feat: add new segment callback bindings
This commit is contained in:
parent
1f4c49ae68
commit
c674b7e101
2 changed files with 147 additions and 1 deletions
|
|
@ -29,7 +29,7 @@ pub use whisper_ctx::WhisperContextParameters;
|
||||||
use whisper_ctx::WhisperInnerContext;
|
use whisper_ctx::WhisperInnerContext;
|
||||||
pub use whisper_ctx_wrapper::WhisperContext;
|
pub use whisper_ctx_wrapper::WhisperContext;
|
||||||
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
||||||
pub use whisper_params::{FullParams, SamplingStrategy};
|
pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData};
|
||||||
#[cfg(feature = "raw-api")]
|
#[cfg(feature = "raw-api")]
|
||||||
pub use whisper_rs_sys;
|
pub use whisper_rs_sys;
|
||||||
pub use whisper_state::WhisperState;
|
pub use whisper_state::WhisperState;
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,16 @@ impl Default for SamplingStrategy {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct SegmentCallbackData {
|
||||||
|
pub segment: i32,
|
||||||
|
pub start_timestamp: i64,
|
||||||
|
pub end_timestamp: i64,
|
||||||
|
pub text: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
type SegmentCallbackFn = Box<dyn FnMut(SegmentCallbackData)>;
|
||||||
|
|
||||||
pub struct FullParams<'a, 'b> {
|
pub struct FullParams<'a, 'b> {
|
||||||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||||
phantom_lang: PhantomData<&'a str>,
|
phantom_lang: PhantomData<&'a str>,
|
||||||
|
|
@ -28,6 +38,7 @@ pub struct FullParams<'a, 'b> {
|
||||||
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
||||||
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
||||||
abort_callback_safe: Option<Box<dyn FnMut() -> bool>>,
|
abort_callback_safe: Option<Box<dyn FnMut() -> bool>>,
|
||||||
|
segment_calllback_safe: Option<SegmentCallbackFn>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> FullParams<'a, 'b> {
|
impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
|
|
@ -64,6 +75,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
grammar: None,
|
grammar: None,
|
||||||
progess_callback_safe: None,
|
progess_callback_safe: None,
|
||||||
abort_callback_safe: None,
|
abort_callback_safe: None,
|
||||||
|
segment_calllback_safe: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -392,6 +404,140 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
self.fp.new_segment_callback_user_data = user_data;
|
self.fp.new_segment_callback_user_data = user_data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the callback for segment updates.
|
||||||
|
///
|
||||||
|
/// Provides a limited segment_callback to ensure safety.
|
||||||
|
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`
|
||||||
|
///
|
||||||
|
/// Defaults to None.
|
||||||
|
pub fn set_segment_callback_safe<O, F>(&mut self, closure: O)
|
||||||
|
where
|
||||||
|
F: FnMut(SegmentCallbackData) + 'static,
|
||||||
|
O: Into<Option<F>>,
|
||||||
|
{
|
||||||
|
use std::ffi::{c_void, CStr};
|
||||||
|
use whisper_rs_sys::{whisper_context, whisper_state};
|
||||||
|
|
||||||
|
extern "C" fn trampoline<F>(
|
||||||
|
_: *mut whisper_context,
|
||||||
|
state: *mut whisper_state,
|
||||||
|
n_new: i32,
|
||||||
|
user_data: *mut c_void,
|
||||||
|
) where
|
||||||
|
F: FnMut(SegmentCallbackData) + 'static,
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
let user_data = &mut *(user_data as *mut SegmentCallbackFn);
|
||||||
|
let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
|
||||||
|
let s0 = n_segments - n_new;
|
||||||
|
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
|
||||||
|
|
||||||
|
for i in s0..n_segments {
|
||||||
|
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
|
||||||
|
let text = CStr::from_ptr(text);
|
||||||
|
|
||||||
|
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
|
||||||
|
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
|
||||||
|
|
||||||
|
match text.to_str() {
|
||||||
|
Ok(n) => user_data(SegmentCallbackData {
|
||||||
|
segment: i,
|
||||||
|
start_timestamp: t0,
|
||||||
|
end_timestamp: t1,
|
||||||
|
text: n.to_string(),
|
||||||
|
}),
|
||||||
|
Err(_) => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match closure.into() {
|
||||||
|
Some(closure) => {
|
||||||
|
// Stable address
|
||||||
|
let closure = Box::new(closure) as SegmentCallbackFn;
|
||||||
|
// Thin pointer
|
||||||
|
let closure = Box::new(closure);
|
||||||
|
// Raw pointer
|
||||||
|
let closure = Box::into_raw(closure);
|
||||||
|
|
||||||
|
self.fp.new_segment_callback_user_data = closure as *mut c_void;
|
||||||
|
self.fp.new_segment_callback = Some(trampoline::<SegmentCallbackFn>);
|
||||||
|
self.segment_calllback_safe = None;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.segment_calllback_safe = None;
|
||||||
|
self.fp.new_segment_callback = None;
|
||||||
|
self.fp.new_segment_callback_user_data = std::ptr::null_mut::<c_void>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the callback for segment updates.
|
||||||
|
///
|
||||||
|
/// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters.
|
||||||
|
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`.
|
||||||
|
///
|
||||||
|
/// Defaults to None.
|
||||||
|
pub fn set_segment_callback_safe_lossy<O, F>(&mut self, closure: O)
|
||||||
|
where
|
||||||
|
F: FnMut(SegmentCallbackData) + 'static,
|
||||||
|
O: Into<Option<F>>,
|
||||||
|
{
|
||||||
|
use std::ffi::{c_void, CStr};
|
||||||
|
use whisper_rs_sys::{whisper_context, whisper_state};
|
||||||
|
|
||||||
|
extern "C" fn trampoline<F>(
|
||||||
|
_: *mut whisper_context,
|
||||||
|
state: *mut whisper_state,
|
||||||
|
n_new: i32,
|
||||||
|
user_data: *mut c_void,
|
||||||
|
) where
|
||||||
|
F: FnMut(SegmentCallbackData) + 'static,
|
||||||
|
{
|
||||||
|
unsafe {
|
||||||
|
let user_data = &mut *(user_data as *mut SegmentCallbackFn);
|
||||||
|
let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
|
||||||
|
let s0 = n_segments - n_new;
|
||||||
|
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
|
||||||
|
|
||||||
|
for i in s0..n_segments {
|
||||||
|
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
|
||||||
|
let text = CStr::from_ptr(text);
|
||||||
|
|
||||||
|
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
|
||||||
|
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
|
||||||
|
user_data(SegmentCallbackData {
|
||||||
|
segment: i,
|
||||||
|
start_timestamp: t0,
|
||||||
|
end_timestamp: t1,
|
||||||
|
text: text.to_string_lossy().to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match closure.into() {
|
||||||
|
Some(closure) => {
|
||||||
|
// Stable address
|
||||||
|
let closure = Box::new(closure) as SegmentCallbackFn;
|
||||||
|
// Thin pointer
|
||||||
|
let closure = Box::new(closure);
|
||||||
|
// Raw pointer
|
||||||
|
let closure = Box::into_raw(closure);
|
||||||
|
|
||||||
|
self.fp.new_segment_callback_user_data = closure as *mut c_void;
|
||||||
|
self.fp.new_segment_callback = Some(trampoline::<SegmentCallbackFn>);
|
||||||
|
self.segment_calllback_safe = None;
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
self.segment_calllback_safe = None;
|
||||||
|
self.fp.new_segment_callback = None;
|
||||||
|
self.fp.new_segment_callback_user_data = std::ptr::null_mut::<c_void>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Set the callback for progress updates.
|
/// Set the callback for progress updates.
|
||||||
///
|
///
|
||||||
/// Note that is still a C callback.
|
/// Note that is still a C callback.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue