From 482860d0d6f5cdf35d9c1fa2d89a9d05cca87187 Mon Sep 17 00:00:00 2001 From: arizhih Date: Fri, 17 May 2024 02:05:12 +0200 Subject: [PATCH] Pass aheads by reference, add dtw_mem_size param, rustfmt --- examples/audio_transcription.rs | 2 +- examples/basic_use.rs | 2 +- src/lib.rs | 5 ++- src/whisper_ctx.rs | 71 ++++++++++++++++++--------------- src/whisper_ctx_wrapper.rs | 2 +- 5 files changed, 45 insertions(+), 37 deletions(-) diff --git a/examples/audio_transcription.rs b/examples/audio_transcription.rs index c31aa0e..7661e04 100644 --- a/examples/audio_transcription.rs +++ b/examples/audio_transcription.rs @@ -11,7 +11,7 @@ fn main() -> Result<(), &'static str> { // Load a context and model. let ctx = WhisperContext::new_with_params( "example/path/to/model/whisper.cpp/models/ggml-base.en.bin", - &mut WhisperContextParameters::default(), + WhisperContextParameters::default(), ) .expect("failed to load model"); // Create a state diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 2bc75a6..8627473 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -7,7 +7,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar // more dependencies than the base library. pub fn usage() -> Result<(), &'static str> { // load a context and model - let ctx = WhisperContext::new_with_params("path/to/model", &mut WhisperContextParameters::default()) + let ctx = WhisperContext::new_with_params("path/to/model", WhisperContextParameters::default()) .expect("failed to load model"); // make a state let mut state = ctx.create_state().expect("failed to create state"); diff --git a/src/lib.rs b/src/lib.rs index c1ef2b8..1822174 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,6 +22,9 @@ pub use standalone::*; #[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))] use std::sync::Once; pub use utilities::*; +pub use whisper_ctx::DtwMode; +pub use whisper_ctx::DtwModelPreset; +pub use whisper_ctx::DtwParameters; pub use whisper_ctx::WhisperContextParameters; use whisper_ctx::WhisperInnerContext; pub use whisper_ctx_wrapper::WhisperContext; @@ -34,8 +37,6 @@ pub use whisper_state::WhisperState; pub use whisper_sys_log::install_whisper_log_trampoline; #[cfg(feature = "whisper-cpp-tracing")] pub use whisper_sys_tracing::install_whisper_tracing_trampoline; -pub use whisper_ctx::DtwParameters; -pub use whisper_ctx::DtwPredefinedModels; pub type WhisperSysContext = whisper_rs_sys::whisper_context; pub type WhisperSysState = whisper_rs_sys::whisper_state; diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 92f9b6a..5ab7e21 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -26,7 +26,7 @@ impl WhisperInnerContext { /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` pub fn new_with_params( path: &str, - parameters: &mut WhisperContextParameters, + parameters: WhisperContextParameters, ) -> Result { let path_cstr = CString::new(path)?; let ctx = unsafe { @@ -518,7 +518,7 @@ impl WhisperContextParameters { } fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params { - let dtw_token_timestamps = !matches!(self.dtw_parameters, DtwParameters::Disabled); + let dtw_token_timestamps = !matches!(self.dtw_parameters.mode, DtwMode::None); let mut dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE; let mut dtw_n_top: c_int = -1; @@ -527,16 +527,14 @@ impl WhisperContextParameters { heads: std::ptr::null(), }; - let dtw_mem_size = 1024 * 1024 * 128; - - match &self.dtw_parameters { - DtwParameters::Disabled => {} - DtwParameters::TopMost { n_top } => { + match &self.dtw_parameters.mode { + DtwMode::None => {} + DtwMode::TopMost { n_top } => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST; dtw_n_top = *n_top; } - DtwParameters::Custom { aheads } => { + DtwMode::Custom { aheads } => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM; @@ -545,48 +543,48 @@ impl WhisperContextParameters { heads: aheads.as_ptr(), }; } - DtwParameters::Predefined { model_preset } => match model_preset { - DtwPredefinedModels::TinyEn => { + DtwMode::ModelPreset { model_preset } => match model_preset { + DtwModelPreset::TinyEn => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN; } - DtwPredefinedModels::Tiny => { + DtwModelPreset::Tiny => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY; } - DtwPredefinedModels::BaseEn => { + DtwModelPreset::BaseEn => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN; } - DtwPredefinedModels::Base => { + DtwModelPreset::Base => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE; } - DtwPredefinedModels::SmallEn => { + DtwModelPreset::SmallEn => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN; } - DtwPredefinedModels::Small => { + DtwModelPreset::Small => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL; } - DtwPredefinedModels::MediumEn => { + DtwModelPreset::MediumEn => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN; } - DtwPredefinedModels::Medium => { + DtwModelPreset::Medium => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM; } - DtwPredefinedModels::LargeV1 => { + DtwModelPreset::LargeV1 => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1; } - DtwPredefinedModels::LargeV2 => { + DtwModelPreset::LargeV2 => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2; } - DtwPredefinedModels::LargeV3 => { + DtwModelPreset::LargeV3 => { dtw_aheads_preset = whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3; } @@ -601,16 +599,31 @@ impl WhisperContextParameters { dtw_aheads_preset, dtw_n_top, dtw_aheads, - dtw_mem_size, + dtw_mem_size: self.dtw_parameters.dtw_mem_size, } } } /// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled #[derive(Debug, Clone)] -pub enum DtwParameters { +pub struct DtwParameters { + pub mode: DtwMode, + pub dtw_mem_size: usize, +} + +impl Default for DtwParameters { + fn default() -> Self { + Self { + mode: DtwMode::None, + dtw_mem_size: 1024 * 1024 * 128, + } + } +} + +#[derive(Debug, Clone)] +pub enum DtwMode { /// DTW token level timestamps disabled - Disabled, + None, /// Use N Top Most layers from loaded model TopMost { /// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer @@ -620,20 +633,14 @@ pub enum DtwParameters { /// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element /// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 Custom { - aheads: Vec, + aheads: &'static [whisper_rs_sys::whisper_ahead], }, /// Use predefined preset for standard models - Predefined { model_preset: DtwPredefinedModels }, -} - -impl Default for DtwParameters { - fn default() -> Self { - Self::Disabled - } + ModelPreset { model_preset: DtwModelPreset }, } #[derive(Debug, Clone)] -pub enum DtwPredefinedModels { +pub enum DtwModelPreset { TinyEn, Tiny, BaseEn, diff --git a/src/whisper_ctx_wrapper.rs b/src/whisper_ctx_wrapper.rs index 87d062c..ff3caff 100644 --- a/src/whisper_ctx_wrapper.rs +++ b/src/whisper_ctx_wrapper.rs @@ -27,7 +27,7 @@ impl WhisperContext { /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` pub fn new_with_params( path: &str, - parameters: &mut WhisperContextParameters, + parameters: WhisperContextParameters, ) -> Result { let ctx = WhisperInnerContext::new_with_params(path, parameters)?; Ok(Self::wrap(ctx))