update to latest whisper.cpp

This commit is contained in:
0/0 2023-02-07 10:04:10 -07:00
parent 0da39195c0
commit edac524756
No known key found for this signature in database
GPG key ID: 3861E636EA1E0E2B
4 changed files with 277 additions and 55 deletions

View file

@ -19,7 +19,7 @@ pub struct WhisperContext {
}
impl WhisperContext {
/// Create a new WhisperContext.
/// Create a new WhisperContext from a file.
///
/// # Arguments
/// * path: The path to the model file.
@ -28,10 +28,10 @@ impl WhisperContext {
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init(const char * path_model);`
/// `struct whisper_context * whisper_init_from_file(const char * path_model);`
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init(path_cstr.as_ptr()) };
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file(path_cstr.as_ptr()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
@ -44,6 +44,33 @@ impl WhisperContext {
}
}
/// Create a new WhisperContext from a buffer.
///
/// # Arguments
/// * buffer: The buffer containing the model.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);`
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx =
unsafe { whisper_rs_sys::whisper_init_from_buffer(buffer.as_ptr() as _, buffer.len()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self {
ctx,
spectrogram_initialized: false,
encode_complete: false,
decode_once: false,
})
}
}
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// The resulting spectrogram is stored in the context transparently.
///
@ -190,40 +217,90 @@ impl WhisperContext {
}
}
// Token sampling functions
/// Return the token with the highest probability.
/// Make sure to call [WhisperContext::decode] first.
/// Convert the provided text into tokens.
///
/// # Arguments
/// * needs_timestamp
/// * text: The text to convert.
///
/// # Returns
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
/// Ok(Vec<WhisperToken>) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
pub fn sample_best(&mut self) -> Result<WhisperTokenData, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
pub fn tokenize(
&mut self,
text: &str,
max_tokens: usize,
) -> Result<Vec<WhisperToken>, WhisperError> {
// allocate at least max_tokens to ensure the memory is valid
let mut tokens: Vec<WhisperToken> = Vec::with_capacity(max_tokens);
let ret = unsafe {
whisper_rs_sys::whisper_tokenize(
self.ctx,
text.as_ptr() as *const _,
tokens.as_mut_ptr(),
max_tokens as c_int,
)
};
if ret == -1 {
Err(WhisperError::InvalidText)
} else {
// SAFETY: when ret != -1, we know that the length of the vector is at least ret tokens
unsafe { tokens.set_len(ret as usize) };
Ok(tokens)
}
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) };
Ok(ret)
}
/// Return the token with the most probable timestamp.
/// Make sure to call [WhisperContext::decode] first.
// Language functions
/// Use mel data at offset_ms to try and auto-detect the spoken language
/// Make sure to call pcm_to_mel() or set_mel() first
///
/// # Arguments
/// * offset_ms: The offset in milliseconds to use for the language detection.
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
/// Ok(Vec<f32>) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `whisper_token whisper_sample_timestamp(struct whisper_context * ctx)`
pub fn sample_timestamp(&mut self, is_initial: bool) -> Result<WhisperTokenData, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
pub fn lang_detect(
&mut self,
offset_ms: usize,
threads: usize,
) -> Result<Vec<f32>, WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
let ret = unsafe {
whisper_rs_sys::whisper_lang_auto_detect(
self.ctx,
offset_ms as c_int,
threads as c_int,
lang_probs.as_mut_ptr(),
)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateEvaluation)
} else {
assert_eq!(
ret as usize,
lang_probs.len(),
"lang_probs length mismatch: this is a bug in whisper.cpp"
);
// if we're still running, double check that the length is correct, otherwise print to stderr
// and abort, as this will cause Undefined Behavior
// might get here due to the unwind being caught by a user-installed panic handler
if lang_probs.len() != ret as usize {
eprintln!("lang_probs length mismatch: this is a bug in whisper.cpp, aborting");
std::process::abort();
}
Ok(lang_probs)
}
let ret = unsafe { whisper_rs_sys::whisper_sample_timestamp(self.ctx, is_initial) };
Ok(ret)
}
// model attributes
@ -263,6 +340,18 @@ impl WhisperContext {
unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) }
}
/// Get n_audio_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_n_audio_ctx (struct whisper_context * ctx)`
#[inline]
pub fn n_audio_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) }
}
/// Does this model support multiple languages?
///
/// # C++ equivalent
@ -272,25 +361,46 @@ impl WhisperContext {
unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
}
/// The probabilities for the next token.
/// Make sure to call [WhisperContext::decode] first.
// logit functions
/// Get the logits obtained from the last call to [WhisperContext::decode].
/// The logits for the last token are stored in the last row of the matrix.
///
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
///
/// # Arguments
/// * segment: The segment to fetch data for.
///
/// # Returns
/// Ok(*const f32) on success, Err(WhisperError) on failure.
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab.
///
/// # C++ equivalent
/// `float * whisper_get_probs(struct whisper_context * ctx)`
pub fn get_probs(&mut self) -> Result<*const f32, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
/// `float * whisper_get_logits(struct whisper_context * ctx)`
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
let ret = unsafe { whisper_rs_sys::whisper_get_probs(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_get_logits(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
Ok(ret)
let mut logits = Vec::new();
let n_vocab = self.n_vocab();
let n_tokens = self.full_n_tokens(segment);
for i in 0..n_tokens {
let mut row = Vec::new();
for j in 0..n_vocab {
let idx = (i * n_vocab) + j;
let val = unsafe { *ret.offset(idx as isize) };
row.push(val);
}
logits.push(row);
}
Ok(logits)
}
// token functions
/// Convert a token ID to a string.
///
/// # Arguments
@ -311,7 +421,6 @@ impl WhisperContext {
Ok(r_str.to_string())
}
// special tokens
/// Get the ID of the eot token.
///
/// # C++ equivalent
@ -366,6 +475,18 @@ impl WhisperContext {
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
}
/// Get the ID of a specified language token
///
/// # Arguments
/// * lang_id: ID of the language
///
/// # C++ equivalent
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
#[inline]
pub fn token_lang(&self, lang_id: c_int) -> WhisperToken {
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
}
/// Print performance statistics to stderr.
///
/// # C++ equivalent