From 4bc5709e58e727f5d58c4f466b3cbc2ab619df61 Mon Sep 17 00:00:00 2001 From: hlhr202 Date: Tue, 7 May 2024 00:52:33 +0800 Subject: [PATCH] fix: metal --- .gitignore | 1 + src/whisper_sys_log.rs | 9 ++++++++- src/whisper_sys_tracing.rs | 9 ++++++++- sys/build.rs | 23 +++++++++++++++++++++-- 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index c3e612d..4f72f63 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ **/target **/Cargo.lock /.idea +/.vscode \ No newline at end of file diff --git a/src/whisper_sys_log.rs b/src/whisper_sys_log.rs index 9b5be22..5e58239 100644 --- a/src/whisper_sys_log.rs +++ b/src/whisper_sys_log.rs @@ -37,6 +37,13 @@ unsafe extern "C" fn whisper_cpp_log_trampoline( /// You should only call this once (subsequent calls have no ill effect). pub fn install_whisper_log_trampoline() { crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { - whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut()) + whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut()); + #[cfg(feature = "metal")] + { + whisper_rs_sys::ggml_metal_log_set_callback( + Some(whisper_cpp_log_trampoline), + std::ptr::null_mut(), + ); + } }); } diff --git a/src/whisper_sys_tracing.rs b/src/whisper_sys_tracing.rs index 6c6d316..eeae123 100644 --- a/src/whisper_sys_tracing.rs +++ b/src/whisper_sys_tracing.rs @@ -37,6 +37,13 @@ unsafe extern "C" fn whisper_cpp_tracing_trampoline( /// You should only call this once (subsequent calls have no effect). pub fn install_whisper_tracing_trampoline() { crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { - whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut()) + whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut()); + #[cfg(feature = "metal")] + { + whisper_rs_sys::ggml_metal_log_set_callback( + Some(whisper_cpp_tracing_trampoline), + std::ptr::null_mut(), + ); + } }); } diff --git a/sys/build.rs b/sys/build.rs index 0535637..354e93b 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -77,8 +77,13 @@ fn main() { .expect("Failed to copy bindings.rs"); } else { let bindings = bindgen::Builder::default() - .header("wrapper.h") - .clang_arg("-I./whisper.cpp") + .header("wrapper.h"); + + + #[cfg(feature = "metal")] + let bindings = bindings.header("whisper.cpp/ggml-metal.h"); + + let bindings = bindings.clang_arg("-I./whisper.cpp") .parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .generate(); @@ -160,6 +165,20 @@ fn main() { // for whatever reason this file is generated during build and triggers cargo complaining _ = std::fs::remove_file("bindings/javascript/package.json"); + + if cfg!(feature = "metal") { + // copy metal shader to the root of the crate + let _ = std::fs::copy( + out.join("whisper.cpp").join("ggml-metal.metal"), + out.parent() + .unwrap() + .parent() + .unwrap() + .parent() + .unwrap() + .join("ggml-metal.metal"), + ); + } } // From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462