Add safe wrapper for raw dtw parameters

This commit is contained in:
arizhih 2024-05-16 18:58:48 +02:00
parent dcfcbced18
commit da6b410439
5 changed files with 135 additions and 51 deletions

View file

@ -11,7 +11,7 @@ fn main() -> Result<(), &'static str> {
// Load a context and model. // Load a context and model.
let ctx = WhisperContext::new_with_params( let ctx = WhisperContext::new_with_params(
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin", "example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
WhisperContextParameters::default(), &mut WhisperContextParameters::default(),
) )
.expect("failed to load model"); .expect("failed to load model");
// Create a state // Create a state

View file

@ -7,7 +7,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar
// more dependencies than the base library. // more dependencies than the base library.
pub fn usage() -> Result<(), &'static str> { pub fn usage() -> Result<(), &'static str> {
// load a context and model // load a context and model
let ctx = WhisperContext::new_with_params("path/to/model", WhisperContextParameters::default()) let ctx = WhisperContext::new_with_params("path/to/model", &mut WhisperContextParameters::default())
.expect("failed to load model"); .expect("failed to load model");
// make a state // make a state
let mut state = ctx.create_state().expect("failed to create state"); let mut state = ctx.create_state().expect("failed to create state");

View file

@ -34,6 +34,8 @@ pub use whisper_state::WhisperState;
pub use whisper_sys_log::install_whisper_log_trampoline; pub use whisper_sys_log::install_whisper_log_trampoline;
#[cfg(feature = "whisper-cpp-tracing")] #[cfg(feature = "whisper-cpp-tracing")]
pub use whisper_sys_tracing::install_whisper_tracing_trampoline; pub use whisper_sys_tracing::install_whisper_tracing_trampoline;
pub use whisper_ctx::DtwParameters;
pub use whisper_ctx::DtwPredefinedModels;
pub type WhisperSysContext = whisper_rs_sys::whisper_context; pub type WhisperSysContext = whisper_rs_sys::whisper_context;
pub type WhisperSysState = whisper_rs_sys::whisper_state; pub type WhisperSysState = whisper_rs_sys::whisper_state;

View file

@ -26,7 +26,7 @@ impl WhisperInnerContext {
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params( pub fn new_with_params(
path: &str, path: &str,
parameters: WhisperContextParameters, parameters: &mut WhisperContextParameters,
) -> Result<Self, WhisperError> { ) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?; let path_cstr = CString::new(path)?;
let ctx = unsafe { let ctx = unsafe {
@ -476,22 +476,13 @@ pub struct WhisperContextParameters {
/// (in that case, GPU is always enabled). /// (in that case, GPU is always enabled).
pub use_gpu: bool, pub use_gpu: bool,
/// Enable flash attention, default false /// Enable flash attention, default false
///
/// **Warning** Can't be used with DTW. DTW will be disabled if flash_attn is true
pub flash_attn: bool, pub flash_attn: bool,
/// GPU device id, default 0 /// GPU device id, default 0
pub gpu_device: c_int, pub gpu_device: c_int,
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default 0 /// DTW token level timestamp parameters
pub dtw_token_timestamps: bool, pub dtw_parameters: DtwParameters,
/// Preset id for DTW, default whisper_alignment_heads_preset_WHISPER_AHEADS_NONE
pub dtw_aheads_preset: whisper_rs_sys::whisper_alignment_heads_preset,
/// Number of top text layers used from model. Only with whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST preset.
pub dtw_n_top: c_int,
/// Custom aheads, only with whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM preset
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
pub dtw_aheads: whisper_rs_sys::whisper_aheads,
/// Memory size for DTW
///
/// **Warning**: Might be removed in next version of whisper.cpp
pub dtw_mem_size: usize,
} }
#[allow(clippy::derivable_impls)] // this impl cannot be derived #[allow(clippy::derivable_impls)] // this impl cannot be derived
@ -501,14 +492,7 @@ impl Default for WhisperContextParameters {
use_gpu: cfg!(feature = "_gpu"), use_gpu: cfg!(feature = "_gpu"),
flash_attn: false, flash_attn: false,
gpu_device: 0, gpu_device: 0,
dtw_token_timestamps: false, dtw_parameters: DtwParameters::default(),
dtw_aheads_preset: whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE,
dtw_n_top: -1,
dtw_aheads: whisper_rs_sys::whisper_aheads {
n_heads: 0,
heads: std::ptr::null(),
},
dtw_mem_size: 1024 * 1024 * 128,
} }
} }
} }
@ -528,43 +512,141 @@ impl WhisperContextParameters {
self.gpu_device = gpu_device; self.gpu_device = gpu_device;
self self
} }
pub fn dtw_token_timestamps(&mut self, dtw_token_timestamps: bool) -> &mut Self { pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters) -> &mut Self {
self.dtw_token_timestamps = dtw_token_timestamps; self.dtw_parameters = dtw_parameters;
self
}
pub fn dtw_aheads_preset(
&mut self,
dtw_aheads_preset: whisper_rs_sys::whisper_alignment_heads_preset,
) -> &mut Self {
self.dtw_aheads_preset = dtw_aheads_preset;
self
}
pub fn dtw_n_top(&mut self, dtw_n_top: c_int) -> &mut Self {
self.dtw_n_top = dtw_n_top;
self
}
pub fn dtw_aheads(&mut self, dtw_aheads: whisper_rs_sys::whisper_aheads) -> &mut Self {
self.dtw_aheads = dtw_aheads;
self
}
pub fn dtw_mem_size(&mut self, dtw_mem_size: usize) -> &mut Self {
self.dtw_mem_size = dtw_mem_size;
self self
} }
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params { fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
let dtw_token_timestamps = !matches!(self.dtw_parameters, DtwParameters::Disabled);
let mut dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE;
let mut dtw_n_top: c_int = -1;
let mut dtw_aheads = whisper_rs_sys::whisper_aheads {
n_heads: 0,
heads: std::ptr::null(),
};
let dtw_mem_size = 1024 * 1024 * 128;
match &self.dtw_parameters {
DtwParameters::Disabled => {}
DtwParameters::TopMost { n_top } => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST;
dtw_n_top = *n_top;
}
DtwParameters::Custom { aheads } => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM;
dtw_aheads = whisper_rs_sys::whisper_aheads {
n_heads: aheads.len(),
heads: aheads.as_ptr(),
};
}
DtwParameters::Predefined { model_preset } => match model_preset {
DtwPredefinedModels::TinyEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN;
}
DtwPredefinedModels::Tiny => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY;
}
DtwPredefinedModels::BaseEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN;
}
DtwPredefinedModels::Base => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE;
}
DtwPredefinedModels::SmallEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN;
}
DtwPredefinedModels::Small => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL;
}
DtwPredefinedModels::MediumEn => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN;
}
DtwPredefinedModels::Medium => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM;
}
DtwPredefinedModels::LargeV1 => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1;
}
DtwPredefinedModels::LargeV2 => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2;
}
DtwPredefinedModels::LargeV3 => {
dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3;
}
},
}
whisper_rs_sys::whisper_context_params { whisper_rs_sys::whisper_context_params {
use_gpu: self.use_gpu, use_gpu: self.use_gpu,
flash_attn: self.flash_attn, flash_attn: self.flash_attn,
gpu_device: self.gpu_device, gpu_device: self.gpu_device,
dtw_token_timestamps: self.dtw_token_timestamps, dtw_token_timestamps,
dtw_aheads_preset: self.dtw_aheads_preset, dtw_aheads_preset,
dtw_n_top: self.dtw_n_top, dtw_n_top,
dtw_aheads: self.dtw_aheads, dtw_aheads,
dtw_mem_size: self.dtw_mem_size, dtw_mem_size,
} }
} }
} }
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
#[derive(Debug, Clone)]
pub enum DtwParameters {
/// DTW token level timestamps disabled
Disabled,
/// Use N Top Most layers from loaded model
TopMost {
/// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer
n_top: c_int,
},
/// Use custom aheads, non-empty list of whisper_ahead.
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
Custom {
aheads: Vec<whisper_rs_sys::whisper_ahead>,
},
/// Use predefined preset for standard models
Predefined { model_preset: DtwPredefinedModels },
}
impl Default for DtwParameters {
fn default() -> Self {
Self::Disabled
}
}
#[derive(Debug, Clone)]
pub enum DtwPredefinedModels {
TinyEn,
Tiny,
BaseEn,
Base,
SmallEn,
Small,
MediumEn,
Medium,
LargeV1,
LargeV2,
LargeV3,
}
#[cfg(test)] #[cfg(test)]
#[cfg(feature = "test-with-tiny-model")] #[cfg(feature = "test-with-tiny-model")]
mod test_with_tiny_model { mod test_with_tiny_model {

View file

@ -27,7 +27,7 @@ impl WhisperContext {
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params( pub fn new_with_params(
path: &str, path: &str,
parameters: WhisperContextParameters, parameters: &mut WhisperContextParameters,
) -> Result<Self, WhisperError> { ) -> Result<Self, WhisperError> {
let ctx = WhisperInnerContext::new_with_params(path, parameters)?; let ctx = WhisperInnerContext::new_with_params(path, parameters)?;
Ok(Self::wrap(ctx)) Ok(Self::wrap(ctx))