Merge pull request #154 from newfla/hip_linux

feat: ROCm linux support
This commit is contained in:
Niko 2024-06-02 19:00:16 +00:00 committed by GitHub
commit b46876ae5a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 36 additions and 0 deletions

View file

@ -28,6 +28,7 @@ default = []
raw-api = []
coreml = ["whisper-rs-sys/coreml"]
cuda = ["whisper-rs-sys/cuda", "_gpu"]
hipblas = ["whisper-rs-sys/hipblas", "_gpu"]
opencl = ["whisper-rs-sys/opencl"]
openblas = ["whisper-rs-sys/openblas"]
metal = ["whisper-rs-sys/metal", "_gpu"]

View file

@ -71,6 +71,7 @@ All disabled by default unless otherwise specified.
**NOTE**: enabling this no longer guarantees semver compliance,
as whisper-rs-sys may be upgraded to a breaking version in a patch release of whisper-rs.
* `cuda`: enable CUDA support. Implicitly enables hidden GPU flag at runtime.
* `hipblas`: enable ROCm/hipBLAS support. Only available on linux. Implicitly enables hidden GPU flag at runtime.
* `opencl`: enable OpenCL support. Upstream whisper.cpp does not treat OpenCL as a GPU, so it is always enabled at
runtime.
* `openblas`: enable OpenBLAS support.

View file

@ -43,6 +43,7 @@ include = [
[features]
coreml = []
cuda = []
hipblas = []
opencl = []
openblas = []
metal = []

View file

@ -58,6 +58,29 @@ fn main() {
}
}
}
#[cfg(feature = "hipblas")]
{
println!("cargo:rustc-link-lib=hipblas");
println!("cargo:rustc-link-lib=rocblas");
println!("cargo:rustc-link-lib=amdhip64");
cfg_if::cfg_if! {
if #[cfg(target_os = "windows")] {
panic!("Due to a problem with the last revision of the ROCm 5.7 library, it is not possible to compile the library for the windows environment.\nSee https://github.com/ggerganov/whisper.cpp/issues/2202 for more details.")
} else {
println!("cargo:rerun-if-env-changed=HIP_PATH");
let hip_path = match env::var("HIP_PATH") {
Ok(path) =>PathBuf::from(path),
Err(_) => PathBuf::from("/opt/rocm"),
};
let hip_lib_path = hip_path.join("lib");
println!("cargo:rustc-link-search={}",hip_lib_path.display());
}
}
}
println!("cargo:rerun-if-changed=wrapper.h");
let out = PathBuf::from(env::var("OUT_DIR").unwrap());
@ -126,6 +149,16 @@ fn main() {
config.define("WHISPER_CUDA", "ON");
}
if cfg!(feature = "hipblas") {
config.define("WHISPER_HIPBLAS", "ON");
config.define("CMAKE_C_COMPILER", "hipcc");
config.define("CMAKE_CXX_COMPILER", "hipcc");
println!("cargo:rerun-if-env-changed=AMDGPU_TARGETS");
if let Ok(gpu_targets) = env::var("AMDGPU_TARGETS") {
config.define("AMDGPU_TARGETS", gpu_targets);
}
}
if cfg!(feature = "openblas") {
config.define("WHISPER_OPENBLAS", "ON");
}