diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index 42c0bff..c5eca17 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -47,7 +47,8 @@ fn main() { let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); let mut state = ctx.create_state().expect("failed to create key"); - let params = FullParams::new(SamplingStrategy::default()); + let mut params = FullParams::new(SamplingStrategy::default()); + params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); let st = std::time::Instant::now(); state diff --git a/src/whisper_params.rs b/src/whisper_params.rs index fc20722..923f646 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, + progess_callback_safe: Option>, } impl<'a, 'b> FullParams<'a, 'b> { @@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> { fp, phantom_lang: PhantomData, phantom_tokens: PhantomData, + progess_callback_safe: None, } } @@ -354,8 +356,8 @@ impl<'a, 'b> FullParams<'a, 'b> { /// Set the callback for progress updates. /// - /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). - /// It is still a C callback. + /// Note that is still a C callback. + /// See `set_progress_callback_safe` for a limited yet safe version. /// /// # Safety /// Do not use this function unless you know what you are doing. @@ -370,6 +372,48 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.progress_callback = progress_callback; } + /// Set the callback for progress updates, potentially using a closure. + /// + /// Note that, in order to ensure safety, the callback only accepts the progress in percent. + /// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state` + /// (or extend this one to support their use). + /// + /// Defaults to None. + pub fn set_progress_callback_safe(&mut self, closure: O) + where + F: FnMut(i32) + 'static, + O: Into>, + { + use std::ffi::c_void; + use whisper_rs_sys::{whisper_context, whisper_state}; + + unsafe extern "C" fn trampoline( + _: *mut whisper_context, + _: *mut whisper_state, + progress: c_int, + user_data: *mut c_void, + ) where + F: FnMut(i32), + { + let user_data = &mut *(user_data as *mut F); + user_data(progress); + } + + match closure.into() { + Some(mut closure) => { + self.fp.progress_callback = Some(trampoline::); + self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void; + // store the closure internally to make sure that the pointer above remains valid + self.progess_callback_safe = Some(Box::new(closure)); + } + None => { + self.fp.progress_callback = None; + self.fp.progress_callback_user_data = 0 as *mut c_void; + self.progess_callback_safe = None; + } + } + } + /// Set the user data to be passed to the progress callback. /// /// # Safety