update upstream whisper.cpp and fix breaking changes

This commit is contained in:
0/0 2022-10-28 19:37:42 -06:00
parent 4ccd746949
commit 1632ac11fe
No known key found for this signature in database
GPG key ID: DE8D5010C0AAA3DC
10 changed files with 657 additions and 402 deletions

View file

@ -194,11 +194,11 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
pub fn sample_best(&mut self, needs_timestamp: bool) -> Result<WhisperToken, WhisperError> {
pub fn sample_best(&mut self) -> Result<WhisperToken, WhisperError> {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) };
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) };
Ok(ret)
}
@ -446,6 +446,90 @@ impl WhisperContext {
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
}
/// Get number of tokens in the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
Ok(ret as c_int)
}
}
/// Get the token text of the specified token in the specified segment.
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_text(
&self,
segment: c_int,
token: c_int,
) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
}
/// Get the token ID of the specified token in the specified segment.
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_id(
&self,
segment: c_int,
token: c_int,
) -> Result<WhisperToken, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
Ok(ret as WhisperToken)
}
}
/// Get the probability of the specified token in the specified segment.
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// f32
///
/// # C++ equivalent
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
#[inline]
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 {
unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) }
}
}
impl Drop for WhisperContext {