Change get_logits to return a single slice

This commit is contained in:
Jonathan Soo 2023-05-08 09:31:54 -04:00
parent 0859b41191
commit 3811a77dce

View file

@ -268,38 +268,21 @@ impl<'a> WhisperState<'a> {
}
// logit functions
/// Get the logits obtained from the last call to [WhisperContext::decode].
/// The logits for the last token are stored in the last row of the matrix.
///
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
///
/// # Arguments
/// * segment: The segment to fetch data for.
/// Gets logits obtained from the last call to [WhisperContext::decode].
/// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input.
///
/// # Returns
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab.
/// A slice of logits with length equal to n_vocab.
///
/// # C++ equivalent
/// `float * whisper_get_logits(struct whisper_context * ctx)`
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
pub fn get_logits(&self) -> Result<&[f32], WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let mut logits = Vec::new();
let n_vocab = self.n_vocab();
let n_tokens = self.full_n_tokens(segment)?;
for i in 0..n_tokens {
let mut row = Vec::new();
for j in 0..n_vocab {
let idx = (i * n_vocab) + j;
let val = unsafe { *ret.offset(idx as isize) };
row.push(val);
}
logits.push(row);
}
Ok(logits)
Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) })
}
// model attributes