Merge pull request #220 from rasteiner/fix/progress-callback-crash

Fix/progress callback crash
This commit is contained in:
Niko 2025-06-07 23:24:52 -07:00 committed by GitHub
commit 3c37ed2271
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -38,7 +38,7 @@ pub struct FullParams<'a, 'b> {
phantom_lang: PhantomData<&'a str>, phantom_lang: PhantomData<&'a str>,
phantom_tokens: PhantomData<&'b [c_int]>, phantom_tokens: PhantomData<&'b [c_int]>,
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>, grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
progess_callback_safe: Option<Arc<Box<dyn FnMut(i32)>>>, progress_callback_safe: Option<Arc<Box<dyn FnMut(i32)>>>,
abort_callback_safe: Option<Arc<Box<dyn FnMut() -> bool>>>, abort_callback_safe: Option<Arc<Box<dyn FnMut() -> bool>>>,
segment_calllback_safe: Option<Arc<SegmentCallbackFn>>, segment_calllback_safe: Option<Arc<SegmentCallbackFn>>,
} }
@ -75,7 +75,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
phantom_lang: PhantomData, phantom_lang: PhantomData,
phantom_tokens: PhantomData, phantom_tokens: PhantomData,
grammar: None, grammar: None,
progess_callback_safe: None, progress_callback_safe: None,
abort_callback_safe: None, abort_callback_safe: None,
segment_calllback_safe: None, segment_calllback_safe: None,
} }
@ -579,16 +579,18 @@ impl<'a, 'b> FullParams<'a, 'b> {
} }
match closure.into() { match closure.into() {
Some(mut closure) => { Some(closure) => {
self.fp.progress_callback = Some(trampoline::<F>); self.fp.progress_callback = Some(trampoline::<Box<dyn FnMut(i32)>>);
self.fp.progress_callback_user_data = &mut closure as *mut F as *mut c_void; let boxed_closure = Box::new(closure) as Box<dyn FnMut(i32)>;
// store the closure internally to make sure that the pointer above remains valid let boxed_closure = Box::new(boxed_closure);
self.progess_callback_safe = Some(Arc::new(Box::new(closure))); let raw_ptr = Box::into_raw(boxed_closure);
self.fp.progress_callback_user_data = raw_ptr as *mut c_void;
self.progress_callback_safe = None;
} }
None => { None => {
self.fp.progress_callback = None; self.fp.progress_callback = None;
self.fp.progress_callback_user_data = std::ptr::null_mut::<c_void>(); self.fp.progress_callback_user_data = std::ptr::null_mut::<c_void>();
self.progess_callback_safe = None; self.progress_callback_safe = None;
} }
} }
} }