diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 493db09..d875c3c 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -2,6 +2,7 @@ use std::ffi::{c_float, c_int, CString}; use std::marker::PhantomData; use whisper_rs_sys::whisper_token; +#[derive(Debug, Clone)] pub enum SamplingStrategy { Greedy { best_of: c_int, diff --git a/src/whisper_state.rs b/src/whisper_state.rs index bce1b70..b0305b0 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -268,38 +268,21 @@ impl<'a> WhisperState<'a> { } // logit functions - /// Get the logits obtained from the last call to [WhisperContext::decode]. - /// The logits for the last token are stored in the last row of the matrix. - /// - /// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it - /// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible. - /// - /// # Arguments - /// * segment: The segment to fetch data for. + /// Gets logits obtained from the last call to [WhisperContext::decode]. + /// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input. /// /// # Returns - /// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab. + /// A slice of logits with length equal to n_vocab. /// /// # C++ equivalent /// `float * whisper_get_logits(struct whisper_context * ctx)` - pub fn get_logits(&self, segment: c_int) -> Result>, WhisperError> { + pub fn get_logits(&self) -> Result<&[f32], WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) }; if ret.is_null() { return Err(WhisperError::NullPointer); } - let mut logits = Vec::new(); let n_vocab = self.n_vocab(); - let n_tokens = self.full_n_tokens(segment)?; - for i in 0..n_tokens { - let mut row = Vec::new(); - for j in 0..n_vocab { - let idx = (i * n_vocab) + j; - let val = unsafe { *ret.offset(idx as isize) }; - row.push(val); - } - logits.push(row); - } - Ok(logits) + Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) }) } // model attributes diff --git a/sys/build.rs b/sys/build.rs index 89f1b9b..6558717 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -80,6 +80,7 @@ fn main() { .arg("-DWHISPER_BUILD_TESTS=OFF") .arg("-DWHISPER_BUILD_EXAMPLES=OFF") .arg("-DWHISPER_COREML=1") + .arg("-DWHISPER_COREML_ALLOW_FALLBACK=1") .status() .expect("Failed to generate build script");