Use proper lifetimes, add DTW usage example
This commit is contained in:
parent
482860d0d6
commit
0c8798c986
2 changed files with 58 additions and 12 deletions
|
|
@ -9,9 +9,37 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar
|
||||||
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
||||||
fn main() -> Result<(), &'static str> {
|
fn main() -> Result<(), &'static str> {
|
||||||
// Load a context and model.
|
// Load a context and model.
|
||||||
|
let mut context_param = WhisperContextParameters::default();
|
||||||
|
|
||||||
|
// Enable DTW token level timestamp for known model by using model preset
|
||||||
|
context_param.dtw_parameters.mode = whisper_rs::DtwMode::ModelPreset {
|
||||||
|
model_preset: whisper_rs::DtwModelPreset::BaseEn,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Enable DTW token level timestamp for unknown model by providing custom aheads
|
||||||
|
// see details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
|
||||||
|
// values corresponds to ggml-base.en.bin, result will be the same as with DtwModelPreset::BaseEn
|
||||||
|
let custom_aheads = [
|
||||||
|
(3, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
(4, 7),
|
||||||
|
(5, 1),
|
||||||
|
(5, 2),
|
||||||
|
(5, 4),
|
||||||
|
(5, 6),
|
||||||
|
]
|
||||||
|
.map(|(n_text_layer, n_head)| whisper_rs_sys::whisper_ahead {
|
||||||
|
n_text_layer,
|
||||||
|
n_head,
|
||||||
|
});
|
||||||
|
context_param.dtw_parameters.mode = whisper_rs::DtwMode::Custom {
|
||||||
|
aheads: &custom_aheads,
|
||||||
|
};
|
||||||
|
|
||||||
let ctx = WhisperContext::new_with_params(
|
let ctx = WhisperContext::new_with_params(
|
||||||
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
|
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
|
||||||
WhisperContextParameters::default(),
|
context_param,
|
||||||
)
|
)
|
||||||
.expect("failed to load model");
|
.expect("failed to load model");
|
||||||
// Create a state
|
// Create a state
|
||||||
|
|
@ -33,6 +61,8 @@ fn main() -> Result<(), &'static str> {
|
||||||
params.set_print_progress(false);
|
params.set_print_progress(false);
|
||||||
params.set_print_realtime(false);
|
params.set_print_realtime(false);
|
||||||
params.set_print_timestamps(false);
|
params.set_print_timestamps(false);
|
||||||
|
// Enable token level timestamps
|
||||||
|
params.set_token_timestamps(true);
|
||||||
|
|
||||||
// Open the audio file.
|
// Open the audio file.
|
||||||
let reader = hound::WavReader::open("audio.wav").expect("failed to open file");
|
let reader = hound::WavReader::open("audio.wav").expect("failed to open file");
|
||||||
|
|
@ -87,8 +117,24 @@ fn main() -> Result<(), &'static str> {
|
||||||
.full_get_segment_t1(i)
|
.full_get_segment_t1(i)
|
||||||
.expect("failed to get end timestamp");
|
.expect("failed to get end timestamp");
|
||||||
|
|
||||||
|
let first_token_dtw_ts = if let Ok(token_count) = state.full_n_tokens(i) {
|
||||||
|
if token_count > 0 {
|
||||||
|
if let Ok(token_data) = state.full_get_token_data(i, 0) {
|
||||||
|
token_data.t_dtw
|
||||||
|
} else {
|
||||||
|
-1i64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
-1i64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
-1i64
|
||||||
|
};
|
||||||
// Print the segment to stdout.
|
// Print the segment to stdout.
|
||||||
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
println!(
|
||||||
|
"[{} - {} ({})]: {}",
|
||||||
|
start_timestamp, end_timestamp, first_token_dtw_ts, segment
|
||||||
|
);
|
||||||
|
|
||||||
// Format the segment information as a string.
|
// Format the segment information as a string.
|
||||||
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
|
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
|
||||||
|
|
|
||||||
|
|
@ -469,7 +469,7 @@ impl Drop for WhisperInnerContext {
|
||||||
unsafe impl Send for WhisperInnerContext {}
|
unsafe impl Send for WhisperInnerContext {}
|
||||||
unsafe impl Sync for WhisperInnerContext {}
|
unsafe impl Sync for WhisperInnerContext {}
|
||||||
|
|
||||||
pub struct WhisperContextParameters {
|
pub struct WhisperContextParameters<'a> {
|
||||||
/// Use GPU if available.
|
/// Use GPU if available.
|
||||||
///
|
///
|
||||||
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
|
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
|
||||||
|
|
@ -482,11 +482,11 @@ pub struct WhisperContextParameters {
|
||||||
/// GPU device id, default 0
|
/// GPU device id, default 0
|
||||||
pub gpu_device: c_int,
|
pub gpu_device: c_int,
|
||||||
/// DTW token level timestamp parameters
|
/// DTW token level timestamp parameters
|
||||||
pub dtw_parameters: DtwParameters,
|
pub dtw_parameters: DtwParameters<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::derivable_impls)] // this impl cannot be derived
|
#[allow(clippy::derivable_impls)] // this impl cannot be derived
|
||||||
impl Default for WhisperContextParameters {
|
impl<'a> Default for WhisperContextParameters<'a> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
use_gpu: cfg!(feature = "_gpu"),
|
use_gpu: cfg!(feature = "_gpu"),
|
||||||
|
|
@ -496,7 +496,7 @@ impl Default for WhisperContextParameters {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl WhisperContextParameters {
|
impl<'a> WhisperContextParameters<'a> {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
@ -512,7 +512,7 @@ impl WhisperContextParameters {
|
||||||
self.gpu_device = gpu_device;
|
self.gpu_device = gpu_device;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters) -> &mut Self {
|
pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters<'a>) -> &mut Self {
|
||||||
self.dtw_parameters = dtw_parameters;
|
self.dtw_parameters = dtw_parameters;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
@ -606,12 +606,12 @@ impl WhisperContextParameters {
|
||||||
|
|
||||||
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
|
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct DtwParameters {
|
pub struct DtwParameters<'a> {
|
||||||
pub mode: DtwMode,
|
pub mode: DtwMode<'a>,
|
||||||
pub dtw_mem_size: usize,
|
pub dtw_mem_size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for DtwParameters {
|
impl Default for DtwParameters<'_> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
mode: DtwMode::None,
|
mode: DtwMode::None,
|
||||||
|
|
@ -621,7 +621,7 @@ impl Default for DtwParameters {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum DtwMode {
|
pub enum DtwMode<'a> {
|
||||||
/// DTW token level timestamps disabled
|
/// DTW token level timestamps disabled
|
||||||
None,
|
None,
|
||||||
/// Use N Top Most layers from loaded model
|
/// Use N Top Most layers from loaded model
|
||||||
|
|
@ -633,7 +633,7 @@ pub enum DtwMode {
|
||||||
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
|
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
|
||||||
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
|
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
|
||||||
Custom {
|
Custom {
|
||||||
aheads: &'static [whisper_rs_sys::whisper_ahead],
|
aheads: &'a [whisper_rs_sys::whisper_ahead],
|
||||||
},
|
},
|
||||||
/// Use predefined preset for standard models
|
/// Use predefined preset for standard models
|
||||||
ModelPreset { model_preset: DtwModelPreset },
|
ModelPreset { model_preset: DtwModelPreset },
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue