Split out any functions that return a CStr or similar object into three returning all unique types of possible return value
This commit is contained in:
parent
4ed94c2831
commit
b07d20a3c1
2 changed files with 122 additions and 60 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::error::WhisperError;
|
use crate::error::WhisperError;
|
||||||
use crate::WhisperTokenId;
|
use crate::WhisperTokenId;
|
||||||
|
use std::borrow::Cow;
|
||||||
use std::ffi::{c_int, CStr, CString};
|
use std::ffi::{c_int, CStr, CString};
|
||||||
|
|
||||||
/// Safe Rust wrapper around a Whisper context.
|
/// Safe Rust wrapper around a Whisper context.
|
||||||
|
|
@ -280,54 +281,44 @@ impl WhisperInnerContext {
|
||||||
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
// token functions
|
// --- begin model_type_readable helpers ---
|
||||||
/// Convert a token ID to a string.
|
fn model_type_readable_cstr(&self) -> Result<&CStr, WhisperError> {
|
||||||
///
|
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
|
||||||
/// # Arguments
|
if ret.is_null() {
|
||||||
/// * token_id: ID of the token.
|
return Err(WhisperError::NullPointer);
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(&str) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
|
||||||
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
|
|
||||||
let c_str = self.token_to_cstr(token_id)?;
|
|
||||||
let r_str = c_str.to_str()?;
|
|
||||||
Ok(r_str)
|
|
||||||
}
|
}
|
||||||
|
Ok(unsafe { CStr::from_ptr(ret) })
|
||||||
|
}
|
||||||
|
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
|
||||||
|
Ok(self.model_type_readable_cstr()?.to_bytes())
|
||||||
|
}
|
||||||
|
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
|
||||||
|
Ok(self.model_type_readable_cstr()?.to_str()?)
|
||||||
|
}
|
||||||
|
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
|
||||||
|
Ok(self.model_type_readable_cstr()?.to_string_lossy())
|
||||||
|
}
|
||||||
|
// --- end model_type_readable helpers ---
|
||||||
|
|
||||||
/// Convert a token ID to a &CStr.
|
// --- begin token functions ---
|
||||||
///
|
fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
||||||
/// # Arguments
|
|
||||||
/// * token_id: ID of the token.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
|
||||||
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
|
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
|
||||||
if ret.is_null() {
|
if ret.is_null() {
|
||||||
return Err(WhisperError::NullPointer);
|
return Err(WhisperError::NullPointer);
|
||||||
}
|
}
|
||||||
Ok(unsafe { CStr::from_ptr(ret) })
|
Ok(unsafe { CStr::from_ptr(ret) })
|
||||||
}
|
}
|
||||||
|
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
|
||||||
/// Undocumented but exposed function in the C++ API.
|
Ok(self.token_to_cstr(token_id)?.to_bytes())
|
||||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
|
||||||
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
|
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
|
|
||||||
if ret.is_null() {
|
|
||||||
return Err(WhisperError::NullPointer);
|
|
||||||
}
|
}
|
||||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
|
||||||
let r_str = c_str.to_str()?;
|
Ok(self.token_to_cstr(token_id)?.to_str()?)
|
||||||
Ok(r_str.to_string())
|
}
|
||||||
|
pub fn token_to_str_lossy(
|
||||||
|
&self,
|
||||||
|
token_id: WhisperTokenId,
|
||||||
|
) -> Result<Cow<'_, str>, WhisperError> {
|
||||||
|
Ok(self.token_to_cstr(token_id)?.to_string_lossy())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the ID of the eot token.
|
/// Get the ID of the eot token.
|
||||||
|
|
@ -396,6 +387,7 @@ impl WhisperInnerContext {
|
||||||
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
|
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
|
||||||
}
|
}
|
||||||
|
// --- end token functions ---
|
||||||
|
|
||||||
/// Print performance statistics to stderr.
|
/// Print performance statistics to stderr.
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use std::borrow::Cow;
|
||||||
use std::ffi::{c_int, CStr};
|
use std::ffi::{c_int, CStr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
|
@ -242,14 +243,80 @@ impl WhisperContext {
|
||||||
self.ctx.model_type()
|
self.ctx.model_type()
|
||||||
}
|
}
|
||||||
|
|
||||||
// token functions
|
// --- begin model_type_readable ---
|
||||||
/// Convert a token ID to a string.
|
/// Undocumented but exposed function in the C++ API.
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * token_id: ID of the token.
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// Ok(&str) on success, Err(WhisperError) on failure.
|
/// * On success: `Ok(&[u8])`
|
||||||
|
/// * On error: `Err(WhisperError::NullPointer)`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||||
|
pub fn model_type_readable_bytes(&self) -> Result<&[u8], WhisperError> {
|
||||||
|
self.ctx.model_type_readable_bytes()
|
||||||
|
}
|
||||||
|
/// Undocumented but exposed function in the C++ API.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: `Ok(&str)`
|
||||||
|
/// * On error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||||
|
pub fn model_type_readable_str(&self) -> Result<&str, WhisperError> {
|
||||||
|
self.ctx.model_type_readable_str()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Undocumented but exposed function in the C++ API.
|
||||||
|
///
|
||||||
|
/// This function differs from [`Self::model_type_readable_str`] in that it ignores invalid UTF-8 bytes in the input,
|
||||||
|
/// and instead replaces them with the Unicode replacement character.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: `Ok(Cow<str>)`
|
||||||
|
/// * On error: `Err(WhisperError::NullPointer)`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||||
|
pub fn model_type_readable_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
|
||||||
|
self.ctx.model_type_readable_str_lossy()
|
||||||
|
}
|
||||||
|
// --- end model_type_readable ---
|
||||||
|
|
||||||
|
// --- begin token functions ---
|
||||||
|
/// Convert a token ID to a byte array.
|
||||||
|
///
|
||||||
|
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
|
||||||
|
/// There is no way to check if your index is in bounds from Rust.
|
||||||
|
/// C++ exceptions *cannot* be caught and *will* cause the Rust runtime to abort your program.
|
||||||
|
/// Use this function and its siblings with extreme caution.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `token_id`: ID of the token.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: `Ok(&[u8])`
|
||||||
|
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
|
||||||
|
/// * On other error: `Err(WhisperError::NullPointer)`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
|
pub fn token_to_bytes(&self, token_id: WhisperTokenId) -> Result<&[u8], WhisperError> {
|
||||||
|
self.ctx.token_to_bytes(token_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a token ID to a string.
|
||||||
|
///
|
||||||
|
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
|
||||||
|
/// See [`Self::token_to_bytes`] for more information.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `token_id`: ID of the token.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: `Ok(&str)`
|
||||||
|
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
|
||||||
|
/// * On other error: `Err(WhisperError::NullPointer)` or `Err(WhisperError::InvalidUtf8)`
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
|
|
@ -257,27 +324,29 @@ impl WhisperContext {
|
||||||
self.ctx.token_to_str(token_id)
|
self.ctx.token_to_str(token_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert a token ID to a &CStr.
|
/// Convert a token ID to a string.
|
||||||
|
///
|
||||||
|
/// This function differs from [`Self::token_to_str`] in that it ignores invalid UTF-8 bytes in the input,
|
||||||
|
/// and instead replaces them with the Unicode replacement character.
|
||||||
|
///
|
||||||
|
/// **Danger**: this function is liable to throw a C++ exception if you pass an out-of-bounds index.
|
||||||
|
/// See [`Self::token_to_bytes`] for more information.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * token_id: ID of the token.
|
/// * `token_id`: ID of the token.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
/// * On success: `Ok(Cow<str>)`
|
||||||
|
/// * On out-of-bounds index: foreign runtime exception, causing your entire program to abort.
|
||||||
|
/// * On other error: `Err(WhisperError::NullPointer)`
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
pub fn token_to_str_lossy(
|
||||||
self.ctx.token_to_cstr(token_id)
|
&self,
|
||||||
}
|
token_id: WhisperTokenId,
|
||||||
|
) -> Result<Cow<'_, str>, WhisperError> {
|
||||||
/// Undocumented but exposed function in the C++ API.
|
self.ctx.token_to_str_lossy(token_id)
|
||||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
|
||||||
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
|
|
||||||
self.ctx.model_type_readable()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the ID of the eot token.
|
/// Get the ID of the eot token.
|
||||||
|
|
@ -346,6 +415,7 @@ impl WhisperContext {
|
||||||
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
||||||
self.ctx.token_lang(lang_id)
|
self.ctx.token_lang(lang_id)
|
||||||
}
|
}
|
||||||
|
// --- end token functions ---
|
||||||
|
|
||||||
/// Print performance statistics to stderr.
|
/// Print performance statistics to stderr.
|
||||||
///
|
///
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue