Convert full_get_token_* and similar to use an internal helper instead of duplicating code
This commit is contained in:
parent
37cba931f6
commit
099faf4e2e
1 changed files with 54 additions and 33 deletions
|
|
@ -425,22 +425,7 @@ impl WhisperState {
|
|||
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
|
||||
}
|
||||
|
||||
/// Get the token text of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_text(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<String, WhisperError> {
|
||||
fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> {
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||
self.ctx.ctx,
|
||||
|
|
@ -452,9 +437,53 @@ impl WhisperState {
|
|||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||
let r_str = c_str.to_str()?;
|
||||
Ok(r_str.to_string())
|
||||
unsafe { Ok(CStr::from_ptr(ret)) }
|
||||
}
|
||||
|
||||
/// Get the raw token bytes of the specified token in the specified segment.
|
||||
///
|
||||
/// Useful if you're using a language for which whisper is known to split tokens
|
||||
/// away from UTF-8 character boundaries.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(Vec<u8>)` on success, with the returned bytes or
|
||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_bytes(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<Vec<u8>, WhisperError> {
|
||||
Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec())
|
||||
}
|
||||
|
||||
/// Get the token text of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(String)` on success, with the UTF-8 validated string, or
|
||||
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_text(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<String, WhisperError> {
|
||||
Ok(self
|
||||
.full_get_token_raw(segment, token)?
|
||||
.to_str()?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Get the token text of the specified token in the specified segment.
|
||||
|
|
@ -467,7 +496,8 @@ impl WhisperState {
|
|||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
/// `Ok(String)` on success, or
|
||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
|
|
@ -476,19 +506,10 @@ impl WhisperState {
|
|||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<String, WhisperError> {
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
segment,
|
||||
token,
|
||||
)
|
||||
};
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||
Ok(c_str.to_string_lossy().to_string())
|
||||
Ok(self
|
||||
.full_get_token_raw(segment, token)?
|
||||
.to_string_lossy()
|
||||
.to_string())
|
||||
}
|
||||
|
||||
/// Get the token ID of the specified token in the specified segment.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue