From ddabeb4c0bbc03d66593eb2a65b8f31b454bae2f Mon Sep 17 00:00:00 2001 From: Marcin Mielniczuk Date: Fri, 11 Aug 2023 15:08:05 +0200 Subject: [PATCH 1/2] Add a rustified progress callback. --- examples/full_usage/src/main.rs | 3 ++- src/whisper_params.rs | 40 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index 42c0bff..c5eca17 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -47,7 +47,8 @@ fn main() { let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); let mut state = ctx.create_state().expect("failed to create key"); - let params = FullParams::new(SamplingStrategy::default()); + let mut params = FullParams::new(SamplingStrategy::default()); + params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); let st = std::time::Instant::now(); state diff --git a/src/whisper_params.rs b/src/whisper_params.rs index fc20722..3370b6b 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> { pub(crate) fp: whisper_rs_sys::whisper_full_params, phantom_lang: PhantomData<&'a str>, phantom_tokens: PhantomData<&'b [c_int]>, + progess_callback_safe: Option>, } impl<'a, 'b> FullParams<'a, 'b> { @@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> { fp, phantom_lang: PhantomData, phantom_tokens: PhantomData, + progess_callback_safe: None, } } @@ -370,6 +372,44 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.progress_callback = progress_callback; } + /// Set the callback for progress updates, potentially using a closure. + /// + /// Note that, in order to ensure safety, the callback only accepts the progress in percent. + /// + /// Defaults to None. + pub fn set_progress_callback_safe(&mut self, closure: Option) + where + F: FnMut(i32) + 'static, + { + use std::ffi::c_void; + use whisper_rs_sys::{whisper_context, whisper_state}; + + unsafe extern "C" fn trampoline( + _: *mut whisper_context, + _: *mut whisper_state, + progress: c_int, + user_data: *mut c_void, + ) where + F: FnMut(i32), + { + let user_data = &mut *(user_data as *mut F); + user_data(progress); + } + + match closure { + Some(mut closure) => { + self.fp.progress_callback = Some(trampoline::); + self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void; + self.progess_callback_safe = Some(Box::new(closure)); + } + None => { + self.fp.progress_callback = None; + self.fp.progress_callback_user_data = 0 as *mut c_void; + self.progess_callback_safe = None; + } + } + } + /// Set the user data to be passed to the progress callback. /// /// # Safety From 0e74df12a9987822b152b5ee08501491689956f9 Mon Sep 17 00:00:00 2001 From: Marcin Mielniczuk Date: Fri, 11 Aug 2023 22:01:06 +0200 Subject: [PATCH 2/2] Improve the docs --- src/whisper_params.rs | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 3370b6b..923f646 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -356,8 +356,8 @@ impl<'a, 'b> FullParams<'a, 'b> { /// Set the callback for progress updates. /// - /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). - /// It is still a C callback. + /// Note that is still a C callback. + /// See `set_progress_callback_safe` for a limited yet safe version. /// /// # Safety /// Do not use this function unless you know what you are doing. @@ -375,11 +375,14 @@ impl<'a, 'b> FullParams<'a, 'b> { /// Set the callback for progress updates, potentially using a closure. /// /// Note that, in order to ensure safety, the callback only accepts the progress in percent. + /// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state` + /// (or extend this one to support their use). /// /// Defaults to None. - pub fn set_progress_callback_safe(&mut self, closure: Option) + pub fn set_progress_callback_safe(&mut self, closure: O) where F: FnMut(i32) + 'static, + O: Into>, { use std::ffi::c_void; use whisper_rs_sys::{whisper_context, whisper_state}; @@ -396,10 +399,11 @@ impl<'a, 'b> FullParams<'a, 'b> { user_data(progress); } - match closure { + match closure.into() { Some(mut closure) => { self.fp.progress_callback = Some(trampoline::); self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void; + // store the closure internally to make sure that the pointer above remains valid self.progess_callback_safe = Some(Box::new(closure)); } None => {