From b07d20a3c135249cfe1261af0c10995728186173 Mon Sep 17 00:00:00 2001 From: Niko Date: Fri, 15 Aug 2025 23:37:43 -0700 Subject: [PATCH] Split out any functions that return a CStr or similar object into three returning all unique types of possible return value --- src/whisper_ctx.rs | 72 +++++++++++------------- src/whisper_ctx_wrapper.rs | 110 ++++++++++++++++++++++++++++++------- 2 files changed, 122 insertions(+), 60 deletions(-) diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 1f65624..8f76fbe 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -1,5 +1,6 @@ use crate::error::WhisperError; use crate::WhisperTokenId; +use std::borrow::Cow; use std::ffi::{c_int, CStr, CString}; /// Safe Rust wrapper around a Whisper context. @@ -280,54 +281,44 @@ impl WhisperInnerContext { unsafe { whisper_rs_sys::whisper_model_type(self.ctx) } } - // token functions - /// Convert a token ID to a string. - /// - /// # Arguments - /// * token_id: ID of the token. - /// - /// # Returns - /// Ok(&str) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> { - let c_str = self.token_to_cstr(token_id)?; - let r_str = c_str.to_str()?; - Ok(r_str) + // --- begin model_type_readable helpers --- + fn model_type_readable_cstr(&self) -> Result<&CStr, WhisperError> { + let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + Ok(unsafe { CStr::from_ptr(ret) }) } + pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> { + Ok(self.model_type_readable_cstr()?.to_bytes()) + } + pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> { + Ok(self.model_type_readable_cstr()?.to_str()?) + } + pub fn model_type_readable_str_lossy(&self) -> Result, WhisperError> { + Ok(self.model_type_readable_cstr()?.to_string_lossy()) + } + // --- end model_type_readable helpers --- - /// Convert a token ID to a &CStr. - /// - /// # Arguments - /// * token_id: ID of the token. - /// - /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> { + // --- begin token functions --- + fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) }; if ret.is_null() { return Err(WhisperError::NullPointer); } Ok(unsafe { CStr::from_ptr(ret) }) } - - /// Undocumented but exposed function in the C++ API. - /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` - /// - /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. - pub fn model_type_readable(&self) -> Result { - let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - let c_str = unsafe { CStr::from_ptr(ret) }; - let r_str = c_str.to_str()?; - Ok(r_str.to_string()) + pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> { + Ok(self.token_to_cstr(token_id)?.to_bytes()) + } + pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> { + Ok(self.token_to_cstr(token_id)?.to_str()?) + } + pub fn token_to_str_lossy( + &self, + token_id: WhisperTokenId, + ) -> Result, WhisperError> { + Ok(self.token_to_cstr(token_id)?.to_string_lossy()) } /// Get the ID of the eot token. @@ -396,6 +387,7 @@ impl WhisperInnerContext { pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId { unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) } } + // --- end token functions --- /// Print performance statistics to stderr. /// diff --git a/src/whisper_ctx_wrapper.rs b/src/whisper_ctx_wrapper.rs index 653fbb6..339d6cc 100644 --- a/src/whisper_ctx_wrapper.rs +++ b/src/whisper_ctx_wrapper.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::ffi::{c_int, CStr}; use std::sync::Arc; @@ -242,14 +243,80 @@ impl WhisperContext { self.ctx.model_type() } - // token functions - /// Convert a token ID to a string. - /// - /// # Arguments - /// * token_id: ID of the token. + // --- begin model_type_readable --- + /// Undocumented but exposed function in the C++ API. /// /// # Returns - /// Ok(&str) on success, Err(WhisperError) on failure. + /// * On success: `Ok(&[u8])` + /// * On error: `Err(WhisperError::NullPointer)` + /// + /// # C++ equivalent + /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` + pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> { + self.ctx.model_type_readable_bytes() + } + /// Undocumented but exposed function in the C++ API. + /// + /// # Returns + /// * On success: `Ok(&str)` + /// * On error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)` + /// + /// # C++ equivalent + /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` + pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> { + self.ctx.model_type_readable_str() + } + + /// Undocumented but exposed function in the C++ API. + /// + /// This function differs from [`Self::model_type_readable_str`] in that it ignores invalid UTF-8 bytes in the input, + /// and instead replaces them with the Unicode replacement character. + /// + /// # Returns + /// * On success: `Ok(Cow)` + /// * On error: `Err(WhisperError::NullPointer)` + /// + /// # C++ equivalent + /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` + pub fn model_type_readable_str_lossy(&self) -> Result, WhisperError> { + self.ctx.model_type_readable_str_lossy() + } + // --- end model_type_readable --- + + // --- begin token functions --- + /// Convert a token ID to a byte array. + /// + /// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index. + /// There is no way to check if your index is in bounds from Rust. + /// C++ exceptions *cannot* be caught and *will* cause the Rust runtime to abort your program. + /// Use this function and its siblings with extreme caution. + /// + /// # Arguments + /// * `token_id`: ID of the token. + /// + /// # Returns + /// * On success: `Ok(&[u8])` + /// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort. + /// * On other error: `Err(WhisperError::NullPointer)` + /// + /// # C++ equivalent + /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` + pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> { + self.ctx.token_to_bytes(token_id) + } + + /// Convert a token ID to a string. + /// + /// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index. + /// See [`Self::token_to_bytes`] for more information. + /// + /// # Arguments + /// * `token_id`: ID of the token. + /// + /// # Returns + /// * On success: `Ok(&str)` + /// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort. + /// * On other error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)` /// /// # C++ equivalent /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` @@ -257,27 +324,29 @@ impl WhisperContext { self.ctx.token_to_str(token_id) } - /// Convert a token ID to a &CStr. + /// Convert a token ID to a string. + /// + /// This function differs from [`Self::token_to_str`] in that it ignores invalid UTF-8 bytes in the input, + /// and instead replaces them with the Unicode replacement character. + /// + /// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index. + /// See [`Self::token_to_bytes`] for more information. /// /// # Arguments - /// * token_id: ID of the token. + /// * `token_id`: ID of the token. /// /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. + /// * On success: `Ok(Cow)` + /// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort. + /// * On other error: `Err(WhisperError::NullPointer)` /// /// # C++ equivalent /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` - pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> { - self.ctx.token_to_cstr(token_id) - } - - /// Undocumented but exposed function in the C++ API. - /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` - /// - /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. - pub fn model_type_readable(&self) -> Result { - self.ctx.model_type_readable() + pub fn token_to_str_lossy( + &self, + token_id: WhisperTokenId, + ) -> Result, WhisperError> { + self.ctx.token_to_str_lossy(token_id) } /// Get the ID of the eot token. @@ -346,6 +415,7 @@ impl WhisperContext { pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId { self.ctx.token_lang(lang_id) } + // --- end token functions --- /// Print performance statistics to stderr. ///