Merge branch 'tazz4843:master' into fix/progress-callback-crash

This commit is contained in:
Roman Steiner 2025-03-03 18:33:52 +01:00 committed by GitHub
commit 138020527c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 49 additions and 23 deletions

View file

@ -376,6 +376,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
/// Do not use this function unless you know what you are doing. /// Do not use this function unless you know what you are doing.
/// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
/// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
/// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation
/// ///
/// Defaults to None. /// Defaults to None.
pub unsafe fn set_new_segment_callback( pub unsafe fn set_new_segment_callback(
@ -389,6 +390,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
/// ///
/// # Safety /// # Safety
/// See the safety notes for `set_new_segment_callback`. /// See the safety notes for `set_new_segment_callback`.
/// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation
/// ///
/// Defaults to None. /// Defaults to None.
pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
@ -399,6 +401,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
/// ///
/// Provides a limited segment_callback to ensure safety. /// Provides a limited segment_callback to ensure safety.
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state` /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`
/// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation
/// ///
/// Defaults to None. /// Defaults to None.
pub fn set_segment_callback_safe<O, F>(&mut self, closure: O) pub fn set_segment_callback_safe<O, F>(&mut self, closure: O)
@ -419,20 +422,26 @@ impl<'a, 'b> FullParams<'a, 'b> {
{ {
unsafe { unsafe {
let user_data = &mut *(user_data as *mut SegmentCallbackFn); let user_data = &mut *(user_data as *mut SegmentCallbackFn);
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, n_new); let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
let text = CStr::from_ptr(text); let s0 = n_segments - n_new;
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, n_new); for i in s0..n_segments {
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, n_new); let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
let text = CStr::from_ptr(text);
match text.to_str() { let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
Ok(n) => user_data(SegmentCallbackData { let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
segment: n_new + 1,
start_timestamp: t0, match text.to_str() {
end_timestamp: t1, Ok(n) => user_data(SegmentCallbackData {
text: n.to_string(), segment: i,
}), start_timestamp: t0,
Err(_) => {} end_timestamp: t1,
text: n.to_string(),
}),
Err(_) => {}
}
} }
} }
} }
@ -462,6 +471,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
/// ///
/// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters. /// Provides a limited segment_callback to ensure safety with lossy handling of bad UTF-8 characters.
/// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`. /// See `set_new_segment_callback` if you need to use `whisper_context` and `whisper_state`.
/// **Warning** Can't be used with DTW. DTW will produce inconsistent callback invocation
/// ///
/// Defaults to None. /// Defaults to None.
pub fn set_segment_callback_safe_lossy<O, F>(&mut self, closure: O) pub fn set_segment_callback_safe_lossy<O, F>(&mut self, closure: O)
@ -482,17 +492,23 @@ impl<'a, 'b> FullParams<'a, 'b> {
{ {
unsafe { unsafe {
let user_data = &mut *(user_data as *mut SegmentCallbackFn); let user_data = &mut *(user_data as *mut SegmentCallbackFn);
let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, n_new); let n_segments = whisper_rs_sys::whisper_full_n_segments_from_state(state);
let text = CStr::from_ptr(text); let s0 = n_segments - n_new;
//let user_data = user_data as *mut Box<dyn FnMut(SegmentCallbackData)>;
let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, n_new); for i in s0..n_segments {
let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, n_new); let text = whisper_rs_sys::whisper_full_get_segment_text_from_state(state, i);
user_data(SegmentCallbackData { let text = CStr::from_ptr(text);
segment: n_new,
start_timestamp: t0, let t0 = whisper_rs_sys::whisper_full_get_segment_t0_from_state(state, i);
end_timestamp: t1, let t1 = whisper_rs_sys::whisper_full_get_segment_t1_from_state(state, i);
text: text.to_string_lossy().to_string(), user_data(SegmentCallbackData {
}); segment: i,
start_timestamp: t0,
end_timestamp: t1,
text: text.to_string_lossy().to_string(),
});
}
} }
} }

View file

@ -164,6 +164,10 @@ fn main() {
.very_verbose(true) .very_verbose(true)
.pic(true); .pic(true);
if cfg!(target_os = "windows") {
config.cxxflag("/utf-8");
}
if cfg!(feature = "coreml") { if cfg!(feature = "coreml") {
config.define("WHISPER_COREML", "ON"); config.define("WHISPER_COREML", "ON");
config.define("WHISPER_COREML_ALLOW_FALLBACK", "1"); config.define("WHISPER_COREML_ALLOW_FALLBACK", "1");
@ -214,6 +218,12 @@ fn main() {
if cfg!(feature = "openblas") { if cfg!(feature = "openblas") {
config.define("GGML_BLAS", "ON"); config.define("GGML_BLAS", "ON");
config.define("GGML_BLAS_VENDOR", "OpenBLAS");
if env::var("BLAS_INCLUDE_DIRS").is_err() {
panic!("BLAS_INCLUDE_DIRS environment variable must be set when using OpenBLAS");
}
config.define("BLAS_INCLUDE_DIRS", env::var("BLAS_INCLUDE_DIRS").unwrap());
println!("cargo:rerun-if-env-changed=BLAS_INCLUDE_DIRS");
} }
if cfg!(feature = "metal") { if cfg!(feature = "metal") {
@ -255,7 +265,7 @@ fn main() {
println!("cargo:rustc-link-lib=static=ggml"); println!("cargo:rustc-link-lib=static=ggml");
println!("cargo:rustc-link-lib=static=ggml-base"); println!("cargo:rustc-link-lib=static=ggml-base");
println!("cargo:rustc-link-lib=static=ggml-cpu"); println!("cargo:rustc-link-lib=static=ggml-cpu");
if cfg!(target_os = "macos") { if cfg!(target_os = "macos") || cfg!(feature = "openblas") {
println!("cargo:rustc-link-lib=static=ggml-blas"); println!("cargo:rustc-link-lib=static=ggml-blas");
} }
if cfg!(feature = "vulkan") { if cfg!(feature = "vulkan") {