Merge branch 'master' into cuda-and-opencl-support

# Conflicts:
#	sys/build.rs
This commit is contained in:
Zero 2023-05-14 13:59:49 -06:00
commit 816c17ad80
No known key found for this signature in database
GPG key ID: 3861E636EA1E0E2B
6 changed files with 19 additions and 27 deletions

View file

@ -1,3 +1,12 @@
# Version 0.7.0 (2023-05-10)
* Update upstream whisper.cpp to v1.4.0 (integer quantization support, see last point for CUDA support)
* Expose `WhisperState` as a public type, allowing for more control over the state.
* `WhisperContext::create_state` now returns a `WhisperState` instead of `()`.
* All methods that took a key argument in v0.6.0 have been moved to `WhisperState`.
* Generic key argument on `WhisperContext` has been removed.
* Note: CUDA and OpenCL acceleration is supported on the `cuda-and-opencl-support` branch of the git repo,
and will probably be released in v0.8.0.
# Version 0.6.0 (2023-04-17) # Version 0.6.0 (2023-04-17)
* Update upstream whisper.cpp to v1.3.0 * Update upstream whisper.cpp to v1.3.0
* Fix breaking changes in update, which cascade to users: * Fix breaking changes in update, which cascade to users:

View file

@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
[package] [package]
name = "whisper-rs" name = "whisper-rs"
version = "0.6.0" version = "0.7.0"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp" description = "Rust bindings for whisper.cpp"
license = "Unlicense" license = "Unlicense"
@ -14,8 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
whisper-rs-sys = { path = "sys", version = "0.4" } whisper-rs-sys = { path = "sys", version = "0.5" }
dashmap = "5"
[dev-dependencies] [dev-dependencies]
hound = "3.5.0" hound = "3.5.0"

View file

@ -299,7 +299,7 @@ impl WhisperContext {
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) } unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
} }
/// token functions // token functions
/// Convert a token ID to a string. /// Convert a token ID to a string.
/// ///
/// # Arguments /// # Arguments

View file

@ -2,6 +2,7 @@ use std::ffi::{c_float, c_int, CString};
use std::marker::PhantomData; use std::marker::PhantomData;
use whisper_rs_sys::whisper_token; use whisper_rs_sys::whisper_token;
#[derive(Debug, Clone)]
pub enum SamplingStrategy { pub enum SamplingStrategy {
Greedy { Greedy {
best_of: c_int, best_of: c_int,

View file

@ -268,38 +268,21 @@ impl<'a> WhisperState<'a> {
} }
// logit functions // logit functions
/// Get the logits obtained from the last call to [WhisperContext::decode]. /// Gets logits obtained from the last call to [WhisperContext::decode].
/// The logits for the last token are stored in the last row of the matrix. /// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input.
///
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
///
/// # Arguments
/// * segment: The segment to fetch data for.
/// ///
/// # Returns /// # Returns
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab. /// A slice of logits with length equal to n_vocab.
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `float * whisper_get_logits(struct whisper_context * ctx)` /// `float * whisper_get_logits(struct whisper_context * ctx)`
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> { pub fn get_logits(&self) -> Result<&[f32], WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) }; let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
let mut logits = Vec::new();
let n_vocab = self.n_vocab(); let n_vocab = self.n_vocab();
let n_tokens = self.full_n_tokens(segment)?; Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) })
for i in 0..n_tokens {
let mut row = Vec::new();
for j in 0..n_vocab {
let idx = (i * n_vocab) + j;
let val = unsafe { *ret.offset(idx as isize) };
row.push(val);
}
logits.push(row);
}
Ok(logits)
} }
// model attributes // model attributes

View file

@ -1,6 +1,6 @@
[package] [package]
name = "whisper-rs-sys" name = "whisper-rs-sys"
version = "0.4.0" version = "0.5.0"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp (FFI bindings)" description = "Rust bindings for whisper.cpp (FFI bindings)"
license = "Unlicense" license = "Unlicense"