Add OpenVINO support

This commit is contained in:
Zero 2023-08-28 16:28:43 -06:00
parent 776abc3c91
commit 5c140c14d4
No known key found for this signature in database
GPG key ID: 3861E636EA1E0E2B
6 changed files with 100 additions and 2 deletions

View file

@ -56,6 +56,44 @@ impl WhisperContext {
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Using this context, enable use of OpenVINO for encoder inference.
///
/// # Arguments
/// * `model_path`: An optional path to the OpenVINO encoder IR model.
/// If set to `None`,
/// the path will be generated from the ggml model path
/// that was passed in to whisper_init_from_file.
/// For example, if the model path was "/path/to/ggml-base.en.bin",
/// then the OpenVINO IR model path will be assumed as "/path/to/ggml-base.en-encoder-openvino.xml".
///
/// * `device`: The OpenVINO device to use for inference (e.g. "CPU", "GPU")
///
/// * `cache_dir`: Optional cache directory that can speed up init time,
/// especially for GPU, by caching compiled 'blobs' there.
/// Set to nullptr if not used.
///
/// # Returns
/// `true` on success, `false` if OpenVINO was not enabled at compile time
/// (enable the `openvino` feature flag in your Cargo.toml).
///
/// # C++ equivalent
/// `int whisper_ctx_init_openvino_encoder(struct whisper_context * ctx, const char * model_path, const char * device, const char * cache_dir);`
#[cfg(feature = "openvino")]
pub fn init_openvino_encoder(&mut self, model_path: Option<&str>, device: &str, cache_dir: Option<&str>) -> bool {
let model_path = model_path.map(|s| CString::new(s).unwrap());
let device = CString::new(device).unwrap();
let cache_dir = cache_dir.map(|s| CString::new(s).unwrap());
let ret = unsafe {
whisper_rs_sys::whisper_ctx_init_openvino_encoder(
self.ctx,
model_path.map(|s| s.as_ptr()).unwrap_or(std::ptr::null()),
device.as_ptr(),
cache_dir.map(|s| s.as_ptr()).unwrap_or(std::ptr::null()),
)
};
ret != 0
}
/// Create a new state object, ready for use.
///
/// # Returns