Pass aheads by reference, add dtw_mem_size param, rustfmt

This commit is contained in:
arizhih 2024-05-17 02:05:12 +02:00
parent da6b410439
commit 482860d0d6
5 changed files with 45 additions and 37 deletions

View file

@ -11,7 +11,7 @@ fn main() -> Result<(), &'static str> {
// Load a context and model. // Load a context and model.
let ctx = WhisperContext::new_with_params( let ctx = WhisperContext::new_with_params(
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin", "example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
&mut WhisperContextParameters::default(), WhisperContextParameters::default(),
) )
.expect("failed to load model"); .expect("failed to load model");
// Create a state // Create a state

View file

@ -7,7 +7,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar
// more dependencies than the base library. // more dependencies than the base library.
pub fn usage() -> Result<(), &'static str> { pub fn usage() -> Result<(), &'static str> {
// load a context and model // load a context and model
let ctx = WhisperContext::new_with_params("path/to/model", &mut WhisperContextParameters::default()) let ctx = WhisperContext::new_with_params("path/to/model", WhisperContextParameters::default())
.expect("failed to load model"); .expect("failed to load model");
// make a state // make a state
let mut state = ctx.create_state().expect("failed to create state"); let mut state = ctx.create_state().expect("failed to create state");

View file

@ -22,6 +22,9 @@ pub use standalone::*;
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))] #[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
use std::sync::Once; use std::sync::Once;
pub use utilities::*; pub use utilities::*;
pub use whisper_ctx::DtwMode;
pub use whisper_ctx::DtwModelPreset;
pub use whisper_ctx::DtwParameters;
pub use whisper_ctx::WhisperContextParameters; pub use whisper_ctx::WhisperContextParameters;
use whisper_ctx::WhisperInnerContext; use whisper_ctx::WhisperInnerContext;
pub use whisper_ctx_wrapper::WhisperContext; pub use whisper_ctx_wrapper::WhisperContext;
@ -34,8 +37,6 @@ pub use whisper_state::WhisperState;
pub use whisper_sys_log::install_whisper_log_trampoline; pub use whisper_sys_log::install_whisper_log_trampoline;
#[cfg(feature = "whisper-cpp-tracing")] #[cfg(feature = "whisper-cpp-tracing")]
pub use whisper_sys_tracing::install_whisper_tracing_trampoline; pub use whisper_sys_tracing::install_whisper_tracing_trampoline;
pub use whisper_ctx::DtwParameters;
pub use whisper_ctx::DtwPredefinedModels;
pub type WhisperSysContext = whisper_rs_sys::whisper_context; pub type WhisperSysContext = whisper_rs_sys::whisper_context;
pub type WhisperSysState = whisper_rs_sys::whisper_state; pub type WhisperSysState = whisper_rs_sys::whisper_state;

View file

@ -26,7 +26,7 @@ impl WhisperInnerContext {
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params( pub fn new_with_params(
path: &str, path: &str,
parameters: &mut WhisperContextParameters, parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> { ) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?; let path_cstr = CString::new(path)?;
let ctx = unsafe { let ctx = unsafe {
@ -518,7 +518,7 @@ impl WhisperContextParameters {
} }
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params { fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
let dtw_token_timestamps = !matches!(self.dtw_parameters, DtwParameters::Disabled); let dtw_token_timestamps = !matches!(self.dtw_parameters.mode, DtwMode::None);
let mut dtw_aheads_preset = let mut dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE;
let mut dtw_n_top: c_int = -1; let mut dtw_n_top: c_int = -1;
@ -527,16 +527,14 @@ impl WhisperContextParameters {
heads: std::ptr::null(), heads: std::ptr::null(),
}; };
let dtw_mem_size = 1024 * 1024 * 128; match &self.dtw_parameters.mode {
DtwMode::None => {}
match &self.dtw_parameters { DtwMode::TopMost { n_top } => {
DtwParameters::Disabled => {}
DtwParameters::TopMost { n_top } => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST;
dtw_n_top = *n_top; dtw_n_top = *n_top;
} }
DtwParameters::Custom { aheads } => { DtwMode::Custom { aheads } => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM;
@ -545,48 +543,48 @@ impl WhisperContextParameters {
heads: aheads.as_ptr(), heads: aheads.as_ptr(),
}; };
} }
DtwParameters::Predefined { model_preset } => match model_preset { DtwMode::ModelPreset { model_preset } => match model_preset {
DtwPredefinedModels::TinyEn => { DtwModelPreset::TinyEn => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN;
} }
DtwPredefinedModels::Tiny => { DtwModelPreset::Tiny => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY;
} }
DtwPredefinedModels::BaseEn => { DtwModelPreset::BaseEn => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN;
} }
DtwPredefinedModels::Base => { DtwModelPreset::Base => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE;
} }
DtwPredefinedModels::SmallEn => { DtwModelPreset::SmallEn => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN;
} }
DtwPredefinedModels::Small => { DtwModelPreset::Small => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL;
} }
DtwPredefinedModels::MediumEn => { DtwModelPreset::MediumEn => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN;
} }
DtwPredefinedModels::Medium => { DtwModelPreset::Medium => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM;
} }
DtwPredefinedModels::LargeV1 => { DtwModelPreset::LargeV1 => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1;
} }
DtwPredefinedModels::LargeV2 => { DtwModelPreset::LargeV2 => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2;
} }
DtwPredefinedModels::LargeV3 => { DtwModelPreset::LargeV3 => {
dtw_aheads_preset = dtw_aheads_preset =
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3; whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3;
} }
@ -601,16 +599,31 @@ impl WhisperContextParameters {
dtw_aheads_preset, dtw_aheads_preset,
dtw_n_top, dtw_n_top,
dtw_aheads, dtw_aheads,
dtw_mem_size, dtw_mem_size: self.dtw_parameters.dtw_mem_size,
} }
} }
} }
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled /// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum DtwParameters { pub struct DtwParameters {
pub mode: DtwMode,
pub dtw_mem_size: usize,
}
impl Default for DtwParameters {
fn default() -> Self {
Self {
mode: DtwMode::None,
dtw_mem_size: 1024 * 1024 * 128,
}
}
}
#[derive(Debug, Clone)]
pub enum DtwMode {
/// DTW token level timestamps disabled /// DTW token level timestamps disabled
Disabled, None,
/// Use N Top Most layers from loaded model /// Use N Top Most layers from loaded model
TopMost { TopMost {
/// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer /// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer
@ -620,20 +633,14 @@ pub enum DtwParameters {
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element /// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 /// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
Custom { Custom {
aheads: Vec<whisper_rs_sys::whisper_ahead>, aheads: &'static [whisper_rs_sys::whisper_ahead],
}, },
/// Use predefined preset for standard models /// Use predefined preset for standard models
Predefined { model_preset: DtwPredefinedModels }, ModelPreset { model_preset: DtwModelPreset },
}
impl Default for DtwParameters {
fn default() -> Self {
Self::Disabled
}
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum DtwPredefinedModels { pub enum DtwModelPreset {
TinyEn, TinyEn,
Tiny, Tiny,
BaseEn, BaseEn,

View file

@ -27,7 +27,7 @@ impl WhisperContext {
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params( pub fn new_with_params(
path: &str, path: &str,
parameters: &mut WhisperContextParameters, parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> { ) -> Result<Self, WhisperError> {
let ctx = WhisperInnerContext::new_with_params(path, parameters)?; let ctx = WhisperInnerContext::new_with_params(path, parameters)?;
Ok(Self::wrap(ctx)) Ok(Self::wrap(ctx))