Add a rustified progress callback.
This commit is contained in:
parent
19f47dac39
commit
ddabeb4c0b
2 changed files with 42 additions and 1 deletions
|
|
@ -47,7 +47,8 @@ fn main() {
|
|||
|
||||
let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
|
||||
let mut state = ctx.create_state().expect("failed to create key");
|
||||
let params = FullParams::new(SamplingStrategy::default());
|
||||
let mut params = FullParams::new(SamplingStrategy::default());
|
||||
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
|
||||
|
||||
let st = std::time::Instant::now();
|
||||
state
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> {
|
|||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||
phantom_lang: PhantomData<&'a str>,
|
||||
phantom_tokens: PhantomData<&'b [c_int]>,
|
||||
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FullParams<'a, 'b> {
|
||||
|
|
@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
fp,
|
||||
phantom_lang: PhantomData,
|
||||
phantom_tokens: PhantomData,
|
||||
progess_callback_safe: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -370,6 +372,44 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
self.fp.progress_callback = progress_callback;
|
||||
}
|
||||
|
||||
/// Set the callback for progress updates, potentially using a closure.
|
||||
///
|
||||
/// Note that, in order to ensure safety, the callback only accepts the progress in percent.
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub fn set_progress_callback_safe<F>(&mut self, closure: Option<F>)
|
||||
where
|
||||
F: FnMut(i32) + 'static,
|
||||
{
|
||||
use std::ffi::c_void;
|
||||
use whisper_rs_sys::{whisper_context, whisper_state};
|
||||
|
||||
unsafe extern "C" fn trampoline<F>(
|
||||
_: *mut whisper_context,
|
||||
_: *mut whisper_state,
|
||||
progress: c_int,
|
||||
user_data: *mut c_void,
|
||||
) where
|
||||
F: FnMut(i32),
|
||||
{
|
||||
let user_data = &mut *(user_data as *mut F);
|
||||
user_data(progress);
|
||||
}
|
||||
|
||||
match closure {
|
||||
Some(mut closure) => {
|
||||
self.fp.progress_callback = Some(trampoline::<F>);
|
||||
self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void;
|
||||
self.progess_callback_safe = Some(Box::new(closure));
|
||||
}
|
||||
None => {
|
||||
self.fp.progress_callback = None;
|
||||
self.fp.progress_callback_user_data = 0 as *mut c_void;
|
||||
self.progess_callback_safe = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the user data to be passed to the progress callback.
|
||||
///
|
||||
/// # Safety
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue