diff --git a/examples/audio_transcription.rs b/examples/audio_transcription.rs index 7661e04..2a8499b 100644 --- a/examples/audio_transcription.rs +++ b/examples/audio_transcription.rs @@ -9,9 +9,37 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. fn main() -> Result<(), &'static str> { // Load a context and model. + let mut context_param = WhisperContextParameters::default(); + + // Enable DTW token level timestamp for known model by using model preset + context_param.dtw_parameters.mode = whisper_rs::DtwMode::ModelPreset { + model_preset: whisper_rs::DtwModelPreset::BaseEn, + }; + + // Enable DTW token level timestamp for unknown model by providing custom aheads + // see details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 + // values corresponds to ggml-base.en.bin, result will be the same as with DtwModelPreset::BaseEn + let custom_aheads = [ + (3, 1), + (4, 2), + (4, 3), + (4, 7), + (5, 1), + (5, 2), + (5, 4), + (5, 6), + ] + .map(|(n_text_layer, n_head)| whisper_rs_sys::whisper_ahead { + n_text_layer, + n_head, + }); + context_param.dtw_parameters.mode = whisper_rs::DtwMode::Custom { + aheads: &custom_aheads, + }; + let ctx = WhisperContext::new_with_params( "example/path/to/model/whisper.cpp/models/ggml-base.en.bin", - WhisperContextParameters::default(), + context_param, ) .expect("failed to load model"); // Create a state @@ -33,6 +61,8 @@ fn main() -> Result<(), &'static str> { params.set_print_progress(false); params.set_print_realtime(false); params.set_print_timestamps(false); + // Enable token level timestamps + params.set_token_timestamps(true); // Open the audio file. let reader = hound::WavReader::open("audio.wav").expect("failed to open file"); @@ -87,8 +117,24 @@ fn main() -> Result<(), &'static str> { .full_get_segment_t1(i) .expect("failed to get end timestamp"); + let first_token_dtw_ts = if let Ok(token_count) = state.full_n_tokens(i) { + if token_count > 0 { + if let Ok(token_data) = state.full_get_token_data(i, 0) { + token_data.t_dtw + } else { + -1i64 + } + } else { + -1i64 + } + } else { + -1i64 + }; // Print the segment to stdout. - println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); + println!( + "[{} - {} ({})]: {}", + start_timestamp, end_timestamp, first_token_dtw_ts, segment + ); // Format the segment information as a string. let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment); diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 5ab7e21..1a0f0fc 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -469,7 +469,7 @@ impl Drop for WhisperInnerContext { unsafe impl Send for WhisperInnerContext {} unsafe impl Sync for WhisperInnerContext {} -pub struct WhisperContextParameters { +pub struct WhisperContextParameters<'a> { /// Use GPU if available. /// /// **Warning**: Does not have an effect if OpenCL is selected as GPU backend @@ -482,11 +482,11 @@ pub struct WhisperContextParameters { /// GPU device id, default 0 pub gpu_device: c_int, /// DTW token level timestamp parameters - pub dtw_parameters: DtwParameters, + pub dtw_parameters: DtwParameters<'a>, } #[allow(clippy::derivable_impls)] // this impl cannot be derived -impl Default for WhisperContextParameters { +impl<'a> Default for WhisperContextParameters<'a> { fn default() -> Self { Self { use_gpu: cfg!(feature = "_gpu"), @@ -496,7 +496,7 @@ impl Default for WhisperContextParameters { } } } -impl WhisperContextParameters { +impl<'a> WhisperContextParameters<'a> { pub fn new() -> Self { Self::default() } @@ -512,7 +512,7 @@ impl WhisperContextParameters { self.gpu_device = gpu_device; self } - pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters) -> &mut Self { + pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters<'a>) -> &mut Self { self.dtw_parameters = dtw_parameters; self } @@ -606,12 +606,12 @@ impl WhisperContextParameters { /// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled #[derive(Debug, Clone)] -pub struct DtwParameters { - pub mode: DtwMode, +pub struct DtwParameters<'a> { + pub mode: DtwMode<'a>, pub dtw_mem_size: usize, } -impl Default for DtwParameters { +impl Default for DtwParameters<'_> { fn default() -> Self { Self { mode: DtwMode::None, @@ -621,7 +621,7 @@ impl Default for DtwParameters { } #[derive(Debug, Clone)] -pub enum DtwMode { +pub enum DtwMode<'a> { /// DTW token level timestamps disabled None, /// Use N Top Most layers from loaded model @@ -633,7 +633,7 @@ pub enum DtwMode { /// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element /// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143 Custom { - aheads: &'static [whisper_rs_sys::whisper_ahead], + aheads: &'a [whisper_rs_sys::whisper_ahead], }, /// Use predefined preset for standard models ModelPreset { model_preset: DtwModelPreset },