From 3811a77dceee7fb497670f8c5f7daaa265eedeec Mon Sep 17 00:00:00 2001 From: Jonathan Soo Date: Mon, 8 May 2023 09:31:54 -0400 Subject: [PATCH 1/6] Change get_logits to return a single slice --- src/whisper_state.rs | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/whisper_state.rs b/src/whisper_state.rs index bce1b70..b0305b0 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -268,38 +268,21 @@ impl<'a> WhisperState<'a> { } // logit functions - /// Get the logits obtained from the last call to [WhisperContext::decode]. - /// The logits for the last token are stored in the last row of the matrix. - /// - /// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it - /// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible. - /// - /// # Arguments - /// * segment: The segment to fetch data for. + /// Gets logits obtained from the last call to [WhisperContext::decode]. + /// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input. /// /// # Returns - /// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab. + /// A slice of logits with length equal to n_vocab. /// /// # C++ equivalent /// `float * whisper_get_logits(struct whisper_context * ctx)` - pub fn get_logits(&self, segment: c_int) -> Result>, WhisperError> { + pub fn get_logits(&self) -> Result<&[f32], WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) }; if ret.is_null() { return Err(WhisperError::NullPointer); } - let mut logits = Vec::new(); let n_vocab = self.n_vocab(); - let n_tokens = self.full_n_tokens(segment)?; - for i in 0..n_tokens { - let mut row = Vec::new(); - for j in 0..n_vocab { - let idx = (i * n_vocab) + j; - let val = unsafe { *ret.offset(idx as isize) }; - row.push(val); - } - logits.push(row); - } - Ok(logits) + Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) }) } // model attributes From 47f53af20d8ddba574f729969ef69fbf51897ae8 Mon Sep 17 00:00:00 2001 From: Jonathan Soo Date: Mon, 8 May 2023 09:35:09 -0400 Subject: [PATCH 2/6] Make SamplingStrategy Debug + Clone --- src/whisper_params.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 493db09..d875c3c 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -2,6 +2,7 @@ use std::ffi::{c_float, c_int, CString}; use std::marker::PhantomData; use whisper_rs_sys::whisper_token; +#[derive(Debug, Clone)] pub enum SamplingStrategy { Greedy { best_of: c_int, From 5c0d6ce9571131a155d9d59f3004e8b091e9229d Mon Sep 17 00:00:00 2001 From: Jonathan Soo Date: Mon, 8 May 2023 10:31:03 -0400 Subject: [PATCH 3/6] Allow fallback to GGML when CoreML model not available. --- sys/build.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/sys/build.rs b/sys/build.rs index 89f1b9b..6558717 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -80,6 +80,7 @@ fn main() { .arg("-DWHISPER_BUILD_TESTS=OFF") .arg("-DWHISPER_BUILD_EXAMPLES=OFF") .arg("-DWHISPER_COREML=1") + .arg("-DWHISPER_COREML_ALLOW_FALLBACK=1") .status() .expect("Failed to generate build script"); From fc062e7541942ca13bf3ddfa469215ea55726003 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 10 May 2023 14:03:16 -0600 Subject: [PATCH 4/6] fix incorrect doc comments --- src/whisper_ctx.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 75aba43..c8c906c 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -299,7 +299,7 @@ impl WhisperContext { unsafe { whisper_rs_sys::whisper_model_type(self.ctx) } } - /// token functions + // token functions /// Convert a token ID to a string. /// /// # Arguments From 060b9b569749d9eb01d286740bb7bf7c92dc4351 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 10 May 2023 14:07:11 -0600 Subject: [PATCH 5/6] update whisper-rs-sys to v0.5.0 --- sys/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sys/Cargo.toml b/sys/Cargo.toml index f73213a..6bce7f3 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whisper-rs-sys" -version = "0.4.0" +version = "0.5.0" edition = "2021" description = "Rust bindings for whisper.cpp (FFI bindings)" license = "Unlicense" From 7251519c618d8505ac0257f27c5848f86d3111a6 Mon Sep 17 00:00:00 2001 From: Zero Date: Wed, 10 May 2023 14:07:22 -0600 Subject: [PATCH 6/6] update whisper-rs to v0.7.0 --- CHANGELOG.md | 9 +++++++++ Cargo.toml | 5 ++--- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fab056c..a488f87 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,12 @@ +# Version 0.7.0 (2023-05-10) +* Update upstream whisper.cpp to v1.4.0 (integer quantization support, see last point for CUDA support) +* Expose `WhisperState` as a public type, allowing for more control over the state. + * `WhisperContext::create_state` now returns a `WhisperState` instead of `()`. + * All methods that took a key argument in v0.6.0 have been moved to `WhisperState`. +* Generic key argument on `WhisperContext` has been removed. +* Note: CUDA and OpenCL acceleration is supported on the `cuda-and-opencl-support` branch of the git repo, + and will probably be released in v0.8.0. + # Version 0.6.0 (2023-04-17) * Update upstream whisper.cpp to v1.3.0 * Fix breaking changes in update, which cascade to users: diff --git a/Cargo.toml b/Cargo.toml index 5cac3fb..e90bc26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"] [package] name = "whisper-rs" -version = "0.6.0" +version = "0.7.0" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -14,8 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -whisper-rs-sys = { path = "sys", version = "0.4" } -dashmap = "5" +whisper-rs-sys = { path = "sys", version = "0.5" } [dev-dependencies] hound = "3.5.0"