From 15e70ffd07affbce7913534a233f9af49527956f Mon Sep 17 00:00:00 2001 From: Niko Date: Sat, 2 Aug 2025 18:30:48 -0700 Subject: [PATCH] Refactor the entire token/segment usage This was spurred by noticing a trivial case of UB in the original code: all one needed was an out-of-bounds index on any of several methods with tokens or segment indexes on the state to cause UB. I took this opportunity to consolidate methods into Rust structs that verify their index before use. --- src/lib.rs | 4 +- src/whisper_ctx.rs | 30 +-- src/whisper_ctx_wrapper.rs | 30 +-- src/whisper_state.rs | 304 ++++-------------------------- src/whisper_state/iterator.rs | 28 +++ src/whisper_state/segment.rs | 341 ++++++++++++++++++++++++++++++++++ 6 files changed, 432 insertions(+), 305 deletions(-) create mode 100644 src/whisper_state/iterator.rs create mode 100644 src/whisper_state/segment.rs diff --git a/src/lib.rs b/src/lib.rs index cf20757..3222b26 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,14 +31,14 @@ pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData}; #[cfg(feature = "raw-api")] pub use whisper_rs_sys; -pub use whisper_state::WhisperState; +pub use whisper_state::{WhisperSegment, WhisperState, WhisperStateSegmentIterator, WhisperToken}; pub use whisper_vad::*; pub type WhisperSysContext = whisper_rs_sys::whisper_context; pub type WhisperSysState = whisper_rs_sys::whisper_state; pub type WhisperTokenData = whisper_rs_sys::whisper_token_data; -pub type WhisperToken = whisper_rs_sys::whisper_token; +pub type WhisperTokenId = whisper_rs_sys::whisper_token; pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback; pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback; diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 0e7be5c..ee36ff0 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -1,5 +1,5 @@ use crate::error::WhisperError; -use crate::WhisperToken; +use crate::WhisperTokenId; use std::ffi::{c_int, CStr, CString}; /// Safe Rust wrapper around a Whisper context. @@ -84,12 +84,12 @@ impl WhisperInnerContext { &self, text: &str, max_tokens: usize, - ) -> Result, WhisperError> { + ) -> Result, WhisperError> { // convert the text to a nul-terminated C string. Will raise an error if the text contains // any nul bytes. let text = CString::new(text)?; // allocate at least max_tokens to ensure the memory is valid - let mut tokens: Vec = Vec::with_capacity(max_tokens); + let mut tokens: Vec = Vec::with_capacity(max_tokens); let ret = unsafe { whisper_rs_sys::whisper_tokenize( self.ctx, @@ -307,7 +307,7 @@ impl WhisperInnerContext { /// /// # C++ equivalent /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> { + pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> { let c_str = self.token_to_cstr(token_id)?; let r_str = c_str.to_str()?; Ok(r_str) @@ -323,7 +323,7 @@ impl WhisperInnerContext { /// /// # C++ equivalent /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> { + pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) }; if ret.is_null() { return Err(WhisperError::NullPointer); @@ -351,7 +351,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_eot (struct whisper_context * ctx)` #[inline] - pub fn token_eot(&self) -> WhisperToken { + pub fn token_eot(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_eot(self.ctx) } } @@ -360,7 +360,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_sot (struct whisper_context * ctx)` #[inline] - pub fn token_sot(&self) -> WhisperToken { + pub fn token_sot(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_sot(self.ctx) } } @@ -369,7 +369,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_solm(struct whisper_context * ctx)` #[inline] - pub fn token_solm(&self) -> WhisperToken { + pub fn token_solm(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_solm(self.ctx) } } @@ -378,7 +378,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_prev(struct whisper_context * ctx)` #[inline] - pub fn token_prev(&self) -> WhisperToken { + pub fn token_prev(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_prev(self.ctx) } } @@ -387,7 +387,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_nosp(struct whisper_context * ctx)` #[inline] - pub fn token_nosp(&self) -> WhisperToken { + pub fn token_nosp(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_nosp(self.ctx) } } @@ -396,7 +396,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_not (struct whisper_context * ctx)` #[inline] - pub fn token_not(&self) -> WhisperToken { + pub fn token_not(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_not(self.ctx) } } @@ -405,7 +405,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_beg (struct whisper_context * ctx)` #[inline] - pub fn token_beg(&self) -> WhisperToken { + pub fn token_beg(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) } } @@ -417,7 +417,7 @@ impl WhisperInnerContext { /// # C++ equivalent /// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)` #[inline] - pub fn token_lang(&self, lang_id: c_int) -> WhisperToken { + pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) } } @@ -444,7 +444,7 @@ impl WhisperInnerContext { /// /// # C++ equivalent /// `whisper_token whisper_token_translate ()` - pub fn token_translate(&self) -> WhisperToken { + pub fn token_translate(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_translate(self.ctx) } } @@ -452,7 +452,7 @@ impl WhisperInnerContext { /// /// # C++ equivalent /// `whisper_token whisper_token_transcribe()` - pub fn token_transcribe(&self) -> WhisperToken { + pub fn token_transcribe(&self) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) } } } diff --git a/src/whisper_ctx_wrapper.rs b/src/whisper_ctx_wrapper.rs index ff3caff..5d2eec8 100644 --- a/src/whisper_ctx_wrapper.rs +++ b/src/whisper_ctx_wrapper.rs @@ -2,7 +2,7 @@ use std::ffi::{c_int, CStr}; use std::sync::Arc; use crate::{ - WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperToken, + WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperTokenId, }; pub struct WhisperContext { @@ -57,7 +57,7 @@ impl WhisperContext { /// * text: The text to convert. /// /// # Returns - /// `Ok(Vec)` on success, `Err(WhisperError)` on failure. + /// `Ok(Vec)` on success, `Err(WhisperError)` on failure. /// /// # C++ equivalent /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` @@ -65,7 +65,7 @@ impl WhisperContext { &self, text: &str, max_tokens: usize, - ) -> Result, WhisperError> { + ) -> Result, WhisperError> { self.ctx.tokenize(text, max_tokens) } @@ -269,7 +269,7 @@ impl WhisperContext { /// /// # C++ equivalent /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> { + pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> { self.ctx.token_to_str(token_id) } @@ -283,7 +283,7 @@ impl WhisperContext { /// /// # C++ equivalent /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> { + pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> { self.ctx.token_to_cstr(token_id) } @@ -301,7 +301,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_eot (struct whisper_context * ctx)` #[inline] - pub fn token_eot(&self) -> WhisperToken { + pub fn token_eot(&self) -> WhisperTokenId { self.ctx.token_eot() } @@ -310,7 +310,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_sot (struct whisper_context * ctx)` #[inline] - pub fn token_sot(&self) -> WhisperToken { + pub fn token_sot(&self) -> WhisperTokenId { self.ctx.token_sot() } @@ -319,7 +319,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_solm(struct whisper_context * ctx)` #[inline] - pub fn token_solm(&self) -> WhisperToken { + pub fn token_solm(&self) -> WhisperTokenId { self.ctx.token_solm() } @@ -328,7 +328,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_prev(struct whisper_context * ctx)` #[inline] - pub fn token_prev(&self) -> WhisperToken { + pub fn token_prev(&self) -> WhisperTokenId { self.ctx.token_prev() } @@ -337,7 +337,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_nosp(struct whisper_context * ctx)` #[inline] - pub fn token_nosp(&self) -> WhisperToken { + pub fn token_nosp(&self) -> WhisperTokenId { self.ctx.token_nosp() } @@ -346,7 +346,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_not (struct whisper_context * ctx)` #[inline] - pub fn token_not(&self) -> WhisperToken { + pub fn token_not(&self) -> WhisperTokenId { self.ctx.token_not() } @@ -355,7 +355,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_beg (struct whisper_context * ctx)` #[inline] - pub fn token_beg(&self) -> WhisperToken { + pub fn token_beg(&self) -> WhisperTokenId { self.ctx.token_beg() } @@ -367,7 +367,7 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)` #[inline] - pub fn token_lang(&self, lang_id: c_int) -> WhisperToken { + pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId { self.ctx.token_lang(lang_id) } @@ -394,7 +394,7 @@ impl WhisperContext { /// /// # C++ equivalent /// `whisper_token whisper_token_translate ()` - pub fn token_translate(&self) -> WhisperToken { + pub fn token_translate(&self) -> WhisperTokenId { self.ctx.token_translate() } @@ -402,7 +402,7 @@ impl WhisperContext { /// /// # C++ equivalent /// `whisper_token whisper_token_transcribe()` - pub fn token_transcribe(&self) -> WhisperToken { + pub fn token_transcribe(&self) -> WhisperTokenId { self.ctx.token_transcribe() } diff --git a/src/whisper_state.rs b/src/whisper_state.rs index 81ff014..b5aee36 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -1,7 +1,13 @@ -use std::ffi::{c_int, CStr}; +use std::ffi::c_int; use std::sync::Arc; -use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData}; +use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperTokenId}; + +mod iterator; +mod segment; + +pub use iterator::WhisperStateSegmentIterator; +pub use segment::{WhisperSegment, WhisperToken}; /// Rustified pointer to a Whisper state. #[derive(Debug)] @@ -151,7 +157,7 @@ impl WhisperState { /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` pub fn decode( &mut self, - tokens: &[WhisperToken], + tokens: &[WhisperTokenId], n_past: usize, threads: usize, ) -> Result<(), WhisperError> { @@ -179,11 +185,11 @@ impl WhisperState { // Language functions /// Use mel data at offset_ms to try and auto-detect the spoken language - /// Make sure to call pcm_to_mel() or set_mel() first + /// Make sure to call [`Self::pcm_to_mel`] or [`Self::set_mel`] first /// /// # Arguments - /// * offset_ms: The offset in milliseconds to use for the language detection. - /// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. + /// * `offset_ms`: The offset in milliseconds to use for the language detection. + /// * `n_threads`: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. /// /// # Returns /// `Ok((i32, Vec))` on success where the i32 is detected language id and Vec @@ -309,8 +315,8 @@ impl WhisperState { /// # C++ equivalent /// `int whisper_full_n_segments(struct whisper_context * ctx)` #[inline] - pub fn full_n_segments(&self) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) }) + pub fn full_n_segments(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) } } /// Language ID associated with the provided state. @@ -318,281 +324,33 @@ impl WhisperState { /// # C++ equivalent /// `int whisper_full_lang_id_from_state(struct whisper_state * state);` #[inline] - pub fn full_lang_id_from_state(&self) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) }) + pub fn full_lang_id_from_state(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) } } - /// Get the start time of the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// - /// # C++ equivalent - /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` - #[inline] - pub fn full_get_segment_t0(&self, segment: c_int) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.ptr, segment) }) + fn segment_in_bounds(&self, segment: c_int) -> bool { + segment >= 0 && segment < self.full_n_segments() } - /// Get the end time of the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// - /// # C++ equivalent - /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` - #[inline] - pub fn full_get_segment_t1(&self, segment: c_int) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) }) - } - - fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> { - let ret = - unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - unsafe { Ok(CStr::from_ptr(ret)) } - } - - /// Get the raw bytes of the specified segment. - /// - /// # Arguments - /// * segment: Segment index. + /// Get a [`WhisperSegment`] object for the specified segment index. /// /// # Returns - /// `Ok(Vec)` on success, with the returned bytes or - /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) - /// - /// # C++ equivalent - /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_bytes(&self, segment: c_int) -> Result, WhisperError> { - Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec()) + /// `Some(WhisperSegment)` if `segment` is in bounds, otherwise [`None`]. + pub fn get_segment(&self, segment: c_int) -> Option> { + self.segment_in_bounds(segment) + .then(|| unsafe { WhisperSegment::new_unchecked(self, segment) }) } - /// Get the text of the specified segment. + /// Get a [`WhisperSegment`] object for the specified segment index. /// - /// # Arguments - /// * segment: Segment index. - /// - /// # Returns - /// `Ok(String)` on success, with the UTF-8 validated string, or - /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) - /// - /// # C++ equivalent - /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_text(&self, segment: c_int) -> Result { - Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string()) + /// # Safety + /// You must ensure `segment` is in bounds for this [`WhisperState`]. + pub unsafe fn get_segment_unchecked(&self, segment: c_int) -> WhisperSegment<'_> { + WhisperSegment::new_unchecked(self, segment) } - /// Get the text of the specified segment. - /// This function differs from [WhisperState::full_get_segment_text] - /// in that it ignores invalid UTF-8 in whisper strings, - /// instead opting to replace it with the replacement character. - /// - /// # Arguments - /// * segment: Segment index. - /// - /// # Returns - /// `Ok(String)` on success, or - /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) - /// - /// # C++ equivalent - /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result { - Ok(self - .full_get_segment_raw(segment)? - .to_string_lossy() - .to_string()) - } - - /// Get number of tokens in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// - /// # Returns - /// c_int - /// - /// # C++ equivalent - /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` - #[inline] - pub fn full_n_tokens(&self, segment: c_int) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) }) - } - - fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> { - let ret = unsafe { - whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx.ctx, - self.ptr, - segment, - token, - ) - }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - unsafe { Ok(CStr::from_ptr(ret)) } - } - - /// Get the raw token bytes of the specified token in the specified segment. - /// - /// Useful if you're using a language for which whisper is known to split tokens - /// away from UTF-8 character boundaries. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// `Ok(Vec)` on success, with the returned bytes or - /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_bytes( - &self, - segment: c_int, - token: c_int, - ) -> Result, WhisperError> { - Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec()) - } - - /// Get the token text of the specified token in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// `Ok(String)` on success, with the UTF-8 validated string, or - /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_text( - &self, - segment: c_int, - token: c_int, - ) -> Result { - Ok(self - .full_get_token_raw(segment, token)? - .to_str()? - .to_string()) - } - - /// Get the token text of the specified token in the specified segment. - /// This function differs from [WhisperState::full_get_token_text] - /// in that it ignores invalid UTF-8 in whisper strings, - /// instead opting to replace it with the replacement character. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// `Ok(String)` on success, or - /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_text_lossy( - &self, - segment: c_int, - token: c_int, - ) -> Result { - Ok(self - .full_get_token_raw(segment, token)? - .to_string_lossy() - .to_string()) - } - - /// Get the token ID of the specified token in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// [crate::WhisperToken] - /// - /// # C++ equivalent - /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_id( - &self, - segment: c_int, - token: c_int, - ) -> Result { - Ok(unsafe { - whisper_rs_sys::whisper_full_get_token_id_from_state(self.ptr, segment, token) - }) - } - - /// Get token data for the specified token in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// [crate::WhisperTokenData] - /// - /// # C++ equivalent - /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` - #[inline] - pub fn full_get_token_data( - &self, - segment: c_int, - token: c_int, - ) -> Result { - Ok(unsafe { - whisper_rs_sys::whisper_full_get_token_data_from_state(self.ptr, segment, token) - }) - } - - /// Get the probability of the specified token in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// f32 - /// - /// # C++ equivalent - /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` - #[inline] - pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> Result { - Ok( - unsafe { - whisper_rs_sys::whisper_full_get_token_p_from_state(self.ptr, segment, token) - }, - ) - } - - /// Get whether the next segment is predicted as a speaker turn. - /// - /// # Arguments - /// * i_segment: Segment index. - /// - /// # Returns - /// bool - /// - /// # C++ equivalent - /// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)` - pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { - unsafe { - whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state( - self.ptr, i_segment, - ) - } - } - - /// Get the no_speech probability for the specified segment - pub fn full_get_segment_no_speech_prob(&self, i_segment: c_int) -> f32 { - unsafe { - whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state(self.ptr, i_segment) - } + /// Get an iterator over all segments. + pub fn as_iter(&self) -> WhisperStateSegmentIterator { + WhisperStateSegmentIterator::new(self) } } diff --git a/src/whisper_state/iterator.rs b/src/whisper_state/iterator.rs new file mode 100644 index 0000000..88d845f --- /dev/null +++ b/src/whisper_state/iterator.rs @@ -0,0 +1,28 @@ +use crate::whisper_state::WhisperSegment; +use crate::WhisperState; +use std::ffi::c_int; + +/// An iterator over a [`WhisperState`]'s result. +pub struct WhisperStateSegmentIterator<'a> { + state_ptr: &'a WhisperState, + current_segment: c_int, +} + +impl<'a> WhisperStateSegmentIterator<'a> { + pub(super) fn new(state_ptr: &'a WhisperState) -> Self { + Self { + state_ptr, + current_segment: 0, + } + } +} + +impl<'a> Iterator for WhisperStateSegmentIterator<'a> { + type Item = WhisperSegment<'a>; + + fn next(&mut self) -> Option { + let ret = self.state_ptr.get_segment(self.current_segment); + self.current_segment += 1; + ret + } +} diff --git a/src/whisper_state/segment.rs b/src/whisper_state/segment.rs new file mode 100644 index 0000000..d681ec9 --- /dev/null +++ b/src/whisper_state/segment.rs @@ -0,0 +1,341 @@ +use crate::{WhisperError, WhisperState, WhisperTokenData, WhisperTokenId}; +use std::borrow::Cow; +use std::ffi::{c_int, CStr}; +use std::fmt; + +/// A segment returned by Whisper after running the transcription pipeline. +pub struct WhisperSegment<'a> { + state: &'a WhisperState, + + segment_idx: c_int, + token_count: c_int, +} +impl<'a> WhisperSegment<'a> { + /// # Safety + /// You must ensure `segment_idx` is in bounds for the linked [`WhisperState`]. + pub(super) unsafe fn new_unchecked(state: &'a WhisperState, segment_idx: c_int) -> Self { + assert!( + state.segment_in_bounds(segment_idx), + "tried to create a WhisperSegment out of bounds for linked state" + ); + Self { + state, + segment_idx, + token_count: unsafe { + whisper_rs_sys::whisper_full_n_tokens_from_state(state.ptr, segment_idx) + }, + } + } + + /// Get the start time of the specified segment. + /// + /// # Returns + /// Start time in centiseconds (10s of milliseconds) + /// + /// # C++ equivalent + /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` + pub fn start_timestamp(&self) -> i64 { + unsafe { + whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.state.ptr, self.segment_idx) + } + } + + /// Get the end time of the specified segment. + /// + /// # Returns + /// End time in centiseconds (10s of milliseconds) + /// + /// # C++ equivalent + /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` + pub fn end_timestamp(&self) -> i64 { + unsafe { + whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.state.ptr, self.segment_idx) + } + } + + /// Get number of tokens in this segment. + /// + /// # Returns + /// `c_int` + /// + /// # C++ equivalent + /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` + pub fn n_tokens(&self) -> c_int { + self.token_count + } + + /// Get whether the next segment is predicted as a speaker turn. + /// + /// # Returns + /// `bool` + /// + /// # C++ equivalent + /// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)` + pub fn next_segment_speaker_turn(&self) -> bool { + unsafe { + whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state( + self.state.ptr, + self.segment_idx, + ) + } + } + + /// Get the no_speech probability for the specified segment + pub fn no_speech_probability(&self) -> f32 { + unsafe { + whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state( + self.state.ptr, + self.segment_idx, + ) + } + } + + fn to_raw_cstr(&self) -> Result<&CStr, WhisperError> { + let ret = unsafe { + whisper_rs_sys::whisper_full_get_segment_text_from_state( + self.state.ptr, + self.segment_idx, + ) + }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + Ok(unsafe { CStr::from_ptr(ret) }) + } + + /// Get the raw bytes of this segment. + /// + /// # Returns + /// * On success: The raw bytes, with no null terminator + /// * On failure: [`WhisperError::NullPointer`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` + pub fn to_bytes(&self) -> Result<&[u8], WhisperError> { + Ok(self.to_raw_cstr()?.to_bytes()) + } + + /// Get the text of this segment. + /// + /// # Returns + /// * On success: the UTF-8 validated string. + /// * On failure: [`WhisperError::NullPointer`] or [`WhisperError::InvalidUtf8`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` + pub fn to_str(&self) -> Result<&str, WhisperError> { + Ok(self.to_raw_cstr()?.to_str()?) + } + + /// Get the text of this segment. + /// + /// This function differs from [`Self::to_str`] + /// in that it ignores invalid UTF-8 in strings, + /// and instead replaces it with the replacement character. + /// + /// # Returns + /// * On success: The valid string, with any invalid UTF-8 replaced with the replacement character + /// * On failure: [`WhisperError::NullPointer`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` + pub fn to_str_lossy(&self) -> Result, WhisperError> { + Ok(self.to_raw_cstr()?.to_string_lossy()) + } + + fn token_in_bounds(&self, token_idx: c_int) -> bool { + token_idx >= 0 && token_idx < self.token_count + } + + pub fn get_token(&self, token: c_int) -> Option> { + self.token_in_bounds(token) + .then(|| unsafe { WhisperToken::new_unchecked(self, token) }) + } + + /// # Safety + /// You must ensure `token` is in bounds for this [`WhisperSegment`]. + pub unsafe fn get_token_unchecked(&self, token: c_int) -> WhisperToken<'_, '_> { + WhisperToken::new_unchecked(self, token) + } +} + +/// Write the contents of this segment to the output. +/// This will panic if Whisper returns a null pointer. +/// +/// Uses [`Self::to_str_lossy`] internally. +impl fmt::Display for WhisperSegment<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + self.to_str_lossy() + .expect("got null pointer during string write") + ) + } +} + +impl fmt::Debug for WhisperSegment<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WhisperSegment") + .field("segment", &self.segment_idx) + .field("n_tokens", &self.token_count) + .field("start_ts", &self.start_timestamp()) + .field("end_ts", &self.end_timestamp()) + .field( + "next_segment_speaker_turn", + &self.next_segment_speaker_turn(), + ) + .field("no_speech_probability", &self.no_speech_probability()) + .field("text", &self.to_str_lossy()) + .finish_non_exhaustive() + } +} + +pub struct WhisperToken<'a, 'b: 'a> { + segment: &'a WhisperSegment<'b>, + token_idx: c_int, +} + +impl<'a, 'b> WhisperToken<'a, 'b> { + /// # Safety + /// You must ensure `token_idx` is in bounds for this [`WhisperSegment`]. + unsafe fn new_unchecked(segment: &'a WhisperSegment<'b>, token_idx: c_int) -> Self { + Self { segment, token_idx } + } + + /// Get the token ID of this token in its segment. + /// + /// # Returns + /// [`WhisperTokenId`] + /// + /// # C++ equivalent + /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` + pub fn token_id(&self) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_token_id_from_state( + self.segment.state.ptr, + self.segment.segment_idx, + self.token_idx, + ) + }) + } + + /// Get token data for this token in its segment. + /// + /// # Returns + /// [`WhisperTokenData`] + /// + /// # C++ equivalent + /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn token_data(&self) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_token_data_from_state( + self.segment.state.ptr, + self.segment.segment_idx, + self.token_idx, + ) + }) + } + + /// Get the probability of this token in its segment. + /// + /// # Returns + /// `f32` + /// + /// # C++ equivalent + /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn token_probability(&self) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_token_p_from_state( + self.segment.state.ptr, + self.segment.segment_idx, + self.token_idx, + ) + }) + } + + fn to_raw_cstr(&self) -> Result<&CStr, WhisperError> { + let ret = unsafe { + whisper_rs_sys::whisper_full_get_token_text_from_state( + self.segment.state.ctx.ctx, + self.segment.state.ptr, + self.segment.segment_idx, + self.token_idx, + ) + }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + Ok(unsafe { CStr::from_ptr(ret) }) + } + + /// Get the raw bytes of this token. + /// + /// Useful if you're using a language for which Whisper is known to split tokens + /// away from UTF-8 character boundaries. + /// + /// # Returns + /// * On success: The raw bytes, with no null terminator + /// * On failure: [`WhisperError::NullPointer`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn to_bytes(&self) -> Result<&[u8], WhisperError> { + Ok(self.to_raw_cstr()?.to_bytes()) + } + + /// Get the text of this token. + /// + /// # Returns + /// * On success: the UTF-8 validated string. + /// * On failure: [`WhisperError::NullPointer`] or [`WhisperError::InvalidUtf8`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn to_str(&self) -> Result<&str, WhisperError> { + Ok(self.to_raw_cstr()?.to_str()?) + } + + /// Get the text of this token. + /// + /// This function differs from [`Self::to_str`] + /// in that it ignores invalid UTF-8 in strings, + /// and instead replaces it with the replacement character. + /// + /// # Returns + /// * On success: The valid string, with any invalid UTF-8 replaced with the replacement character + /// * On failure: [`WhisperError::NullPointer`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn to_str_lossy(&self) -> Result, WhisperError> { + Ok(self.to_raw_cstr()?.to_string_lossy()) + } +} + +/// Write the contents of this token to the output. +/// This will panic if Whisper returns a null pointer. +/// +/// Uses [`Self::to_str_lossy`] internally. +impl fmt::Display for WhisperToken<'_, '_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + self.to_str_lossy() + .expect("got null pointer during string write") + ) + } +} + +impl fmt::Debug for WhisperToken<'_, '_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WhisperToken") + .field("segment_idx", &self.segment.segment_idx) + .field("token_idx", &self.token_idx) + .field("token_id", &self.token_id()) + .field("token_data", &self.token_data()) + .field("token_probability", &self.token_probability()) + .finish() + } +}