feat: add safe bindings to abort callback

This commit is contained in:
thewh1teagle 2024-04-15 00:48:53 +03:00
parent a9e060571a
commit 56f59edeed
No known key found for this signature in database
GPG key ID: F7BFC3A4192804E4

View file

@ -27,6 +27,7 @@ pub struct FullParams<'a, 'b> {
phantom_tokens: PhantomData<&'b [c_int]>, phantom_tokens: PhantomData<&'b [c_int]>,
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>, grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
progess_callback_safe: Option<Box<dyn FnMut(i32)>>, progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
abort_callback_safe: Option<Box<dyn FnMut() -> bool>>,
} }
impl<'a, 'b> FullParams<'a, 'b> { impl<'a, 'b> FullParams<'a, 'b> {
@ -62,6 +63,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
phantom_tokens: PhantomData, phantom_tokens: PhantomData,
grammar: None, grammar: None,
progess_callback_safe: None, progess_callback_safe: None,
abort_callback_safe: None,
} }
} }
@ -450,6 +452,51 @@ impl<'a, 'b> FullParams<'a, 'b> {
} }
} }
/// Set the callback for abort conditions, potentially using a closure.
///
/// Note that, for safety, the callback only accepts a function that returns a boolean
/// indicating whether to abort or not.
///
/// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state`,
/// or extend this one to support their use.
///
/// Defaults to None.
pub fn set_abort_callback_safe<O, F>(&mut self, closure: O)
where
F: FnMut() -> bool + 'static,
O: Into<Option<F>>,
{
use std::ffi::c_void;
unsafe extern "C" fn trampoline<F>(user_data: *mut c_void) -> bool
where
F: FnMut() -> bool,
{
let user_data = &mut *(user_data as *mut F);
user_data()
}
match closure.into() {
Some(closure) => {
// Stable address
let closure = Box::new(closure) as Box<dyn FnMut() -> bool>;
// Thin pointer
let closure = Box::new(closure);
// Raw pointer
let closure = Box::into_raw(closure);
self.fp.abort_callback = Some(trampoline::<F>);
self.fp.abort_callback_user_data = closure as *mut c_void;
self.abort_callback_safe = None;
}
None => {
self.fp.abort_callback = None;
self.fp.abort_callback_user_data = std::ptr::null_mut::<c_void>();
self.abort_callback_safe = None;
}
}
}
/// Set the user data to be passed to the progress callback. /// Set the user data to be passed to the progress callback.
/// ///
/// # Safety /// # Safety