fix: metal

This commit is contained in:
hlhr202 2024-05-07 00:52:33 +08:00
parent 4dca14d5ec
commit 4bc5709e58
4 changed files with 38 additions and 4 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
**/target **/target
**/Cargo.lock **/Cargo.lock
/.idea /.idea
/.vscode

View file

@ -37,6 +37,13 @@ unsafe extern "C" fn whisper_cpp_log_trampoline(
/// You should only call this once (subsequent calls have no ill effect). /// You should only call this once (subsequent calls have no ill effect).
pub fn install_whisper_log_trampoline() { pub fn install_whisper_log_trampoline() {
crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe {
whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut()) whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut());
#[cfg(feature = "metal")]
{
whisper_rs_sys::ggml_metal_log_set_callback(
Some(whisper_cpp_log_trampoline),
std::ptr::null_mut(),
);
}
}); });
} }

View file

@ -37,6 +37,13 @@ unsafe extern "C" fn whisper_cpp_tracing_trampoline(
/// You should only call this once (subsequent calls have no effect). /// You should only call this once (subsequent calls have no effect).
pub fn install_whisper_tracing_trampoline() { pub fn install_whisper_tracing_trampoline() {
crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe {
whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut()) whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut());
#[cfg(feature = "metal")]
{
whisper_rs_sys::ggml_metal_log_set_callback(
Some(whisper_cpp_tracing_trampoline),
std::ptr::null_mut(),
);
}
}); });
} }

View file

@ -77,8 +77,13 @@ fn main() {
.expect("Failed to copy bindings.rs"); .expect("Failed to copy bindings.rs");
} else { } else {
let bindings = bindgen::Builder::default() let bindings = bindgen::Builder::default()
.header("wrapper.h") .header("wrapper.h");
.clang_arg("-I./whisper.cpp")
#[cfg(feature = "metal")]
let bindings = bindings.header("whisper.cpp/ggml-metal.h");
let bindings = bindings.clang_arg("-I./whisper.cpp")
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new())) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate(); .generate();
@ -160,6 +165,20 @@ fn main() {
// for whatever reason this file is generated during build and triggers cargo complaining // for whatever reason this file is generated during build and triggers cargo complaining
_ = std::fs::remove_file("bindings/javascript/package.json"); _ = std::fs::remove_file("bindings/javascript/package.json");
if cfg!(feature = "metal") {
// copy metal shader to the root of the crate
let _ = std::fs::copy(
out.join("whisper.cpp").join("ggml-metal.metal"),
out.parent()
.unwrap()
.parent()
.unwrap()
.parent()
.unwrap()
.join("ggml-metal.metal"),
);
}
} }
// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 // From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462