feat: add new segment callback bindings

This commit is contained in:
thewh1teagle 2024-05-28 21:21:00 +03:00
parent 1f4c49ae68
commit c674b7e101
No known key found for this signature in database
GPG key ID: F7BFC3A4192804E4
2 changed files with 147 additions and 1 deletions

View file

@ -29,7 +29,7 @@ pub use whisper_ctx::WhisperContextParameters;
use whisper_ctx::WhisperInnerContext; use whisper_ctx::WhisperInnerContext;
pub use whisper_ctx_wrapper::WhisperContext; pub use whisper_ctx_wrapper::WhisperContext;
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
pub use whisper_params::{FullParams, SamplingStrategy}; pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData};
#[cfg(feature = "raw-api")] #[cfg(feature = "raw-api")]
pub use whisper_rs_sys; pub use whisper_rs_sys;
pub use whisper_state::WhisperState; pub use whisper_state::WhisperState;

View file

@ -21,6 +21,16 @@ impl Default for SamplingStrategy {
} }
} }
#[derive(Debug, Clone)]
pub struct SegmentCallbackData {
pub segment: i32,
pub start_timestamp: i64,
pub end_timestamp: i64,
pub text: String,
}
type SegmentCallbackFn = Box<dyn FnMut(SegmentCallbackData)>;
pub struct FullParams<'a, 'b> { pub struct FullParams<'a, 'b> {
pub(crate) fp: whisper_rs_sys::whisper_full_params, pub(crate) fp: whisper_rs_sys::whisper_full_params,
phantom_lang: PhantomData<&'a str>, phantom_lang: PhantomData<&'a str>,
@ -28,6 +38,7 @@ pub struct FullParams<'a, 'b> {
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>, grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
progess_callback_safe: Option<Box<dyn FnMut(i32)>>, progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
abort_callback_safe: Option<Box<dyn FnMut() -> bool>>, abort_callback_safe: Option<Box<dyn FnMut() -> bool>>,
segment_calllback_safe: Option<SegmentCallbackFn>,
} }
impl<'a, 'b> FullParams<'a, 'b> { impl<'a, 'b> FullParams<'a, 'b> {
@ -64,6 +75,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
grammar: None, grammar: None,
progess_callback_safe: None, progess_callback_safe: None,
abort_callback_safe: None, abort_callback_safe: None,
segment_calllback_safe: None,
} }
} }
@ -392,6 +404,140 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.new_segment_callback_user_data = user_data; self.fp.new_segment_callback_user_data = user_data;
} }
/// Set the callback for segment updates.
///
/// Provides a limited segment_callback to ensure safety.
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`
///
/// Defaults to None.
pub fn set_segment_callback_safe<O, F>(&mut self, closure: O)
where
F: FnMut(SegmentCallbackData) + 'static,
O: Into<Option<F>>,
{
use std::ffi::{c_void, CStr};
use whisper_rs_sys::{whisper_context, whisper_state};
extern "C" fn trampoline<F>(
_: *mut whisper_context,
state: *mut whisper_state,
n_new: i32,
user_data: *mut c_void,
) where
F: FnMut(SegmentCallbackData) + 'static,
{
unsafe {
let user_data = &mut *(user_data as *mut SegmentCallbackFn);
let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
let s0 = n_segments - n_new;
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
for i in s0..n_segments {
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
let text = CStr::from_ptr(text);
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
match text.to_str() {
Ok(n) => user_data(SegmentCallbackData {
segment: i,
start_timestamp: t0,
end_timestamp: t1,
text: n.to_string(),
}),
Err(_) => {}
}
}
}
}
match closure.into() {
Some(closure) => {
// Stable address
let closure = Box::new(closure) as SegmentCallbackFn;
// Thin pointer
let closure = Box::new(closure);
// Raw pointer
let closure = Box::into_raw(closure);
self.fp.new_segment_callback_user_data = closure as *mut c_void;
self.fp.new_segment_callback = Some(trampoline::<SegmentCallbackFn>);
self.segment_calllback_safe = None;
}
None => {
self.segment_calllback_safe = None;
self.fp.new_segment_callback = None;
self.fp.new_segment_callback_user_data = std::ptr::null_mut::<c_void>();
}
}
}
/// Set the callback for segment updates.
///
/// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters.
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`.
///
/// Defaults to None.
pub fn set_segment_callback_safe_lossy<O, F>(&mut self, closure: O)
where
F: FnMut(SegmentCallbackData) + 'static,
O: Into<Option<F>>,
{
use std::ffi::{c_void, CStr};
use whisper_rs_sys::{whisper_context, whisper_state};
extern "C" fn trampoline<F>(
_: *mut whisper_context,
state: *mut whisper_state,
n_new: i32,
user_data: *mut c_void,
) where
F: FnMut(SegmentCallbackData) + 'static,
{
unsafe {
let user_data = &mut *(user_data as *mut SegmentCallbackFn);
let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
let s0 = n_segments - n_new;
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
for i in s0..n_segments {
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
let text = CStr::from_ptr(text);
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
user_data(SegmentCallbackData {
segment: i,
start_timestamp: t0,
end_timestamp: t1,
text: text.to_string_lossy().to_string(),
});
}
}
}
match closure.into() {
Some(closure) => {
// Stable address
let closure = Box::new(closure) as SegmentCallbackFn;
// Thin pointer
let closure = Box::new(closure);
// Raw pointer
let closure = Box::into_raw(closure);
self.fp.new_segment_callback_user_data = closure as *mut c_void;
self.fp.new_segment_callback = Some(trampoline::<SegmentCallbackFn>);
self.segment_calllback_safe = None;
}
None => {
self.segment_calllback_safe = None;
self.fp.new_segment_callback = None;
self.fp.new_segment_callback_user_data = std::ptr::null_mut::<c_void>();
}
}
}
/// Set the callback for progress updates. /// Set the callback for progress updates.
/// ///
/// Note that is still a C callback. /// Note that is still a C callback.