Add a rustified progress callback.

This commit is contained in:
Marcin Mielniczuk 2023-08-11 15:08:05 +02:00
parent 19f47dac39
commit ddabeb4c0b
2 changed files with 42 additions and 1 deletions

View file

@ -47,7 +47,8 @@ fn main() {
let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
let mut state = ctx.create_state().expect("failed to create key"); let mut state = ctx.create_state().expect("failed to create key");
let params = FullParams::new(SamplingStrategy::default()); let mut params = FullParams::new(SamplingStrategy::default());
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
let st = std::time::Instant::now(); let st = std::time::Instant::now();
state state

View file

@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> {
pub(crate) fp: whisper_rs_sys::whisper_full_params, pub(crate) fp: whisper_rs_sys::whisper_full_params,
phantom_lang: PhantomData<&'a str>, phantom_lang: PhantomData<&'a str>,
phantom_tokens: PhantomData<&'b [c_int]>, phantom_tokens: PhantomData<&'b [c_int]>,
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
} }
impl<'a, 'b> FullParams<'a, 'b> { impl<'a, 'b> FullParams<'a, 'b> {
@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
fp, fp,
phantom_lang: PhantomData, phantom_lang: PhantomData,
phantom_tokens: PhantomData, phantom_tokens: PhantomData,
progess_callback_safe: None,
} }
} }
@ -370,6 +372,44 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.progress_callback = progress_callback; self.fp.progress_callback = progress_callback;
} }
/// Set the callback for progress updates, potentially using a closure.
///
/// Note that, in order to ensure safety, the callback only accepts the progress in percent.
///
/// Defaults to None.
pub fn set_progress_callback_safe<F>(&mut self, closure: Option<F>)
where
F: FnMut(i32) + 'static,
{
use std::ffi::c_void;
use whisper_rs_sys::{whisper_context, whisper_state};
unsafe extern "C" fn trampoline<F>(
_: *mut whisper_context,
_: *mut whisper_state,
progress: c_int,
user_data: *mut c_void,
) where
F: FnMut(i32),
{
let user_data = &mut *(user_data as *mut F);
user_data(progress);
}
match closure {
Some(mut closure) => {
self.fp.progress_callback = Some(trampoline::<F>);
self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void;
self.progess_callback_safe = Some(Box::new(closure));
}
None => {
self.fp.progress_callback = None;
self.fp.progress_callback_user_data = 0 as *mut c_void;
self.progess_callback_safe = None;
}
}
}
/// Set the user data to be passed to the progress callback. /// Set the user data to be passed to the progress callback.
/// ///
/// # Safety /// # Safety