From 099faf4e2ee08ec6b2fb936907208a8ac0641f30 Mon Sep 17 00:00:00 2001 From: Niko Date: Tue, 18 Feb 2025 17:03:13 -0700 Subject: [PATCH] Convert `full_get_token_*` and similar to use an internal helper instead of duplicating code --- src/whisper_state.rs | 87 +++++++++++++++++++++++++++----------------- 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/src/whisper_state.rs b/src/whisper_state.rs index c82dd18..4a9c254 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -425,22 +425,7 @@ impl WhisperState { Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) }) } - /// Get the token text of the specified token in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_text( - &self, - segment: c_int, - token: c_int, - ) -> Result { + fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( self.ctx.ctx, @@ -452,9 +437,53 @@ impl WhisperState { if ret.is_null() { return Err(WhisperError::NullPointer); } - let c_str = unsafe { CStr::from_ptr(ret) }; - let r_str = c_str.to_str()?; - Ok(r_str.to_string()) + unsafe { Ok(CStr::from_ptr(ret)) } + } + + /// Get the raw token bytes of the specified token in the specified segment. + /// + /// Useful if you're using a language for which whisper is known to split tokens + /// away from UTF-8 character boundaries. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// `Ok(Vec)` on success, with the returned bytes or + /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_bytes( + &self, + segment: c_int, + token: c_int, + ) -> Result, WhisperError> { + Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec()) + } + + /// Get the token text of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// `Ok(String)` on success, with the UTF-8 validated string, or + /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_text( + &self, + segment: c_int, + token: c_int, + ) -> Result { + Ok(self + .full_get_token_raw(segment, token)? + .to_str()? + .to_string()) } /// Get the token text of the specified token in the specified segment. @@ -467,7 +496,8 @@ impl WhisperState { /// * token: Token index. /// /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. + /// `Ok(String)` on success, or + /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) /// /// # C++ equivalent /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` @@ -476,19 +506,10 @@ impl WhisperState { segment: c_int, token: c_int, ) -> Result { - let ret = unsafe { - whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx.ctx, - self.ptr, - segment, - token, - ) - }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - let c_str = unsafe { CStr::from_ptr(ret) }; - Ok(c_str.to_string_lossy().to_string()) + Ok(self + .full_get_token_raw(segment, token)? + .to_string_lossy() + .to_string()) } /// Get the token ID of the specified token in the specified segment.