Refactor the entire token/segment usage
This was spurred by noticing a trivial case of UB in the original code: all one needed was an out-of-bounds index on any of several methods with tokens or segment indexes on the state to cause UB. I took this opportunity to consolidate methods into Rust structs that verify their index before use.
This commit is contained in:
parent
35ccb56e7f
commit
15e70ffd07
6 changed files with 432 additions and 305 deletions
|
|
@ -1,7 +1,13 @@
|
|||
use std::ffi::{c_int, CStr};
|
||||
use std::ffi::c_int;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData};
|
||||
use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperTokenId};
|
||||
|
||||
mod iterator;
|
||||
mod segment;
|
||||
|
||||
pub use iterator::WhisperStateSegmentIterator;
|
||||
pub use segment::{WhisperSegment, WhisperToken};
|
||||
|
||||
/// Rustified pointer to a Whisper state.
|
||||
#[derive(Debug)]
|
||||
|
|
@ -151,7 +157,7 @@ impl WhisperState {
|
|||
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
||||
pub fn decode(
|
||||
&mut self,
|
||||
tokens: &[WhisperToken],
|
||||
tokens: &[WhisperTokenId],
|
||||
n_past: usize,
|
||||
threads: usize,
|
||||
) -> Result<(), WhisperError> {
|
||||
|
|
@ -179,11 +185,11 @@ impl WhisperState {
|
|||
|
||||
// Language functions
|
||||
/// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||
/// Make sure to call pcm_to_mel() or set_mel() first
|
||||
/// Make sure to call [`Self::pcm_to_mel`] or [`Self::set_mel`] first
|
||||
///
|
||||
/// # Arguments
|
||||
/// * offset_ms: The offset in milliseconds to use for the language detection.
|
||||
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||
/// * `offset_ms`: The offset in milliseconds to use for the language detection.
|
||||
/// * `n_threads`: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok((i32, Vec<f32>))` on success where the i32 is detected language id and Vec<f32>
|
||||
|
|
@ -309,8 +315,8 @@ impl WhisperState {
|
|||
/// # C++ equivalent
|
||||
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn full_n_segments(&self) -> Result<c_int, WhisperError> {
|
||||
Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) })
|
||||
pub fn full_n_segments(&self) -> c_int {
|
||||
unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) }
|
||||
}
|
||||
|
||||
/// Language ID associated with the provided state.
|
||||
|
|
@ -318,281 +324,33 @@ impl WhisperState {
|
|||
/// # C++ equivalent
|
||||
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
|
||||
#[inline]
|
||||
pub fn full_lang_id_from_state(&self) -> Result<c_int, WhisperError> {
|
||||
Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) })
|
||||
pub fn full_lang_id_from_state(&self) -> c_int {
|
||||
unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) }
|
||||
}
|
||||
|
||||
/// Get the start time of the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
|
||||
#[inline]
|
||||
pub fn full_get_segment_t0(&self, segment: c_int) -> Result<i64, WhisperError> {
|
||||
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.ptr, segment) })
|
||||
fn segment_in_bounds(&self, segment: c_int) -> bool {
|
||||
segment >= 0 && segment < self.full_n_segments()
|
||||
}
|
||||
|
||||
/// Get the end time of the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
|
||||
#[inline]
|
||||
pub fn full_get_segment_t1(&self, segment: c_int) -> Result<i64, WhisperError> {
|
||||
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) })
|
||||
}
|
||||
|
||||
fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> {
|
||||
let ret =
|
||||
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
unsafe { Ok(CStr::from_ptr(ret)) }
|
||||
}
|
||||
|
||||
/// Get the raw bytes of the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// Get a [`WhisperSegment`] object for the specified segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(Vec<u8>)` on success, with the returned bytes or
|
||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
|
||||
Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec())
|
||||
/// `Some(WhisperSegment)` if `segment` is in bounds, otherwise [`None`].
|
||||
pub fn get_segment(&self, segment: c_int) -> Option<WhisperSegment<'_>> {
|
||||
self.segment_in_bounds(segment)
|
||||
.then(|| unsafe { WhisperSegment::new_unchecked(self, segment) })
|
||||
}
|
||||
|
||||
/// Get the text of the specified segment.
|
||||
/// Get a [`WhisperSegment`] object for the specified segment index.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(String)` on success, with the UTF-8 validated string, or
|
||||
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
|
||||
Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string())
|
||||
/// # Safety
|
||||
/// You must ensure `segment` is in bounds for this [`WhisperState`].
|
||||
pub unsafe fn get_segment_unchecked(&self, segment: c_int) -> WhisperSegment<'_> {
|
||||
WhisperSegment::new_unchecked(self, segment)
|
||||
}
|
||||
|
||||
/// Get the text of the specified segment.
|
||||
/// This function differs from [WhisperState::full_get_segment_text]
|
||||
/// in that it ignores invalid UTF-8 in whisper strings,
|
||||
/// instead opting to replace it with the replacement character.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(String)` on success, or
|
||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||
pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result<String, WhisperError> {
|
||||
Ok(self
|
||||
.full_get_segment_raw(segment)?
|
||||
.to_string_lossy()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Get number of tokens in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
||||
#[inline]
|
||||
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
|
||||
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
|
||||
}
|
||||
|
||||
fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> {
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
segment,
|
||||
token,
|
||||
)
|
||||
};
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
unsafe { Ok(CStr::from_ptr(ret)) }
|
||||
}
|
||||
|
||||
/// Get the raw token bytes of the specified token in the specified segment.
|
||||
///
|
||||
/// Useful if you're using a language for which whisper is known to split tokens
|
||||
/// away from UTF-8 character boundaries.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(Vec<u8>)` on success, with the returned bytes or
|
||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_bytes(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<Vec<u8>, WhisperError> {
|
||||
Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec())
|
||||
}
|
||||
|
||||
/// Get the token text of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(String)` on success, with the UTF-8 validated string, or
|
||||
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_text(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<String, WhisperError> {
|
||||
Ok(self
|
||||
.full_get_token_raw(segment, token)?
|
||||
.to_str()?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Get the token text of the specified token in the specified segment.
|
||||
/// This function differs from [WhisperState::full_get_token_text]
|
||||
/// in that it ignores invalid UTF-8 in whisper strings,
|
||||
/// instead opting to replace it with the replacement character.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(String)` on success, or
|
||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_text_lossy(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<String, WhisperError> {
|
||||
Ok(self
|
||||
.full_get_token_raw(segment, token)?
|
||||
.to_string_lossy()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Get the token ID of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// [crate::WhisperToken]
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_id(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<WhisperToken, WhisperError> {
|
||||
Ok(unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_id_from_state(self.ptr, segment, token)
|
||||
})
|
||||
}
|
||||
|
||||
/// Get token data for the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// [crate::WhisperTokenData]
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
#[inline]
|
||||
pub fn full_get_token_data(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<WhisperTokenData, WhisperError> {
|
||||
Ok(unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_data_from_state(self.ptr, segment, token)
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the probability of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// f32
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
#[inline]
|
||||
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> Result<f32, WhisperError> {
|
||||
Ok(
|
||||
unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_p_from_state(self.ptr, segment, token)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
/// Get whether the next segment is predicted as a speaker turn.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * i_segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// bool
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)`
|
||||
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
|
||||
unsafe {
|
||||
whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(
|
||||
self.ptr, i_segment,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the no_speech probability for the specified segment
|
||||
pub fn full_get_segment_no_speech_prob(&self, i_segment: c_int) -> f32 {
|
||||
unsafe {
|
||||
whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state(self.ptr, i_segment)
|
||||
}
|
||||
/// Get an iterator over all segments.
|
||||
pub fn as_iter(&self) -> WhisperStateSegmentIterator {
|
||||
WhisperStateSegmentIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue