diff --git a/sys/Cargo.toml b/sys/Cargo.toml index e8d78ad..5bbc6dd 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -17,6 +17,8 @@ include = [ "whisper.cpp/ggml.h", "whisper.cpp/ggml-opencl.c", "whisper.cpp/ggml-opencl.h", + "whisper.cpp/ggml-cuda.cu", + "whisper.cpp/ggml-cuda.h", "whisper.cpp/LICENSE", "whisper.cpp/whisper.cpp", "whisper.cpp/whisper.h", diff --git a/sys/build.rs b/sys/build.rs index 93d33ef..1c58968 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -34,11 +34,18 @@ fn main() { #[cfg(feature = "cuda")] { println!("cargo:rustc-link-lib=cublas"); - println!("cargo:rustc-link-lib=culibos"); println!("cargo:rustc-link-lib=cudart"); println!("cargo:rustc-link-lib=cublasLt"); - println!("cargo:rustc-link-search=/usr/local/cuda/lib64"); - println!("cargo:rustc-link-search=/opt/cuda/lib64"); + cfg_if! { + if #[cfg(target_os = "windows")] { + let cuda_path = PathBuf::from(env::var("CUDA_PATH").unwrap()).join("lib/x64"); + println!("cargo:rustc-link-search={}", cuda_path.display()); + } else { + println!("cargo:rustc-link-lib=culibos"); + println!("cargo:rustc-link-search=/usr/local/cuda/lib64"); + println!("cargo:rustc-link-search=/opt/cuda/lib64"); + } + } } println!("cargo:rerun-if-changed=wrapper.h");