diff --git a/.gitignore b/.gitignore index c3e612d..4f72f63 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ **/target **/Cargo.lock /.idea +/.vscode \ No newline at end of file diff --git a/src/standalone.rs b/src/standalone.rs index 8179943..73b68b1 100644 --- a/src/standalone.rs +++ b/src/standalone.rs @@ -81,7 +81,13 @@ pub unsafe fn set_log_callback( log_callback: crate::WhisperLogCallback, user_data: *mut std::ffi::c_void, ) { - unsafe { whisper_rs_sys::whisper_log_set(log_callback, user_data) } + unsafe { + whisper_rs_sys::whisper_log_set(log_callback, user_data); + #[cfg(feature = "metal")] + { + whisper_rs_sys::ggml_backend_metal_log_set_callback(log_callback, user_data); + } + } } /// Print system information. diff --git a/src/whisper_sys_log.rs b/src/whisper_sys_log.rs index 9b5be22..ba2e51d 100644 --- a/src/whisper_sys_log.rs +++ b/src/whisper_sys_log.rs @@ -37,6 +37,13 @@ unsafe extern "C" fn whisper_cpp_log_trampoline( /// You should only call this once (subsequent calls have no ill effect). pub fn install_whisper_log_trampoline() { crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { - whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut()) + whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut()); + #[cfg(feature = "metal")] + { + whisper_rs_sys::ggml_backend_metal_log_set_callback( + Some(whisper_cpp_log_trampoline), + std::ptr::null_mut(), + ); + } }); } diff --git a/src/whisper_sys_tracing.rs b/src/whisper_sys_tracing.rs index 6c6d316..ea24b54 100644 --- a/src/whisper_sys_tracing.rs +++ b/src/whisper_sys_tracing.rs @@ -37,6 +37,13 @@ unsafe extern "C" fn whisper_cpp_tracing_trampoline( /// You should only call this once (subsequent calls have no effect). pub fn install_whisper_tracing_trampoline() { crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { - whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut()) + whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut()); + #[cfg(feature = "metal")] + { + whisper_rs_sys::ggml_backend_metal_log_set_callback( + Some(whisper_cpp_tracing_trampoline), + std::ptr::null_mut(), + ); + } }); } diff --git a/sys/build.rs b/sys/build.rs index 248fa11..dad3d57 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -101,8 +101,12 @@ fn main() { let _: u64 = std::fs::copy("src/bindings.rs", out.join("bindings.rs")) .expect("Failed to copy bindings.rs"); } else { - let bindings = bindgen::Builder::default() - .header("wrapper.h") + let bindings = bindgen::Builder::default().header("wrapper.h"); + + #[cfg(feature = "metal")] + let bindings = bindings.header("whisper.cpp/ggml-metal.h"); + + let bindings = bindings .clang_arg("-I./whisper.cpp") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .generate(); @@ -169,6 +173,8 @@ fn main() { if cfg!(feature = "metal") { config.define("WHISPER_METAL", "ON"); + config.define("WHISPER_METAL_NDEBUG", "ON"); + config.define("WHISPER_METAL_EMBED_LIBRARY", "ON"); } else { // Metal is enabled by default, so we need to explicitly disable it config.define("WHISPER_METAL", "OFF");