Add a rustified progress callback.
This commit is contained in:
parent
19f47dac39
commit
ddabeb4c0b
2 changed files with 42 additions and 1 deletions
|
|
@ -47,7 +47,8 @@ fn main() {
|
||||||
|
|
||||||
let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
|
let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
|
||||||
let mut state = ctx.create_state().expect("failed to create key");
|
let mut state = ctx.create_state().expect("failed to create key");
|
||||||
let params = FullParams::new(SamplingStrategy::default());
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
|
||||||
|
|
||||||
let st = std::time::Instant::now();
|
let st = std::time::Instant::now();
|
||||||
state
|
state
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> {
|
||||||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||||
phantom_lang: PhantomData<&'a str>,
|
phantom_lang: PhantomData<&'a str>,
|
||||||
phantom_tokens: PhantomData<&'b [c_int]>,
|
phantom_tokens: PhantomData<&'b [c_int]>,
|
||||||
|
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> FullParams<'a, 'b> {
|
impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
|
|
@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
fp,
|
fp,
|
||||||
phantom_lang: PhantomData,
|
phantom_lang: PhantomData,
|
||||||
phantom_tokens: PhantomData,
|
phantom_tokens: PhantomData,
|
||||||
|
progess_callback_safe: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -370,6 +372,44 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
self.fp.progress_callback = progress_callback;
|
self.fp.progress_callback = progress_callback;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the callback for progress updates, potentially using a closure.
|
||||||
|
///
|
||||||
|
/// Note that, in order to ensure safety, the callback only accepts the progress in percent.
|
||||||
|
///
|
||||||
|
/// Defaults to None.
|
||||||
|
pub fn set_progress_callback_safe<F>(&mut self, closure: Option<F>)
|
||||||
|
where
|
||||||
|
F: FnMut(i32) + 'static,
|
||||||
|
{
|
||||||
|
use std::ffi::c_void;
|
||||||
|
use whisper_rs_sys::{whisper_context, whisper_state};
|
||||||
|
|
||||||
|
unsafe extern "C" fn trampoline<F>(
|
||||||
|
_: *mut whisper_context,
|
||||||
|
_: *mut whisper_state,
|
||||||
|
progress: c_int,
|
||||||
|
user_data: *mut c_void,
|
||||||
|
) where
|
||||||
|
F: FnMut(i32),
|
||||||
|
{
|
||||||
|
let user_data = &mut *(user_data as *mut F);
|
||||||
|
user_data(progress);
|
||||||
|
}
|
||||||
|
|
||||||
|
match closure {
|
||||||
|
Some(mut closure) => {
|
||||||
|
self.fp.progress_callback = Some(trampoline::<F>);
|
||||||
|
self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void;
|
||||||
|
self.progess_callback_safe = Some(Box::new(closure));
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.fp.progress_callback = None;
|
||||||
|
self.fp.progress_callback_user_data = 0 as *mut c_void;
|
||||||
|
self.progess_callback_safe = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the user data to be passed to the progress callback.
|
/// Set the user data to be passed to the progress callback.
|
||||||
///
|
///
|
||||||
/// # Safety
|
/// # Safety
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue