Split out any functions that return a CStr or similar object into three returning all unique types of possible return value
This commit is contained in:
parent
4ed94c2831
commit
b07d20a3c1
2 changed files with 122 additions and 60 deletions
|
|
@ -1,5 +1,6 @@
|
|||
use crate::error::WhisperError;
|
||||
use crate::WhisperTokenId;
|
||||
use std::borrow::Cow;
|
||||
use std::ffi::{c_int, CStr, CString};
|
||||
|
||||
/// Safe Rust wrapper around a Whisper context.
|
||||
|
|
@ -280,54 +281,44 @@ impl WhisperInnerContext {
|
|||
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
|
||||
}
|
||||
|
||||
// token functions
|
||||
/// Convert a token ID to a string.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * token_id: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(&str) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
|
||||
let c_str = self.token_to_cstr(token_id)?;
|
||||
let r_str = c_str.to_str()?;
|
||||
Ok(r_str)
|
||||
// --- begin model_type_readable helpers ---
|
||||
fn model_type_readable_cstr(&self) -> Result<&CStr, WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
Ok(unsafe { CStr::from_ptr(ret) })
|
||||
}
|
||||
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
|
||||
Ok(self.model_type_readable_cstr()?.to_bytes())
|
||||
}
|
||||
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
|
||||
Ok(self.model_type_readable_cstr()?.to_str()?)
|
||||
}
|
||||
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
|
||||
Ok(self.model_type_readable_cstr()?.to_string_lossy())
|
||||
}
|
||||
// --- end model_type_readable helpers ---
|
||||
|
||||
/// Convert a token ID to a &CStr.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * token_id: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
||||
// --- begin token functions ---
|
||||
fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
Ok(unsafe { CStr::from_ptr(ret) })
|
||||
}
|
||||
|
||||
/// Undocumented but exposed function in the C++ API.
|
||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||
let r_str = c_str.to_str()?;
|
||||
Ok(r_str.to_string())
|
||||
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
|
||||
Ok(self.token_to_cstr(token_id)?.to_bytes())
|
||||
}
|
||||
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
|
||||
Ok(self.token_to_cstr(token_id)?.to_str()?)
|
||||
}
|
||||
pub fn token_to_str_lossy(
|
||||
&self,
|
||||
token_id: WhisperTokenId,
|
||||
) -> Result<Cow<'_, str>, WhisperError> {
|
||||
Ok(self.token_to_cstr(token_id)?.to_string_lossy())
|
||||
}
|
||||
|
||||
/// Get the ID of the eot token.
|
||||
|
|
@ -396,6 +387,7 @@ impl WhisperInnerContext {
|
|||
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
||||
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
|
||||
}
|
||||
// --- end token functions ---
|
||||
|
||||
/// Print performance statistics to stderr.
|
||||
///
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use std::borrow::Cow;
|
||||
use std::ffi::{c_int, CStr};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
@ -242,14 +243,80 @@ impl WhisperContext {
|
|||
self.ctx.model_type()
|
||||
}
|
||||
|
||||
// token functions
|
||||
/// Convert a token ID to a string.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * token_id: ID of the token.
|
||||
// --- begin model_type_readable ---
|
||||
/// Undocumented but exposed function in the C++ API.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(&str) on success, Err(WhisperError) on failure.
|
||||
/// * On success: `Ok(&[u8])`
|
||||
/// * On error: `Err(WhisperError::NullPointer)`
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
|
||||
self.ctx.model_type_readable_bytes()
|
||||
}
|
||||
/// Undocumented but exposed function in the C++ API.
|
||||
///
|
||||
/// # Returns
|
||||
/// * On success: `Ok(&str)`
|
||||
/// * On error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
|
||||
self.ctx.model_type_readable_str()
|
||||
}
|
||||
|
||||
/// Undocumented but exposed function in the C++ API.
|
||||
///
|
||||
/// This function differs from [`Self::model_type_readable_str`] in that it ignores invalid UTF-8 bytes in the input,
|
||||
/// and instead replaces them with the Unicode replacement character.
|
||||
///
|
||||
/// # Returns
|
||||
/// * On success: `Ok(Cow<str>)`
|
||||
/// * On error: `Err(WhisperError::NullPointer)`
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
|
||||
self.ctx.model_type_readable_str_lossy()
|
||||
}
|
||||
// --- end model_type_readable ---
|
||||
|
||||
// --- begin token functions ---
|
||||
/// Convert a token ID to a byte array.
|
||||
///
|
||||
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
|
||||
/// There is no way to check if your index is in bounds from Rust.
|
||||
/// C++ exceptions *cannot* be caught and *will* cause the Rust runtime to abort your program.
|
||||
/// Use this function and its siblings with extreme caution.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token_id`: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// * On success: `Ok(&[u8])`
|
||||
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
|
||||
/// * On other error: `Err(WhisperError::NullPointer)`
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
|
||||
self.ctx.token_to_bytes(token_id)
|
||||
}
|
||||
|
||||
/// Convert a token ID to a string.
|
||||
///
|
||||
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
|
||||
/// See [`Self::token_to_bytes`] for more information.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `token_id`: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// * On success: `Ok(&str)`
|
||||
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
|
||||
/// * On other error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
|
|
@ -257,27 +324,29 @@ impl WhisperContext {
|
|||
self.ctx.token_to_str(token_id)
|
||||
}
|
||||
|
||||
/// Convert a token ID to a &CStr.
|
||||
/// Convert a token ID to a string.
|
||||
///
|
||||
/// This function differs from [`Self::token_to_str`] in that it ignores invalid UTF-8 bytes in the input,
|
||||
/// and instead replaces them with the Unicode replacement character.
|
||||
///
|
||||
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
|
||||
/// See [`Self::token_to_bytes`] for more information.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * token_id: ID of the token.
|
||||
/// * `token_id`: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
/// * On success: `Ok(Cow<str>)`
|
||||
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
|
||||
/// * On other error: `Err(WhisperError::NullPointer)`
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
||||
self.ctx.token_to_cstr(token_id)
|
||||
}
|
||||
|
||||
/// Undocumented but exposed function in the C++ API.
|
||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
|
||||
self.ctx.model_type_readable()
|
||||
pub fn token_to_str_lossy(
|
||||
&self,
|
||||
token_id: WhisperTokenId,
|
||||
) -> Result<Cow<'_, str>, WhisperError> {
|
||||
self.ctx.token_to_str_lossy(token_id)
|
||||
}
|
||||
|
||||
/// Get the ID of the eot token.
|
||||
|
|
@ -346,6 +415,7 @@ impl WhisperContext {
|
|||
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
||||
self.ctx.token_lang(lang_id)
|
||||
}
|
||||
// --- end token functions ---
|
||||
|
||||
/// Print performance statistics to stderr.
|
||||
///
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue