Split out any functions that return a CStr or similar object into three returning all unique types of possible return value

This commit is contained in:
Niko 2025-08-15 23:37:43 -07:00
parent 4ed94c2831
commit b07d20a3c1
No known key found for this signature in database
2 changed files with 122 additions and 60 deletions

View file

@ -1,5 +1,6 @@
use crate::error::WhisperError;
use crate::WhisperTokenId;
use std::borrow::Cow;
use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context.
@ -280,54 +281,44 @@ impl WhisperInnerContext {
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
}
// token functions
/// Convert a token ID to a string.
///
/// # Arguments
/// * token_id: ID of the token.
///
/// # Returns
/// Ok(&str) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
let c_str = self.token_to_cstr(token_id)?;
let r_str = c_str.to_str()?;
Ok(r_str)
// --- begin model_type_readable helpers ---
fn model_type_readable_cstr(&self) -> Result<&CStr, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
Ok(unsafe { CStr::from_ptr(ret) })
}
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
Ok(self.model_type_readable_cstr()?.to_bytes())
}
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
Ok(self.model_type_readable_cstr()?.to_str()?)
}
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
Ok(self.model_type_readable_cstr()?.to_string_lossy())
}
// --- end model_type_readable helpers ---
/// Convert a token ID to a &CStr.
///
/// # Arguments
/// * token_id: ID of the token.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
// --- begin token functions ---
fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
Ok(unsafe { CStr::from_ptr(ret) })
}
/// Undocumented but exposed function in the C++ API.
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
Ok(self.token_to_cstr(token_id)?.to_bytes())
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
Ok(self.token_to_cstr(token_id)?.to_str()?)
}
pub fn token_to_str_lossy(
&self,
token_id: WhisperTokenId,
) -> Result<Cow<'_, str>, WhisperError> {
Ok(self.token_to_cstr(token_id)?.to_string_lossy())
}
/// Get the ID of the eot token.
@ -396,6 +387,7 @@ impl WhisperInnerContext {
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
}
// --- end token functions ---
/// Print performance statistics to stderr.
///

View file

@ -1,3 +1,4 @@
use std::borrow::Cow;
use std::ffi::{c_int, CStr};
use std::sync::Arc;
@ -242,14 +243,80 @@ impl WhisperContext {
self.ctx.model_type()
}
// token functions
/// Convert a token ID to a string.
///
/// # Arguments
/// * token_id: ID of the token.
// --- begin model_type_readable ---
/// Undocumented but exposed function in the C++ API.
///
/// # Returns
/// Ok(&str) on success, Err(WhisperError) on failure.
/// * On success: `Ok(&[u8])`
/// * On error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
self.ctx.model_type_readable_bytes()
}
/// Undocumented but exposed function in the C++ API.
///
/// # Returns
/// * On success: `Ok(&str)`
/// * On error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
///
/// # C++ equivalent
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
self.ctx.model_type_readable_str()
}
/// Undocumented but exposed function in the C++ API.
///
/// This function differs from [`Self::model_type_readable_str`] in that it ignores invalid UTF-8 bytes in the input,
/// and instead replaces them with the Unicode replacement character.
///
/// # Returns
/// * On success: `Ok(Cow<str>)`
/// * On error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
self.ctx.model_type_readable_str_lossy()
}
// --- end model_type_readable ---
// --- begin token functions ---
/// Convert a token ID to a byte array.
///
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
/// There is no way to check if your index is in bounds from Rust.
/// C++ exceptions *cannot* be caught and *will* cause the Rust runtime to abort your program.
/// Use this function and its siblings with extreme caution.
///
/// # Arguments
/// * `token_id`: ID of the token.
///
/// # Returns
/// * On success: `Ok(&[u8])`
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
/// * On other error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
self.ctx.token_to_bytes(token_id)
}
/// Convert a token ID to a string.
///
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
/// See [`Self::token_to_bytes`] for more information.
///
/// # Arguments
/// * `token_id`: ID of the token.
///
/// # Returns
/// * On success: `Ok(&str)`
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
/// * On other error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
@ -257,27 +324,29 @@ impl WhisperContext {
self.ctx.token_to_str(token_id)
}
/// Convert a token ID to a &CStr.
/// Convert a token ID to a string.
///
/// This function differs from [`Self::token_to_str`] in that it ignores invalid UTF-8 bytes in the input,
/// and instead replaces them with the Unicode replacement character.
///
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
/// See [`Self::token_to_bytes`] for more information.
///
/// # Arguments
/// * token_id: ID of the token.
/// * `token_id`: ID of the token.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
/// * On success: `Ok(Cow<str>)`
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
/// * On other error: `Err(WhisperError::NullPointer)`
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
self.ctx.token_to_cstr(token_id)
}
/// Undocumented but exposed function in the C++ API.
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
self.ctx.model_type_readable()
pub fn token_to_str_lossy(
&self,
token_id: WhisperTokenId,
) -> Result<Cow<'_, str>, WhisperError> {
self.ctx.token_to_str_lossy(token_id)
}
/// Get the ID of the eot token.
@ -346,6 +415,7 @@ impl WhisperContext {
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
self.ctx.token_lang(lang_id)
}
// --- end token functions ---
/// Print performance statistics to stderr.
///