From 3811a77dceee7fb497670f8c5f7daaa265eedeec Mon Sep 17 00:00:00 2001 From: Jonathan Soo Date: Mon, 8 May 2023 09:31:54 -0400 Subject: [PATCH] Change get_logits to return a single slice --- src/whisper_state.rs | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/src/whisper_state.rs b/src/whisper_state.rs index bce1b70..b0305b0 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -268,38 +268,21 @@ impl<'a> WhisperState<'a> { } // logit functions - /// Get the logits obtained from the last call to [WhisperContext::decode]. - /// The logits for the last token are stored in the last row of the matrix. - /// - /// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it - /// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible. - /// - /// # Arguments - /// * segment: The segment to fetch data for. + /// Gets logits obtained from the last call to [WhisperContext::decode]. + /// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input. /// /// # Returns - /// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab. + /// A slice of logits with length equal to n_vocab. /// /// # C++ equivalent /// `float * whisper_get_logits(struct whisper_context * ctx)` - pub fn get_logits(&self, segment: c_int) -> Result>, WhisperError> { + pub fn get_logits(&self) -> Result<&[f32], WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) }; if ret.is_null() { return Err(WhisperError::NullPointer); } - let mut logits = Vec::new(); let n_vocab = self.n_vocab(); - let n_tokens = self.full_n_tokens(segment)?; - for i in 0..n_tokens { - let mut row = Vec::new(); - for j in 0..n_vocab { - let idx = (i * n_vocab) + j; - let val = unsafe { *ret.offset(idx as isize) }; - row.push(val); - } - logits.push(row); - } - Ok(logits) + Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) }) } // model attributes