From 37cc545ddd2c9700c1445bffa706b354d96b1344 Mon Sep 17 00:00:00 2001 From: Niko Date: Mon, 29 Sep 2025 12:40:24 -0700 Subject: [PATCH] Promote `WhisperToken` to its own module --- src/whisper_state/mod.rs | 4 +- src/whisper_state/segment.rs | 160 +++-------------------------------- src/whisper_state/token.rs | 153 +++++++++++++++++++++++++++++++++ 3 files changed, 166 insertions(+), 151 deletions(-) create mode 100644 src/whisper_state/token.rs diff --git a/src/whisper_state/mod.rs b/src/whisper_state/mod.rs index d4fae64..2210b1f 100644 --- a/src/whisper_state/mod.rs +++ b/src/whisper_state/mod.rs @@ -5,9 +5,11 @@ use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperTokenId}; mod iterator; mod segment; +mod token; pub use iterator::WhisperStateSegmentIterator; -pub use segment::{WhisperSegment, WhisperToken}; +pub use segment::WhisperSegment; +pub use token::WhisperToken; /// Rustified pointer to a Whisper state. #[derive(Debug)] diff --git a/src/whisper_state/segment.rs b/src/whisper_state/segment.rs index 1ed6273..7e41b93 100644 --- a/src/whisper_state/segment.rs +++ b/src/whisper_state/segment.rs @@ -1,4 +1,4 @@ -use crate::{WhisperError, WhisperState, WhisperTokenData, WhisperTokenId}; +use crate::{WhisperError, WhisperState, WhisperToken}; use std::borrow::Cow; use std::ffi::{c_int, CStr}; use std::fmt; @@ -27,6 +27,15 @@ impl<'a> WhisperSegment<'a> { } } + pub(super) fn get_state(&self) -> &WhisperState { + self.state + } + + /// Get the index of this segment. + pub fn segment_index(&self) -> c_int { + self.segment_idx + } + /// Get the start time of the specified segment. /// /// # Returns @@ -201,152 +210,3 @@ impl fmt::Debug for WhisperSegment<'_> { .finish_non_exhaustive() } } - -pub struct WhisperToken<'a, 'b: 'a> { - segment: &'a WhisperSegment<'b>, - token_idx: c_int, -} - -impl<'a, 'b> WhisperToken<'a, 'b> { - /// # Safety - /// You must ensure `token_idx` is in bounds for this [`WhisperSegment`]. - unsafe fn new_unchecked(segment: &'a WhisperSegment<'b>, token_idx: c_int) -> Self { - Self { segment, token_idx } - } - - /// Get the token ID of this token in its segment. - /// - /// # Returns - /// [`WhisperTokenId`] - /// - /// # C++ equivalent - /// `whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn token_id(&self) -> WhisperTokenId { - unsafe { - whisper_rs_sys::whisper_full_get_token_id_from_state( - self.segment.state.ptr, - self.segment.segment_idx, - self.token_idx, - ) - } - } - - /// Get token data for this token in its segment. - /// - /// # Returns - /// [`WhisperTokenData`] - /// - /// # C++ equivalent - /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn token_data(&self) -> WhisperTokenData { - unsafe { - whisper_rs_sys::whisper_full_get_token_data_from_state( - self.segment.state.ptr, - self.segment.segment_idx, - self.token_idx, - ) - } - } - - /// Get the probability of this token in its segment. - /// - /// # Returns - /// `f32` - /// - /// # C++ equivalent - /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn token_probability(&self) -> f32 { - unsafe { - whisper_rs_sys::whisper_full_get_token_p_from_state( - self.segment.state.ptr, - self.segment.segment_idx, - self.token_idx, - ) - } - } - - fn to_raw_cstr(&self) -> Result<&CStr, WhisperError> { - let ret = unsafe { - whisper_rs_sys::whisper_full_get_token_text_from_state( - self.segment.state.ctx.ctx, - self.segment.state.ptr, - self.segment.segment_idx, - self.token_idx, - ) - }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - Ok(unsafe { CStr::from_ptr(ret) }) - } - - /// Get the raw bytes of this token. - /// - /// Useful if you're using a language for which Whisper is known to split tokens - /// away from UTF-8 character boundaries. - /// - /// # Returns - /// * On success: The raw bytes, with no null terminator - /// * On failure: [`WhisperError::NullPointer`] - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn to_bytes(&self) -> Result<&[u8], WhisperError> { - Ok(self.to_raw_cstr()?.to_bytes()) - } - - /// Get the text of this token. - /// - /// # Returns - /// * On success: the UTF-8 validated string. - /// * On failure: [`WhisperError::NullPointer`] or [`WhisperError::InvalidUtf8`] - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn to_str(&self) -> Result<&str, WhisperError> { - Ok(self.to_raw_cstr()?.to_str()?) - } - - /// Get the text of this token. - /// - /// This function differs from [`Self::to_str`] - /// in that it ignores invalid UTF-8 in strings, - /// and instead replaces it with the replacement character. - /// - /// # Returns - /// * On success: The valid string, with any invalid UTF-8 replaced with the replacement character - /// * On failure: [`WhisperError::NullPointer`] - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn to_str_lossy(&self) -> Result, WhisperError> { - Ok(self.to_raw_cstr()?.to_string_lossy()) - } -} - -/// Write the contents of this token to the output. -/// This will panic if Whisper returns a null pointer. -/// -/// Uses [`Self::to_str_lossy`] internally. -impl fmt::Display for WhisperToken<'_, '_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "{}", - self.to_str_lossy() - .expect("got null pointer during string write") - ) - } -} - -impl fmt::Debug for WhisperToken<'_, '_> { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("WhisperToken") - .field("segment_idx", &self.segment.segment_idx) - .field("token_idx", &self.token_idx) - .field("token_id", &self.token_id()) - .field("token_data", &self.token_data()) - .field("token_probability", &self.token_probability()) - .finish() - } -} diff --git a/src/whisper_state/token.rs b/src/whisper_state/token.rs new file mode 100644 index 0000000..c57435d --- /dev/null +++ b/src/whisper_state/token.rs @@ -0,0 +1,153 @@ +use crate::{WhisperError, WhisperSegment, WhisperTokenData, WhisperTokenId}; +use std::borrow::Cow; +use std::ffi::{c_int, CStr}; +use std::fmt; + +pub struct WhisperToken<'a, 'b: 'a> { + segment: &'a WhisperSegment<'b>, + token_idx: c_int, +} + +impl<'a, 'b> WhisperToken<'a, 'b> { + /// # Safety + /// You must ensure `token_idx` is in bounds for this [`WhisperSegment`]. + pub(crate) unsafe fn new_unchecked(segment: &'a WhisperSegment<'b>, token_idx: c_int) -> Self { + Self { segment, token_idx } + } + + /// Get the token ID of this token in its segment. + /// + /// # Returns + /// [`WhisperTokenId`] + /// + /// # C++ equivalent + /// `whisper_token whisper_full_get_token_id(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn token_id(&self) -> WhisperTokenId { + unsafe { + whisper_rs_sys::whisper_full_get_token_id_from_state( + self.segment.get_state().ptr, + self.segment.segment_index(), + self.token_idx, + ) + } + } + + /// Get token data for this token in its segment. + /// + /// # Returns + /// [`WhisperTokenData`] + /// + /// # C++ equivalent + /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn token_data(&self) -> WhisperTokenData { + unsafe { + whisper_rs_sys::whisper_full_get_token_data_from_state( + self.segment.get_state().ptr, + self.segment.segment_index(), + self.token_idx, + ) + } + } + + /// Get the probability of this token in its segment. + /// + /// # Returns + /// `f32` + /// + /// # C++ equivalent + /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn token_probability(&self) -> f32 { + unsafe { + whisper_rs_sys::whisper_full_get_token_p_from_state( + self.segment.get_state().ptr, + self.segment.segment_index(), + self.token_idx, + ) + } + } + + fn to_raw_cstr(&self) -> Result<&CStr, WhisperError> { + let ret = unsafe { + whisper_rs_sys::whisper_full_get_token_text_from_state( + self.segment.get_state().ctx.ctx, + self.segment.get_state().ptr, + self.segment.segment_index(), + self.token_idx, + ) + }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + Ok(unsafe { CStr::from_ptr(ret) }) + } + + /// Get the raw bytes of this token. + /// + /// Useful if you're using a language for which Whisper is known to split tokens + /// away from UTF-8 character boundaries. + /// + /// # Returns + /// * On success: The raw bytes, with no null terminator + /// * On failure: [`WhisperError::NullPointer`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn to_bytes(&self) -> Result<&[u8], WhisperError> { + Ok(self.to_raw_cstr()?.to_bytes()) + } + + /// Get the text of this token. + /// + /// # Returns + /// * On success: the UTF-8 validated string. + /// * On failure: [`WhisperError::NullPointer`] or [`WhisperError::InvalidUtf8`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn to_str(&self) -> Result<&str, WhisperError> { + Ok(self.to_raw_cstr()?.to_str()?) + } + + /// Get the text of this token. + /// + /// This function differs from [`Self::to_str`] + /// in that it ignores invalid UTF-8 in strings, + /// and instead replaces it with the replacement character. + /// + /// # Returns + /// * On success: The valid string, with any invalid UTF-8 replaced with the replacement character + /// * On failure: [`WhisperError::NullPointer`] + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn to_str_lossy(&self) -> Result, WhisperError> { + Ok(self.to_raw_cstr()?.to_string_lossy()) + } +} + +/// Write the contents of this token to the output. +/// This will panic if Whisper returns a null pointer. +/// +/// Uses [`Self::to_str_lossy`] internally. +impl fmt::Display for WhisperToken<'_, '_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + self.to_str_lossy() + .expect("got null pointer during string write") + ) + } +} + +impl fmt::Debug for WhisperToken<'_, '_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("WhisperToken") + .field("segment_idx", &self.segment.segment_index()) + .field("token_idx", &self.token_idx) + .field("token_id", &self.token_id()) + .field("token_data", &self.token_data()) + .field("token_probability", &self.token_probability()) + .finish() + } +}