diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index 42c0bff..c5eca17 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -47,7 +47,8 @@ fn main() { let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); let mut state = ctx.create_state().expect("failed to create key"); - let params = FullParams::new(SamplingStrategy::default()); + let mut params = FullParams::new(SamplingStrategy::default()); + params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); let st = std::time::Instant::now(); state diff --git a/src/whisper_params.rs b/src/whisper_params.rs index fc20722..3370b6b 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, + progess_callback_safe: Option>, } impl<'a, 'b> FullParams<'a, 'b> { @@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> { fp, phantom_lang: PhantomData, phantom_tokens: PhantomData, + progess_callback_safe: None, } } @@ -370,6 +372,44 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.progress_callback = progress_callback; } + /// Set the callback for progress updates, potentially using a closure. + /// + /// Note that, in order to ensure safety, the callback only accepts the progress in percent. + /// + /// Defaults to None. + pub fn set_progress_callback_safe(&mut self, closure: Option) + where + F: FnMut(i32) + 'static, + { + use std::ffi::c_void; + use whisper_rs_sys::{whisper_context, whisper_state}; + + unsafe extern "C" fn trampoline( + _: *mut whisper_context, + _: *mut whisper_state, + progress: c_int, + user_data: *mut c_void, + ) where + F: FnMut(i32), + { + let user_data = &mut *(user_data as *mut F); + user_data(progress); + } + + match closure { + Some(mut closure) => { + self.fp.progress_callback = Some(trampoline::); + self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void; + self.progess_callback_safe = Some(Box::new(closure)); + } + None => { + self.fp.progress_callback = None; + self.fp.progress_callback_user_data = 0 as *mut c_void; + self.progess_callback_safe = None; + } + } + } + /// Set the user data to be passed to the progress callback. /// /// # Safety