From 253ac1bb6fb03ce787046d52b31d22bbccb53705 Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 9 Oct 2023 17:02:54 -0600 Subject: [PATCH 1/4] Add a feature flag for Metal acceleration support --- Cargo.toml | 1 + sys/Cargo.toml | 1 + sys/build.rs | 15 +++++++++++++++ 3 files changed, 17 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index eab903c..9c6db08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ coreml = ["whisper-rs-sys/coreml"] cuda = ["whisper-rs-sys/cuda"] opencl = ["whisper-rs-sys/opencl"] openblas = ["whisper-rs-sys/openblas"] +metal = ["whisper-rs-sys/metal"] test-with-tiny-model = [] [package.metadata.docs.rs] diff --git a/sys/Cargo.toml b/sys/Cargo.toml index a84840a..54f171d 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -39,6 +39,7 @@ coreml = [] cuda = [] opencl = [] openblas = [] +metal = [] [dependencies] diff --git a/sys/build.rs b/sys/build.rs index afbfff7..41faedf 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -20,6 +20,12 @@ fn main() { println!("cargo:rustc-link-lib=framework=Foundation"); println!("cargo:rustc-link-lib=framework=CoreML"); } + #[cfg(feature = "metal")] + { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + } } println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap()); @@ -118,6 +124,15 @@ fn main() { #[cfg(feature = "opencl")] cmd.arg("-DWHISPER_CLBLAST=ON"); + + cfg_if! { + if #[cfg(feature = "metal")] { + cmd.arg("-DWHISPER_METAL=ON"); + } else { + // Metal is enabled by default so we need to explicitly disable it + cmd.arg("-DWHISPER_METAL=OFF"); + } + }; cmd.arg("-DCMAKE_POSITION_INDEPENDENT_CODE=ON"); From 74dd93bf54a855f535f477c9620791c9c3c2672c Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 9 Oct 2023 17:05:05 -0600 Subject: [PATCH 2/4] `cargo fmt` --- sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sys/build.rs b/sys/build.rs index 41faedf..be574ad 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -124,7 +124,7 @@ fn main() { #[cfg(feature = "opencl")] cmd.arg("-DWHISPER_CLBLAST=ON"); - + cfg_if! { if #[cfg(feature = "metal")] { cmd.arg("-DWHISPER_METAL=ON"); From ccccfe758d8bda2b2c40651cd8eed00f5fefe85f Mon Sep 17 00:00:00 2001 From: Niko Date: Fri, 27 Oct 2023 08:16:09 -0600 Subject: [PATCH 3/4] Incorporate changes from sandbox-friendly PR --- sys/build.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sys/build.rs b/sys/build.rs index 43b0d23..9a97c3e 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -130,6 +130,13 @@ fn main() { if cfg!(feature = "opencl") { config.define("WHISPER_CLBLAST", "ON"); } + + if cfg!(feature = "metal") { + config.define("WHISPER_METAL", "ON"); + } else { + // Metal is enabled by default, so we need to explicitly disable it + config.define("WHISPER_METAL", "OFF"); + } let destination = config.build(); From ba1b79138e0f524b27b92ddafdd4373ebc27f405 Mon Sep 17 00:00:00 2001 From: Niko Date: Fri, 27 Oct 2023 08:17:11 -0600 Subject: [PATCH 4/4] `cargo fmt` --- sys/build.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sys/build.rs b/sys/build.rs index 9a97c3e..e74ae3e 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -130,7 +130,7 @@ fn main() { if cfg!(feature = "opencl") { config.define("WHISPER_CLBLAST", "ON"); } - + if cfg!(feature = "metal") { config.define("WHISPER_METAL", "ON"); } else {