From 253ac1bb6fb03ce787046d52b31d22bbccb53705 Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 9 Oct 2023 17:02:54 -0600 Subject: [PATCH] Add a feature flag for Metal acceleration support --- Cargo.toml | 1 + sys/Cargo.toml | 1 + sys/build.rs | 15 +++++++++++++++ 3 files changed, 17 insertions(+) diff --git a/Cargo.toml b/Cargo.toml index eab903c..9c6db08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ coreml = ["whisper-rs-sys/coreml"] cuda = ["whisper-rs-sys/cuda"] opencl = ["whisper-rs-sys/opencl"] openblas = ["whisper-rs-sys/openblas"] +metal = ["whisper-rs-sys/metal"] test-with-tiny-model = [] [package.metadata.docs.rs] diff --git a/sys/Cargo.toml b/sys/Cargo.toml index a84840a..54f171d 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -39,6 +39,7 @@ coreml = [] cuda = [] opencl = [] openblas = [] +metal = [] [dependencies] diff --git a/sys/build.rs b/sys/build.rs index afbfff7..41faedf 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -20,6 +20,12 @@ fn main() { println!("cargo:rustc-link-lib=framework=Foundation"); println!("cargo:rustc-link-lib=framework=CoreML"); } + #[cfg(feature = "metal")] + { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + } } println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap()); @@ -118,6 +124,15 @@ fn main() { #[cfg(feature = "opencl")] cmd.arg("-DWHISPER_CLBLAST=ON"); + + cfg_if! { + if #[cfg(feature = "metal")] { + cmd.arg("-DWHISPER_METAL=ON"); + } else { + // Metal is enabled by default so we need to explicitly disable it + cmd.arg("-DWHISPER_METAL=OFF"); + } + }; cmd.arg("-DCMAKE_POSITION_INDEPENDENT_CODE=ON");