Merge pull request #77 from marmistrz/master

Add a rustified progress callback.
This commit is contained in:
0/0 2023-08-13 04:37:53 +00:00 committed by GitHub
commit b7615db242
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 48 additions and 3 deletions

View file

@ -47,7 +47,8 @@ fn main() {
let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
let mut state = ctx.create_state().expect("failed to create key");
let params = FullParams::new(SamplingStrategy::default());
let mut params = FullParams::new(SamplingStrategy::default());
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
let st = std::time::Instant::now();
state

View file

@ -24,6 +24,7 @@ pub struct FullParams<'a, 'b> {
pub(crate) fp: whisper_rs_sys::whisper_full_params,
phantom_lang: PhantomData<&'a str>,
phantom_tokens: PhantomData<&'b [c_int]>,
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
}
impl<'a, 'b> FullParams<'a, 'b> {
@ -57,6 +58,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
fp,
phantom_lang: PhantomData,
phantom_tokens: PhantomData,
progess_callback_safe: None,
}
}
@ -354,8 +356,8 @@ impl<'a, 'b> FullParams<'a, 'b> {
/// Set the callback for progress updates.
///
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
/// It is still a C callback.
/// Note that is still a C callback.
/// See `set_progress_callback_safe` for a limited yet safe version.
///
/// # Safety
/// Do not use this function unless you know what you are doing.
@ -370,6 +372,48 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.progress_callback = progress_callback;
}
/// Set the callback for progress updates, potentially using a closure.
///
/// Note that, in order to ensure safety, the callback only accepts the progress in percent.
/// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state`
/// (or extend this one to support their use).
///
/// Defaults to None.
pub fn set_progress_callback_safe<O, F>(&mut self, closure: O)
where
F: FnMut(i32) + 'static,
O: Into<Option<F>>,
{
use std::ffi::c_void;
use whisper_rs_sys::{whisper_context, whisper_state};
unsafe extern "C" fn trampoline<F>(
_: *mut whisper_context,
_: *mut whisper_state,
progress: c_int,
user_data: *mut c_void,
) where
F: FnMut(i32),
{
let user_data = &mut *(user_data as *mut F);
user_data(progress);
}
match closure.into() {
Some(mut closure) => {
self.fp.progress_callback = Some(trampoline::<F>);
self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void;
// store the closure internally to make sure that the pointer above remains valid
self.progess_callback_safe = Some(Box::new(closure));
}
None => {
self.fp.progress_callback = None;
self.fp.progress_callback_user_data = 0 as *mut c_void;
self.progess_callback_safe = None;
}
}
}
/// Set the user data to be passed to the progress callback.
///
/// # Safety