Merge pull request #153 from thewh1teagle/feat/new-segment-callback
Feat/new segment callback
This commit is contained in:
commit
e6271bf0f3
2 changed files with 147 additions and 1 deletions
|
|
@ -29,7 +29,7 @@ pub use whisper_ctx::WhisperContextParameters;
|
|||
use whisper_ctx::WhisperInnerContext;
|
||||
pub use whisper_ctx_wrapper::WhisperContext;
|
||||
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
||||
pub use whisper_params::{FullParams, SamplingStrategy};
|
||||
pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData};
|
||||
#[cfg(feature = "raw-api")]
|
||||
pub use whisper_rs_sys;
|
||||
pub use whisper_state::WhisperState;
|
||||
|
|
|
|||
|
|
@ -21,6 +21,16 @@ impl Default for SamplingStrategy {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SegmentCallbackData {
|
||||
pub segment: i32,
|
||||
pub start_timestamp: i64,
|
||||
pub end_timestamp: i64,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
type SegmentCallbackFn = Box<dyn FnMut(SegmentCallbackData)>;
|
||||
|
||||
pub struct FullParams<'a, 'b> {
|
||||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||
phantom_lang: PhantomData<&'a str>,
|
||||
|
|
@ -28,6 +38,7 @@ pub struct FullParams<'a, 'b> {
|
|||
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
||||
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
|
||||
abort_callback_safe: Option<Box<dyn FnMut() -> bool>>,
|
||||
segment_calllback_safe: Option<SegmentCallbackFn>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FullParams<'a, 'b> {
|
||||
|
|
@ -64,6 +75,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
grammar: None,
|
||||
progess_callback_safe: None,
|
||||
abort_callback_safe: None,
|
||||
segment_calllback_safe: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -392,6 +404,140 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
self.fp.new_segment_callback_user_data = user_data;
|
||||
}
|
||||
|
||||
/// Set the callback for segment updates.
|
||||
///
|
||||
/// Provides a limited segment_callback to ensure safety.
|
||||
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub fn set_segment_callback_safe<O, F>(&mut self, closure: O)
|
||||
where
|
||||
F: FnMut(SegmentCallbackData) + 'static,
|
||||
O: Into<Option<F>>,
|
||||
{
|
||||
use std::ffi::{c_void, CStr};
|
||||
use whisper_rs_sys::{whisper_context, whisper_state};
|
||||
|
||||
extern "C" fn trampoline<F>(
|
||||
_: *mut whisper_context,
|
||||
state: *mut whisper_state,
|
||||
n_new: i32,
|
||||
user_data: *mut c_void,
|
||||
) where
|
||||
F: FnMut(SegmentCallbackData) + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let user_data = &mut *(user_data as *mut SegmentCallbackFn);
|
||||
let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
|
||||
let s0 = n_segments - n_new;
|
||||
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
|
||||
|
||||
for i in s0..n_segments {
|
||||
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
|
||||
let text = CStr::from_ptr(text);
|
||||
|
||||
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
|
||||
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
|
||||
|
||||
match text.to_str() {
|
||||
Ok(n) => user_data(SegmentCallbackData {
|
||||
segment: i,
|
||||
start_timestamp: t0,
|
||||
end_timestamp: t1,
|
||||
text: n.to_string(),
|
||||
}),
|
||||
Err(_) => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match closure.into() {
|
||||
Some(closure) => {
|
||||
// Stable address
|
||||
let closure = Box::new(closure) as SegmentCallbackFn;
|
||||
// Thin pointer
|
||||
let closure = Box::new(closure);
|
||||
// Raw pointer
|
||||
let closure = Box::into_raw(closure);
|
||||
|
||||
self.fp.new_segment_callback_user_data = closure as *mut c_void;
|
||||
self.fp.new_segment_callback = Some(trampoline::<SegmentCallbackFn>);
|
||||
self.segment_calllback_safe = None;
|
||||
}
|
||||
None => {
|
||||
self.segment_calllback_safe = None;
|
||||
self.fp.new_segment_callback = None;
|
||||
self.fp.new_segment_callback_user_data = std::ptr::null_mut::<c_void>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the callback for segment updates.
|
||||
///
|
||||
/// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters.
|
||||
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`.
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub fn set_segment_callback_safe_lossy<O, F>(&mut self, closure: O)
|
||||
where
|
||||
F: FnMut(SegmentCallbackData) + 'static,
|
||||
O: Into<Option<F>>,
|
||||
{
|
||||
use std::ffi::{c_void, CStr};
|
||||
use whisper_rs_sys::{whisper_context, whisper_state};
|
||||
|
||||
extern "C" fn trampoline<F>(
|
||||
_: *mut whisper_context,
|
||||
state: *mut whisper_state,
|
||||
n_new: i32,
|
||||
user_data: *mut c_void,
|
||||
) where
|
||||
F: FnMut(SegmentCallbackData) + 'static,
|
||||
{
|
||||
unsafe {
|
||||
let user_data = &mut *(user_data as *mut SegmentCallbackFn);
|
||||
let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
|
||||
let s0 = n_segments - n_new;
|
||||
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
|
||||
|
||||
for i in s0..n_segments {
|
||||
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
|
||||
let text = CStr::from_ptr(text);
|
||||
|
||||
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
|
||||
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
|
||||
user_data(SegmentCallbackData {
|
||||
segment: i,
|
||||
start_timestamp: t0,
|
||||
end_timestamp: t1,
|
||||
text: text.to_string_lossy().to_string(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match closure.into() {
|
||||
Some(closure) => {
|
||||
// Stable address
|
||||
let closure = Box::new(closure) as SegmentCallbackFn;
|
||||
// Thin pointer
|
||||
let closure = Box::new(closure);
|
||||
// Raw pointer
|
||||
let closure = Box::into_raw(closure);
|
||||
|
||||
self.fp.new_segment_callback_user_data = closure as *mut c_void;
|
||||
self.fp.new_segment_callback = Some(trampoline::<SegmentCallbackFn>);
|
||||
self.segment_calllback_safe = None;
|
||||
}
|
||||
None => {
|
||||
self.segment_calllback_safe = None;
|
||||
self.fp.new_segment_callback = None;
|
||||
self.fp.new_segment_callback_user_data = std::ptr::null_mut::<c_void>();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the callback for progress updates.
|
||||
///
|
||||
/// Note that is still a C callback.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue