feat: add safe bindings to abort callback
This commit is contained in:
parent
a9e060571a
commit
56f59edeed
1 changed files with 47 additions and 0 deletions
|
|
@ -27,6 +27,7 @@ pub struct FullParams<'a, 'b> {
|
||||||
phantom_tokens: PhantomData<&'b [c_int]>,
|
phantom_tokens: PhantomData<&'b [c_int]>,
|
||||||
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
||||||
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
||||||
|
abort_callback_safe: Option<Box<dyn FnMut() -> bool>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> FullParams<'a, 'b> {
|
impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
|
|
@ -62,6 +63,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
phantom_tokens: PhantomData,
|
phantom_tokens: PhantomData,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
progess_callback_safe: None,
|
progess_callback_safe: None,
|
||||||
|
abort_callback_safe: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -450,6 +452,51 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the callback for abort conditions, potentially using a closure.
|
||||||
|
///
|
||||||
|
/// Note that, for safety, the callback only accepts a function that returns a boolean
|
||||||
|
/// indicating whether to abort or not.
|
||||||
|
///
|
||||||
|
/// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state`,
|
||||||
|
/// or extend this one to support their use.
|
||||||
|
///
|
||||||
|
/// Defaults to None.
|
||||||
|
pub fn set_abort_callback_safe<O, F>(&mut self, closure: O)
|
||||||
|
where
|
||||||
|
F: FnMut() -> bool + 'static,
|
||||||
|
O: Into<Option<F>>,
|
||||||
|
{
|
||||||
|
use std::ffi::c_void;
|
||||||
|
|
||||||
|
unsafe extern "C" fn trampoline<F>(user_data: *mut c_void) -> bool
|
||||||
|
where
|
||||||
|
F: FnMut() -> bool,
|
||||||
|
{
|
||||||
|
let user_data = &mut *(user_data as *mut F);
|
||||||
|
user_data()
|
||||||
|
}
|
||||||
|
|
||||||
|
match closure.into() {
|
||||||
|
Some(closure) => {
|
||||||
|
// Stable address
|
||||||
|
let closure = Box::new(closure) as Box<dyn FnMut() -> bool>;
|
||||||
|
// Thin pointer
|
||||||
|
let closure = Box::new(closure);
|
||||||
|
// Raw pointer
|
||||||
|
let closure = Box::into_raw(closure);
|
||||||
|
|
||||||
|
self.fp.abort_callback = Some(trampoline::<F>);
|
||||||
|
self.fp.abort_callback_user_data = closure as *mut c_void;
|
||||||
|
self.abort_callback_safe = None;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.fp.abort_callback = None;
|
||||||
|
self.fp.abort_callback_user_data = std::ptr::null_mut::<c_void>();
|
||||||
|
self.abort_callback_safe = None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the user data to be passed to the progress callback.
|
/// Set the user data to be passed to the progress callback.
|
||||||
///
|
///
|
||||||
/// # Safety
|
/// # Safety
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue