From da6b410439c7c18ff1c494c2f5f53b4e0ad606ae Mon Sep 17 00:00:00 2001 From: arizhih Date: Thu, 16 May 2024 18:58:48 +0200 Subject: [PATCH] Add safe wrapper for raw dtw parameters --- examples/audio_transcription.rs | 2 +- examples/basic_use.rs | 2 +- src/lib.rs | 2 + src/whisper_ctx.rs | 178 +++++++++++++++++++++++--------- src/whisper_ctx_wrapper.rs | 2 +- 5 files changed, 135 insertions(+), 51 deletions(-) diff --git a/examples/audio_transcription.rs b/examples/audio_transcription.rs index 7661e04..c31aa0e 100644 --- a/examples/audio_transcription.rs +++ b/examples/audio_transcription.rs @@ -11,7 +11,7 @@ fn main() -> Result<(), &'static str> { // Load a context and model. let ctx = WhisperContext::new_with_params( "example/path/to/model/whisper.cpp/models/ggml-base.en.bin", - WhisperContextParameters::default(), + &mut WhisperContextParameters::default(), ) .expect("failed to load model"); // Create a state diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 8627473..2bc75a6 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -7,7 +7,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar // more dependencies than the base library. pub fn usage() -> Result<(), &'static str> { // load a context and model - let ctx = WhisperContext::new_with_params("path/to/model", WhisperContextParameters::default()) + let ctx = WhisperContext::new_with_params("path/to/model", &mut WhisperContextParameters::default()) .expect("failed to load model"); // make a state let mut state = ctx.create_state().expect("failed to create state"); diff --git a/src/lib.rs b/src/lib.rs index 42109bc..c1ef2b8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,6 +34,8 @@ pub use whisper_state::WhisperState; pub use whisper_sys_log::install_whisper_log_trampoline; #[cfg(feature = "whisper-cpp-tracing")] pub use whisper_sys_tracing::install_whisper_tracing_trampoline; +pub use whisper_ctx::DtwParameters; +pub use whisper_ctx::DtwPredefinedModels; pub type WhisperSysContext = whisper_rs_sys::whisper_context; pub type WhisperSysState = whisper_rs_sys::whisper_state; diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 3857bcc..92f9b6a 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -26,7 +26,7 @@ impl WhisperInnerContext { /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` pub fn new_with_params( path: &str, - parameters: WhisperContextParameters, + parameters: &mut WhisperContextParameters, ) -> Result { let path_cstr = CString::new(path)?; let ctx = unsafe { @@ -476,22 +476,13 @@ pub struct WhisperContextParameters { /// (in that case, GPU is always enabled). pub use_gpu: bool, /// Enable flash attention, default false + /// + /// **Warning** Can't be used with DTW. DTW will be disabled if flash_attn is true pub flash_attn: bool, /// GPU device id, default 0 pub gpu_device: c_int, - /// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default 0 - pub dtw_token_timestamps: bool, - /// Preset id for DTW, default whisper_alignment_heads_preset_WHISPER_AHEADS_NONE - pub dtw_aheads_preset: whisper_rs_sys::whisper_alignment_heads_preset, - /// Number of top text layers used from model. Only with whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST preset. - pub dtw_n_top: c_int, - /// Custom aheads, only with whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM preset - /// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 - pub dtw_aheads: whisper_rs_sys::whisper_aheads, - /// Memory size for DTW - /// - /// **Warning**: Might be removed in next version of whisper.cpp - pub dtw_mem_size: usize, + /// DTW token level timestamp parameters + pub dtw_parameters: DtwParameters, } #[allow(clippy::derivable_impls)] // this impl cannot be derived @@ -501,14 +492,7 @@ impl Default for WhisperContextParameters { use_gpu: cfg!(feature = "_gpu"), flash_attn: false, gpu_device: 0, - dtw_token_timestamps: false, - dtw_aheads_preset: whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE, - dtw_n_top: -1, - dtw_aheads: whisper_rs_sys::whisper_aheads { - n_heads: 0, - heads: std::ptr::null(), - }, - dtw_mem_size: 1024 * 1024 * 128, + dtw_parameters: DtwParameters::default(), } } } @@ -528,43 +512,141 @@ impl WhisperContextParameters { self.gpu_device = gpu_device; self } - pub fn dtw_token_timestamps(&mut self, dtw_token_timestamps: bool) -> &mut Self { - self.dtw_token_timestamps = dtw_token_timestamps; - self - } - pub fn dtw_aheads_preset( - &mut self, - dtw_aheads_preset: whisper_rs_sys::whisper_alignment_heads_preset, - ) -> &mut Self { - self.dtw_aheads_preset = dtw_aheads_preset; - self - } - pub fn dtw_n_top(&mut self, dtw_n_top: c_int) -> &mut Self { - self.dtw_n_top = dtw_n_top; - self - } - pub fn dtw_aheads(&mut self, dtw_aheads: whisper_rs_sys::whisper_aheads) -> &mut Self { - self.dtw_aheads = dtw_aheads; - self - } - pub fn dtw_mem_size(&mut self, dtw_mem_size: usize) -> &mut Self { - self.dtw_mem_size = dtw_mem_size; + pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters) -> &mut Self { + self.dtw_parameters = dtw_parameters; self } + fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params { + let dtw_token_timestamps = !matches!(self.dtw_parameters, DtwParameters::Disabled); + let mut dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE; + let mut dtw_n_top: c_int = -1; + let mut dtw_aheads = whisper_rs_sys::whisper_aheads { + n_heads: 0, + heads: std::ptr::null(), + }; + + let dtw_mem_size = 1024 * 1024 * 128; + + match &self.dtw_parameters { + DtwParameters::Disabled => {} + DtwParameters::TopMost { n_top } => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST; + dtw_n_top = *n_top; + } + DtwParameters::Custom { aheads } => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM; + + dtw_aheads = whisper_rs_sys::whisper_aheads { + n_heads: aheads.len(), + heads: aheads.as_ptr(), + }; + } + DtwParameters::Predefined { model_preset } => match model_preset { + DtwPredefinedModels::TinyEn => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN; + } + DtwPredefinedModels::Tiny => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY; + } + DtwPredefinedModels::BaseEn => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN; + } + DtwPredefinedModels::Base => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE; + } + DtwPredefinedModels::SmallEn => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN; + } + DtwPredefinedModels::Small => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL; + } + DtwPredefinedModels::MediumEn => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN; + } + DtwPredefinedModels::Medium => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM; + } + DtwPredefinedModels::LargeV1 => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1; + } + DtwPredefinedModels::LargeV2 => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2; + } + DtwPredefinedModels::LargeV3 => { + dtw_aheads_preset = + whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3; + } + }, + } + whisper_rs_sys::whisper_context_params { use_gpu: self.use_gpu, flash_attn: self.flash_attn, gpu_device: self.gpu_device, - dtw_token_timestamps: self.dtw_token_timestamps, - dtw_aheads_preset: self.dtw_aheads_preset, - dtw_n_top: self.dtw_n_top, - dtw_aheads: self.dtw_aheads, - dtw_mem_size: self.dtw_mem_size, + dtw_token_timestamps, + dtw_aheads_preset, + dtw_n_top, + dtw_aheads, + dtw_mem_size, } } } +/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled +#[derive(Debug, Clone)] +pub enum DtwParameters { + /// DTW token level timestamps disabled + Disabled, + /// Use N Top Most layers from loaded model + TopMost { + /// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer + n_top: c_int, + }, + /// Use custom aheads, non-empty list of whisper_ahead. + /// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element + /// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 + Custom { + aheads: Vec, + }, + /// Use predefined preset for standard models + Predefined { model_preset: DtwPredefinedModels }, +} + +impl Default for DtwParameters { + fn default() -> Self { + Self::Disabled + } +} + +#[derive(Debug, Clone)] +pub enum DtwPredefinedModels { + TinyEn, + Tiny, + BaseEn, + Base, + SmallEn, + Small, + MediumEn, + Medium, + LargeV1, + LargeV2, + LargeV3, +} + #[cfg(test)] #[cfg(feature = "test-with-tiny-model")] mod test_with_tiny_model { diff --git a/src/whisper_ctx_wrapper.rs b/src/whisper_ctx_wrapper.rs index ff3caff..87d062c 100644 --- a/src/whisper_ctx_wrapper.rs +++ b/src/whisper_ctx_wrapper.rs @@ -27,7 +27,7 @@ impl WhisperContext { /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` pub fn new_with_params( path: &str, - parameters: WhisperContextParameters, + parameters: &mut WhisperContextParameters, ) -> Result { let ctx = WhisperInnerContext::new_with_params(path, parameters)?; Ok(Self::wrap(ctx))