From 56f59edeeda27093030895ccb8374aa3d65f50e8 Mon Sep 17 00:00:00 2001 From: thewh1teagle <61390950+thewh1teagle@users.noreply.github.com> Date: Mon, 15 Apr 2024 00:48:53 +0300 Subject: [PATCH] feat: add safe bindings to abort callback --- src/whisper_params.rs | 47 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 92552dd..dbf0bb5 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -27,6 +27,7 @@ pub struct FullParams<'a, 'b> { phantom_tokens: PhantomData<&'b [c_int]>, grammar: Option>, progess_callback_safe: Option>, + abort_callback_safe: Option bool>>, } impl<'a, 'b> FullParams<'a, 'b> { @@ -62,6 +63,7 @@ impl<'a, 'b> FullParams<'a, 'b> { phantom_tokens: PhantomData, grammar: None, progess_callback_safe: None, + abort_callback_safe: None, } } @@ -450,6 +452,51 @@ impl<'a, 'b> FullParams<'a, 'b> { } } + /// Set the callback for abort conditions, potentially using a closure. + /// + /// Note that, for safety, the callback only accepts a function that returns a boolean + /// indicating whether to abort or not. + /// + /// See `set_progress_callback` if you need to use `whisper_context` and `whisper_state`, + /// or extend this one to support their use. + /// + /// Defaults to None. + pub fn set_abort_callback_safe(&mut self, closure: O) + where + F: FnMut() -> bool + 'static, + O: Into>, + { + use std::ffi::c_void; + + unsafe extern "C" fn trampoline(user_data: *mut c_void) -> bool + where + F: FnMut() -> bool, + { + let user_data = &mut *(user_data as *mut F); + user_data() + } + + match closure.into() { + Some(closure) => { + // Stable address + let closure = Box::new(closure) as Box bool>; + // Thin pointer + let closure = Box::new(closure); + // Raw pointer + let closure = Box::into_raw(closure); + + self.fp.abort_callback = Some(trampoline::); + self.fp.abort_callback_user_data = closure as *mut c_void; + self.abort_callback_safe = None; + } + None => { + self.fp.abort_callback = None; + self.fp.abort_callback_user_data = std::ptr::null_mut::(); + self.abort_callback_safe = None; + } + } + } + /// Set the user data to be passed to the progress callback. /// /// # Safety