Split out any functions that return a CStr or similar object into three returning all unique types of possible return value

This commit is contained in:
Niko 2025-08-15 23:37:43 -07:00
parent 4ed94c2831
commit b07d20a3c1
No known key found for this signature in database
2 changed files with 122 additions and 60 deletions

View file

@ -1,5 +1,6 @@
use crate::error::WhisperError; use crate::error::WhisperError;
use crate::WhisperTokenId; use crate::WhisperTokenId;
use std::borrow::Cow;
use std::ffi::{c_int, CStr, CString}; use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context. /// Safe Rust wrapper around a Whisper context.
@ -280,54 +281,44 @@ impl WhisperInnerContext {
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) } unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
} }
// token functions // --- begin model_type_readable helpers ---
/// Convert a token ID to a string. fn model_type_readable_cstr(&self) -> Result<&CStr, WhisperError> {
/// let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
/// # Arguments if ret.is_null() {
/// * token_id: ID of the token. return Err(WhisperError::NullPointer);
///
/// # Returns
/// Ok(&str) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
let c_str = self.token_to_cstr(token_id)?;
let r_str = c_str.to_str()?;
Ok(r_str)
} }
Ok(unsafe { CStr::from_ptr(ret) })
}
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
Ok(self.model_type_readable_cstr()?.to_bytes())
}
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
Ok(self.model_type_readable_cstr()?.to_str()?)
}
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
Ok(self.model_type_readable_cstr()?.to_string_lossy())
}
// --- end model_type_readable helpers ---
/// Convert a token ID to a &CStr. // --- begin token functions ---
/// fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
/// # Arguments
/// * token_id: ID of the token.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) }; let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
Ok(unsafe { CStr::from_ptr(ret) }) Ok(unsafe { CStr::from_ptr(ret) })
} }
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
/// Undocumented but exposed function in the C++ API. Ok(self.token_to_cstr(token_id)?.to_bytes())
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
} }
let c_str = unsafe { CStr::from_ptr(ret) }; pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
let r_str = c_str.to_str()?; Ok(self.token_to_cstr(token_id)?.to_str()?)
Ok(r_str.to_string()) }
pub fn token_to_str_lossy(
&self,
token_id: WhisperTokenId,
) -> Result<Cow<'_, str>, WhisperError> {
Ok(self.token_to_cstr(token_id)?.to_string_lossy())
} }
/// Get the ID of the eot token. /// Get the ID of the eot token.
@ -396,6 +387,7 @@ impl WhisperInnerContext {
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId { pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) } unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
} }
// --- end token functions ---
/// Print performance statistics to stderr. /// Print performance statistics to stderr.
/// ///

View file

@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::ffi::{c_int, CStr}; use std::ffi::{c_int, CStr};
use std::sync::Arc; use std::sync::Arc;
@ -242,14 +243,80 @@ impl WhisperContext {
self.ctx.model_type() self.ctx.model_type()
} }
// token functions // --- begin model_type_readable ---
/// Convert a token ID to a string. /// Undocumented but exposed function in the C++ API.
///
/// # Arguments
/// * token_id: ID of the token.
/// ///
/// # Returns /// # Returns
/// Ok(&str) on success, Err(WhisperError) on failure. /// * On success: `Ok(&[u8])`
/// * On error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
self.ctx.model_type_readable_bytes()
}
/// Undocumented but exposed function in the C++ API.
///
/// # Returns
/// * On success: `Ok(&str)`
/// * On error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
///
/// # C++ equivalent
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
self.ctx.model_type_readable_str()
}
/// Undocumented but exposed function in the C++ API.
///
/// This function differs from [`Self::model_type_readable_str`] in that it ignores invalid UTF-8 bytes in the input,
/// and instead replaces them with the Unicode replacement character.
///
/// # Returns
/// * On success: `Ok(Cow<str>)`
/// * On error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
self.ctx.model_type_readable_str_lossy()
}
// --- end model_type_readable ---
// --- begin token functions ---
/// Convert a token ID to a byte array.
///
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
/// There is no way to check if your index is in bounds from Rust.
/// C++ exceptions *cannot* be caught and *will* cause the Rust runtime to abort your program.
/// Use this function and its siblings with extreme caution.
///
/// # Arguments
/// * `token_id`: ID of the token.
///
/// # Returns
/// * On success: `Ok(&[u8])`
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
/// * On other error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
self.ctx.token_to_bytes(token_id)
}
/// Convert a token ID to a string.
///
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
/// See [`Self::token_to_bytes`] for more information.
///
/// # Arguments
/// * `token_id`: ID of the token.
///
/// # Returns
/// * On success: `Ok(&str)`
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
/// * On other error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
@ -257,27 +324,29 @@ impl WhisperContext {
self.ctx.token_to_str(token_id) self.ctx.token_to_str(token_id)
} }
/// Convert a token ID to a &CStr. /// Convert a token ID to a string.
///
/// This function differs from [`Self::token_to_str`] in that it ignores invalid UTF-8 bytes in the input,
/// and instead replaces them with the Unicode replacement character.
///
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
/// See [`Self::token_to_bytes`] for more information.
/// ///
/// # Arguments /// # Arguments
/// * token_id: ID of the token. /// * `token_id`: ID of the token.
/// ///
/// # Returns /// # Returns
/// Ok(String) on success, Err(WhisperError) on failure. /// * On success: `Ok(Cow<str>)`
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
/// * On other error: `Err(WhisperError::NullPointer)`
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> { pub fn token_to_str_lossy(
self.ctx.token_to_cstr(token_id) &self,
} token_id: WhisperTokenId,
) -> Result<Cow<'_, str>, WhisperError> {
/// Undocumented but exposed function in the C++ API. self.ctx.token_to_str_lossy(token_id)
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
self.ctx.model_type_readable()
} }
/// Get the ID of the eot token. /// Get the ID of the eot token.
@ -346,6 +415,7 @@ impl WhisperContext {
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId { pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
self.ctx.token_lang(lang_id) self.ctx.token_lang(lang_id)
} }
// --- end token functions ---
/// Print performance statistics to stderr. /// Print performance statistics to stderr.
/// ///