Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
343a3029fb
3 changed files with 7 additions and 22 deletions
|
|
@ -2,6 +2,7 @@ use std::ffi::{c_float, c_int, CString};
|
|||
use std::marker::PhantomData;
|
||||
use whisper_rs_sys::whisper_token;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SamplingStrategy {
|
||||
Greedy {
|
||||
best_of: c_int,
|
||||
|
|
|
|||
|
|
@ -268,38 +268,21 @@ impl<'a> WhisperState<'a> {
|
|||
}
|
||||
|
||||
// logit functions
|
||||
/// Get the logits obtained from the last call to [WhisperContext::decode].
|
||||
/// The logits for the last token are stored in the last row of the matrix.
|
||||
///
|
||||
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
|
||||
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: The segment to fetch data for.
|
||||
/// Gets logits obtained from the last call to [WhisperContext::decode].
|
||||
/// As of whisper.cpp 1.4.1, only a single row of logits is available, corresponding to the last token in the input.
|
||||
///
|
||||
/// # Returns
|
||||
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab.
|
||||
/// A slice of logits with length equal to n_vocab.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `float * whisper_get_logits(struct whisper_context * ctx)`
|
||||
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
|
||||
pub fn get_logits(&self) -> Result<&[f32], WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) };
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
let mut logits = Vec::new();
|
||||
let n_vocab = self.n_vocab();
|
||||
let n_tokens = self.full_n_tokens(segment)?;
|
||||
for i in 0..n_tokens {
|
||||
let mut row = Vec::new();
|
||||
for j in 0..n_vocab {
|
||||
let idx = (i * n_vocab) + j;
|
||||
let val = unsafe { *ret.offset(idx as isize) };
|
||||
row.push(val);
|
||||
}
|
||||
logits.push(row);
|
||||
}
|
||||
Ok(logits)
|
||||
Ok(unsafe { std::slice::from_raw_parts(ret, n_vocab as usize) })
|
||||
}
|
||||
|
||||
// model attributes
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ fn main() {
|
|||
.arg("-DWHISPER_BUILD_TESTS=OFF")
|
||||
.arg("-DWHISPER_BUILD_EXAMPLES=OFF")
|
||||
.arg("-DWHISPER_COREML=1")
|
||||
.arg("-DWHISPER_COREML_ALLOW_FALLBACK=1")
|
||||
.status()
|
||||
.expect("Failed to generate build script");
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue