diff --git a/Cargo.toml b/Cargo.toml index eab903c..9c6db08 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ coreml = ["whisper-rs-sys/coreml"] cuda = ["whisper-rs-sys/cuda"] opencl = ["whisper-rs-sys/opencl"] openblas = ["whisper-rs-sys/openblas"] +metal = ["whisper-rs-sys/metal"] test-with-tiny-model = [] [package.metadata.docs.rs] diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 2c1b997..91dc858 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -39,6 +39,7 @@ coreml = [] cuda = [] opencl = [] openblas = [] +metal = [] [build-dependencies] cmake = "0.1" diff --git a/sys/build.rs b/sys/build.rs index 633b9f0..e74ae3e 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -20,6 +20,12 @@ fn main() { println!("cargo:rustc-link-lib=framework=Foundation"); println!("cargo:rustc-link-lib=framework=CoreML"); } + #[cfg(feature = "metal")] + { + println!("cargo:rustc-link-lib=framework=Foundation"); + println!("cargo:rustc-link-lib=framework=Metal"); + println!("cargo:rustc-link-lib=framework=MetalKit"); + } } #[cfg(feature = "coreml")] @@ -125,6 +131,13 @@ fn main() { config.define("WHISPER_CLBLAST", "ON"); } + if cfg!(feature = "metal") { + config.define("WHISPER_METAL", "ON"); + } else { + // Metal is enabled by default, so we need to explicitly disable it + config.define("WHISPER_METAL", "OFF"); + } + let destination = config.build(); if env::var("TARGET").unwrap().contains("window") {