Merge pull request #51 from tazz4843/cuda-and-opencl-support

Add CUDA and OpenCL support
This commit is contained in:
0/0 2023-05-14 20:28:49 +00:00 committed by GitHub
commit 61124601d6
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 61 additions and 43 deletions

View file

@ -20,8 +20,12 @@ whisper-rs-sys = { path = "sys", version = "0.5" }
hound = "3.5.0" hound = "3.5.0"
[features] [features]
default = []
simd = [] simd = []
coreml = ["whisper-rs-sys/coreml"] coreml = ["whisper-rs-sys/coreml"]
cuda = ["whisper-rs-sys/cuda"]
opencl = ["whisper-rs-sys/opencl"]
test-with-tiny-model = [] test-with-tiny-model = []
[package.metadata.docs.rs] [package.metadata.docs.rs]

View file

@ -27,8 +27,11 @@ include = [
[features] [features]
coreml = [] coreml = []
cuda = []
opencl = []
[dependencies] [dependencies]
[build-dependencies] [build-dependencies]
bindgen = "0.64" bindgen = "0.64"
cfg-if = "1"

View file

@ -2,6 +2,7 @@
extern crate bindgen; extern crate bindgen;
use cfg_if::cfg_if;
use std::env; use std::env;
use std::path::PathBuf; use std::path::PathBuf;
@ -25,6 +26,20 @@ fn main() {
println!("cargo:rustc-link-lib=static=whisper"); println!("cargo:rustc-link-lib=static=whisper");
#[cfg(feature = "coreml")] #[cfg(feature = "coreml")]
println!("cargo:rustc-link-lib=static=whisper.coreml"); println!("cargo:rustc-link-lib=static=whisper.coreml");
#[cfg(feature = "opencl")]
{
println!("cargo:rustc-link-lib=clblast");
println!("cargo:rustc-link-lib=OpenCL");
}
#[cfg(feature = "cuda")]
{
println!("cargo:rustc-link-lib=cublas");
println!("cargo:rustc-link-lib=culibos");
println!("cargo:rustc-link-lib=cudart");
println!("cargo:rustc-link-lib=cublasLt");
println!("cargo:rustc-link-search=/usr/local/cuda/lib64");
println!("cargo:rustc-link-search=/opt/cuda/lib64");
}
println!("cargo:rerun-if-changed=wrapper.h"); println!("cargo:rerun-if-changed=wrapper.h");
if env::var("WHISPER_DONT_GENERATE_BINDINGS").is_ok() { if env::var("WHISPER_DONT_GENERATE_BINDINGS").is_ok() {
@ -70,33 +85,28 @@ fn main() {
_ = std::fs::create_dir("build"); _ = std::fs::create_dir("build");
env::set_current_dir("build").expect("Unable to change directory to whisper.cpp build"); env::set_current_dir("build").expect("Unable to change directory to whisper.cpp build");
#[cfg(feature = "coreml")] let mut cmd = std::process::Command::new("cmake");
let code = std::process::Command::new("cmake") cmd.arg("..")
.arg("..")
.arg("-DCMAKE_BUILD_TYPE=Release") .arg("-DCMAKE_BUILD_TYPE=Release")
.arg("-DBUILD_SHARED_LIBS=OFF") .arg("-DBUILD_SHARED_LIBS=OFF")
.arg("-DWHISPER_ALL_WARNINGS=OFF") .arg("-DWHISPER_ALL_WARNINGS=OFF")
.arg("-DWHISPER_ALL_WARNINGS_3RD_PARTY=OFF") .arg("-DWHISPER_ALL_WARNINGS_3RD_PARTY=OFF")
.arg("-DWHISPER_BUILD_TESTS=OFF") .arg("-DWHISPER_BUILD_TESTS=OFF")
.arg("-DWHISPER_BUILD_EXAMPLES=OFF") .arg("-DWHISPER_BUILD_EXAMPLES=OFF");
.arg("-DWHISPER_COREML=1")
.arg("-DWHISPER_COREML_ALLOW_FALLBACK=1")
.status()
.expect("Failed to generate build script");
#[cfg(not(feature = "coreml"))] #[cfg(feature = "coreml")]
let code = std::process::Command::new("cmake") cmd.arg("-DWHISPER_COREML=ON")
.arg("..") .arg("-DWHISPER_COREML_ALLOW_FALLBACK=1");
.arg("-DCMAKE_BUILD_TYPE=Release")
.arg("-DBUILD_SHARED_LIBS=OFF") #[cfg(feature = "cuda")]
.arg("-DWHISPER_ALL_WARNINGS=OFF") cmd.arg("-DWHISPER_CUBLAS=ON");
.arg("-DWHISPER_ALL_WARNINGS_3RD_PARTY=OFF")
.arg("-DWHISPER_BUILD_TESTS=OFF") #[cfg(feature = "opencl")]
.arg("-DWHISPER_BUILD_EXAMPLES=OFF") cmd.arg("-DWHISPER_CLBLAST=ON");
.status()
.expect("Failed to generate build script"); let code = cmd.status().expect("Failed to run `cmake`");
if code.code() != Some(0) { if code.code() != Some(0) {
panic!("Failed to generate build script"); panic!("Failed to run `cmake`");
} }
let code = std::process::Command::new("cmake") let code = std::process::Command::new("cmake")
@ -111,31 +121,32 @@ fn main() {
} }
// move libwhisper.a to where Cargo expects it (OUT_DIR) // move libwhisper.a to where Cargo expects it (OUT_DIR)
#[cfg(target_os = "windows")] cfg_if! {
{ if #[cfg(target_os = "windows")] {
std::fs::copy( std::fs::copy(
"Release/whisper.lib", "Release/whisper.lib",
format!("{}/whisper.lib", env::var("OUT_DIR").unwrap()), format!("{}/whisper.lib", env::var("OUT_DIR").unwrap()),
) )
.expect("Failed to copy libwhisper.lib"); .expect("Failed to copy libwhisper.lib");
} else {
std::fs::copy(
"libwhisper.a",
format!("{}/libwhisper.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libwhisper.a");
}
} }
#[cfg(not(target_os = "windows"))] // if on iOS or macOS, with coreml feature enabled, copy libwhisper.coreml.a as well
{ cfg_if! {
std::fs::copy( if #[cfg(all(feature = "coreml", any(target_os = "ios", target_os = "macos")))]
"libwhisper.a", {
format!("{}/libwhisper.a", env::var("OUT_DIR").unwrap()), std::fs::copy(
) "libwhisper.coreml.a",
.expect("Failed to copy libwhisper.a"); format!("{}/libwhisper.coreml.a", env::var("OUT_DIR").unwrap()),
} )
#[cfg(feature = "coreml")] .expect("Failed to copy libwhisper.coreml.a");
#[cfg(not(target_os = "windows"))] }
{
std::fs::copy(
"libwhisper.coreml.a",
format!("{}/libwhisper.coreml.a", env::var("OUT_DIR").unwrap()),
)
.expect("Failed to copy libwhisper.coreml.a");
} }
// clean the whisper build directory to prevent Cargo from complaining during crate publish // clean the whisper build directory to prevent Cargo from complaining during crate publish