Refactor the entire token/segment usage
This was spurred by noticing a trivial case of UB in the original code: all one needed was an out-of-bounds index on any of several methods with tokens or segment indexes on the state to cause UB. I took this opportunity to consolidate methods into Rust structs that verify their index before use.
This commit is contained in:
parent
35ccb56e7f
commit
15e70ffd07
6 changed files with 432 additions and 305 deletions
|
|
@ -31,14 +31,14 @@ pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
||||||
pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData};
|
pub use whisper_params::{FullParams, SamplingStrategy, SegmentCallbackData};
|
||||||
#[cfg(feature = "raw-api")]
|
#[cfg(feature = "raw-api")]
|
||||||
pub use whisper_rs_sys;
|
pub use whisper_rs_sys;
|
||||||
pub use whisper_state::WhisperState;
|
pub use whisper_state::{WhisperSegment, WhisperState, WhisperStateSegmentIterator, WhisperToken};
|
||||||
pub use whisper_vad::*;
|
pub use whisper_vad::*;
|
||||||
|
|
||||||
pub type WhisperSysContext = whisper_rs_sys::whisper_context;
|
pub type WhisperSysContext = whisper_rs_sys::whisper_context;
|
||||||
pub type WhisperSysState = whisper_rs_sys::whisper_state;
|
pub type WhisperSysState = whisper_rs_sys::whisper_state;
|
||||||
|
|
||||||
pub type WhisperTokenData = whisper_rs_sys::whisper_token_data;
|
pub type WhisperTokenData = whisper_rs_sys::whisper_token_data;
|
||||||
pub type WhisperToken = whisper_rs_sys::whisper_token;
|
pub type WhisperTokenId = whisper_rs_sys::whisper_token;
|
||||||
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
|
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
|
||||||
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
|
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
|
||||||
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
|
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::error::WhisperError;
|
use crate::error::WhisperError;
|
||||||
use crate::WhisperToken;
|
use crate::WhisperTokenId;
|
||||||
use std::ffi::{c_int, CStr, CString};
|
use std::ffi::{c_int, CStr, CString};
|
||||||
|
|
||||||
/// Safe Rust wrapper around a Whisper context.
|
/// Safe Rust wrapper around a Whisper context.
|
||||||
|
|
@ -84,12 +84,12 @@ impl WhisperInnerContext {
|
||||||
&self,
|
&self,
|
||||||
text: &str,
|
text: &str,
|
||||||
max_tokens: usize,
|
max_tokens: usize,
|
||||||
) -> Result<Vec<WhisperToken>, WhisperError> {
|
) -> Result<Vec<WhisperTokenId>, WhisperError> {
|
||||||
// convert the text to a nul-terminated C string. Will raise an error if the text contains
|
// convert the text to a nul-terminated C string. Will raise an error if the text contains
|
||||||
// any nul bytes.
|
// any nul bytes.
|
||||||
let text = CString::new(text)?;
|
let text = CString::new(text)?;
|
||||||
// allocate at least max_tokens to ensure the memory is valid
|
// allocate at least max_tokens to ensure the memory is valid
|
||||||
let mut tokens: Vec<WhisperToken> = Vec::with_capacity(max_tokens);
|
let mut tokens: Vec<WhisperTokenId> = Vec::with_capacity(max_tokens);
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_tokenize(
|
whisper_rs_sys::whisper_tokenize(
|
||||||
self.ctx,
|
self.ctx,
|
||||||
|
|
@ -307,7 +307,7 @@ impl WhisperInnerContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> {
|
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
|
||||||
let c_str = self.token_to_cstr(token_id)?;
|
let c_str = self.token_to_cstr(token_id)?;
|
||||||
let r_str = c_str.to_str()?;
|
let r_str = c_str.to_str()?;
|
||||||
Ok(r_str)
|
Ok(r_str)
|
||||||
|
|
@ -323,7 +323,7 @@ impl WhisperInnerContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> {
|
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
|
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
|
||||||
if ret.is_null() {
|
if ret.is_null() {
|
||||||
return Err(WhisperError::NullPointer);
|
return Err(WhisperError::NullPointer);
|
||||||
|
|
@ -351,7 +351,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_eot(&self) -> WhisperToken {
|
pub fn token_eot(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_eot(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_eot(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -360,7 +360,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_sot(&self) -> WhisperToken {
|
pub fn token_sot(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_sot(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_sot(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -369,7 +369,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_solm(&self) -> WhisperToken {
|
pub fn token_solm(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_solm(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_solm(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -378,7 +378,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_prev(&self) -> WhisperToken {
|
pub fn token_prev(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_prev(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_prev(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -387,7 +387,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_nosp(struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_nosp(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_nosp(&self) -> WhisperToken {
|
pub fn token_nosp(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_nosp(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_nosp(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -396,7 +396,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_not(&self) -> WhisperToken {
|
pub fn token_not(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_not(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_not(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -405,7 +405,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_beg(&self) -> WhisperToken {
|
pub fn token_beg(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -417,7 +417,7 @@ impl WhisperInnerContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
|
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_lang(&self, lang_id: c_int) -> WhisperToken {
|
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
|
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -444,7 +444,7 @@ impl WhisperInnerContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_translate ()`
|
/// `whisper_token whisper_token_translate ()`
|
||||||
pub fn token_translate(&self) -> WhisperToken {
|
pub fn token_translate(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_translate(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_translate(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -452,7 +452,7 @@ impl WhisperInnerContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_transcribe()`
|
/// `whisper_token whisper_token_transcribe()`
|
||||||
pub fn token_transcribe(&self) -> WhisperToken {
|
pub fn token_transcribe(&self) -> WhisperTokenId {
|
||||||
unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ use std::ffi::{c_int, CStr};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperToken,
|
WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperTokenId,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct WhisperContext {
|
pub struct WhisperContext {
|
||||||
|
|
@ -57,7 +57,7 @@ impl WhisperContext {
|
||||||
/// * text: The text to convert.
|
/// * text: The text to convert.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// `Ok(Vec<WhisperToken>)` on success, `Err(WhisperError)` on failure.
|
/// `Ok(Vec<WhisperTokenId>)` on success, `Err(WhisperError)` on failure.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
|
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
|
||||||
|
|
@ -65,7 +65,7 @@ impl WhisperContext {
|
||||||
&self,
|
&self,
|
||||||
text: &str,
|
text: &str,
|
||||||
max_tokens: usize,
|
max_tokens: usize,
|
||||||
) -> Result<Vec<WhisperToken>, WhisperError> {
|
) -> Result<Vec<WhisperTokenId>, WhisperError> {
|
||||||
self.ctx.tokenize(text, max_tokens)
|
self.ctx.tokenize(text, max_tokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -269,7 +269,7 @@ impl WhisperContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> {
|
pub fn token_to_str(&self, token_id: WhisperTokenId) -> Result<&str, WhisperError> {
|
||||||
self.ctx.token_to_str(token_id)
|
self.ctx.token_to_str(token_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -283,7 +283,7 @@ impl WhisperContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||||
pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> {
|
pub fn token_to_cstr(&self, token_id: WhisperTokenId) -> Result<&CStr, WhisperError> {
|
||||||
self.ctx.token_to_cstr(token_id)
|
self.ctx.token_to_cstr(token_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -301,7 +301,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_eot(&self) -> WhisperToken {
|
pub fn token_eot(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_eot()
|
self.ctx.token_eot()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -310,7 +310,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_sot(&self) -> WhisperToken {
|
pub fn token_sot(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_sot()
|
self.ctx.token_sot()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -319,7 +319,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_solm(&self) -> WhisperToken {
|
pub fn token_solm(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_solm()
|
self.ctx.token_solm()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -328,7 +328,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_prev(&self) -> WhisperToken {
|
pub fn token_prev(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_prev()
|
self.ctx.token_prev()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -337,7 +337,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_nosp(struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_nosp(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_nosp(&self) -> WhisperToken {
|
pub fn token_nosp(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_nosp()
|
self.ctx.token_nosp()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -346,7 +346,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_not(&self) -> WhisperToken {
|
pub fn token_not(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_not()
|
self.ctx.token_not()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -355,7 +355,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
|
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_beg(&self) -> WhisperToken {
|
pub fn token_beg(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_beg()
|
self.ctx.token_beg()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -367,7 +367,7 @@ impl WhisperContext {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
|
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn token_lang(&self, lang_id: c_int) -> WhisperToken {
|
pub fn token_lang(&self, lang_id: c_int) -> WhisperTokenId {
|
||||||
self.ctx.token_lang(lang_id)
|
self.ctx.token_lang(lang_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -394,7 +394,7 @@ impl WhisperContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_translate ()`
|
/// `whisper_token whisper_token_translate ()`
|
||||||
pub fn token_translate(&self) -> WhisperToken {
|
pub fn token_translate(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_translate()
|
self.ctx.token_translate()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -402,7 +402,7 @@ impl WhisperContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_token_transcribe()`
|
/// `whisper_token whisper_token_transcribe()`
|
||||||
pub fn token_transcribe(&self) -> WhisperToken {
|
pub fn token_transcribe(&self) -> WhisperTokenId {
|
||||||
self.ctx.token_transcribe()
|
self.ctx.token_transcribe()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,13 @@
|
||||||
use std::ffi::{c_int, CStr};
|
use std::ffi::c_int;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData};
|
use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperTokenId};
|
||||||
|
|
||||||
|
mod iterator;
|
||||||
|
mod segment;
|
||||||
|
|
||||||
|
pub use iterator::WhisperStateSegmentIterator;
|
||||||
|
pub use segment::{WhisperSegment, WhisperToken};
|
||||||
|
|
||||||
/// Rustified pointer to a Whisper state.
|
/// Rustified pointer to a Whisper state.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
|
|
@ -151,7 +157,7 @@ impl WhisperState {
|
||||||
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
||||||
pub fn decode(
|
pub fn decode(
|
||||||
&mut self,
|
&mut self,
|
||||||
tokens: &[WhisperToken],
|
tokens: &[WhisperTokenId],
|
||||||
n_past: usize,
|
n_past: usize,
|
||||||
threads: usize,
|
threads: usize,
|
||||||
) -> Result<(), WhisperError> {
|
) -> Result<(), WhisperError> {
|
||||||
|
|
@ -179,11 +185,11 @@ impl WhisperState {
|
||||||
|
|
||||||
// Language functions
|
// Language functions
|
||||||
/// Use mel data at offset_ms to try and auto-detect the spoken language
|
/// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||||
/// Make sure to call pcm_to_mel() or set_mel() first
|
/// Make sure to call [`Self::pcm_to_mel`] or [`Self::set_mel`] first
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * offset_ms: The offset in milliseconds to use for the language detection.
|
/// * `offset_ms`: The offset in milliseconds to use for the language detection.
|
||||||
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
/// * `n_threads`: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// `Ok((i32, Vec<f32>))` on success where the i32 is detected language id and Vec<f32>
|
/// `Ok((i32, Vec<f32>))` on success where the i32 is detected language id and Vec<f32>
|
||||||
|
|
@ -309,8 +315,8 @@ impl WhisperState {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
|
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn full_n_segments(&self) -> Result<c_int, WhisperError> {
|
pub fn full_n_segments(&self) -> c_int {
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) })
|
unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Language ID associated with the provided state.
|
/// Language ID associated with the provided state.
|
||||||
|
|
@ -318,281 +324,33 @@ impl WhisperState {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
|
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn full_lang_id_from_state(&self) -> Result<c_int, WhisperError> {
|
pub fn full_lang_id_from_state(&self) -> c_int {
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) })
|
unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the start time of the specified segment.
|
fn segment_in_bounds(&self, segment: c_int) -> bool {
|
||||||
///
|
segment >= 0 && segment < self.full_n_segments()
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_segment_t0(&self, segment: c_int) -> Result<i64, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.ptr, segment) })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the end time of the specified segment.
|
/// Get a [`WhisperSegment`] object for the specified segment index.
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_segment_t1(&self, segment: c_int) -> Result<i64, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> {
|
|
||||||
let ret =
|
|
||||||
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
|
|
||||||
if ret.is_null() {
|
|
||||||
return Err(WhisperError::NullPointer);
|
|
||||||
}
|
|
||||||
unsafe { Ok(CStr::from_ptr(ret)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the raw bytes of the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// `Ok(Vec<u8>)` on success, with the returned bytes or
|
/// `Some(WhisperSegment)` if `segment` is in bounds, otherwise [`None`].
|
||||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
pub fn get_segment(&self, segment: c_int) -> Option<WhisperSegment<'_>> {
|
||||||
///
|
self.segment_in_bounds(segment)
|
||||||
/// # C++ equivalent
|
.then(|| unsafe { WhisperSegment::new_unchecked(self, segment) })
|
||||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
|
||||||
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
|
|
||||||
Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the text of the specified segment.
|
/// Get a [`WhisperSegment`] object for the specified segment index.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Safety
|
||||||
/// * segment: Segment index.
|
/// You must ensure `segment` is in bounds for this [`WhisperState`].
|
||||||
///
|
pub unsafe fn get_segment_unchecked(&self, segment: c_int) -> WhisperSegment<'_> {
|
||||||
/// # Returns
|
WhisperSegment::new_unchecked(self, segment)
|
||||||
/// `Ok(String)` on success, with the UTF-8 validated string, or
|
|
||||||
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
|
||||||
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
|
|
||||||
Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the text of the specified segment.
|
/// Get an iterator over all segments.
|
||||||
/// This function differs from [WhisperState::full_get_segment_text]
|
pub fn as_iter(&self) -> WhisperStateSegmentIterator {
|
||||||
/// in that it ignores invalid UTF-8 in whisper strings,
|
WhisperStateSegmentIterator::new(self)
|
||||||
/// instead opting to replace it with the replacement character.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// `Ok(String)` on success, or
|
|
||||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
|
||||||
pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result<String, WhisperError> {
|
|
||||||
Ok(self
|
|
||||||
.full_get_segment_raw(segment)?
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get number of tokens in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// c_int
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
|
|
||||||
}
|
|
||||||
|
|
||||||
fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> {
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
|
||||||
self.ctx.ctx,
|
|
||||||
self.ptr,
|
|
||||||
segment,
|
|
||||||
token,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret.is_null() {
|
|
||||||
return Err(WhisperError::NullPointer);
|
|
||||||
}
|
|
||||||
unsafe { Ok(CStr::from_ptr(ret)) }
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the raw token bytes of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// Useful if you're using a language for which whisper is known to split tokens
|
|
||||||
/// away from UTF-8 character boundaries.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// `Ok(Vec<u8>)` on success, with the returned bytes or
|
|
||||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
pub fn full_get_token_bytes(
|
|
||||||
&self,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<Vec<u8>, WhisperError> {
|
|
||||||
Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the token text of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// `Ok(String)` on success, with the UTF-8 validated string, or
|
|
||||||
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
pub fn full_get_token_text(
|
|
||||||
&self,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<String, WhisperError> {
|
|
||||||
Ok(self
|
|
||||||
.full_get_token_raw(segment, token)?
|
|
||||||
.to_str()?
|
|
||||||
.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the token text of the specified token in the specified segment.
|
|
||||||
/// This function differs from [WhisperState::full_get_token_text]
|
|
||||||
/// in that it ignores invalid UTF-8 in whisper strings,
|
|
||||||
/// instead opting to replace it with the replacement character.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// `Ok(String)` on success, or
|
|
||||||
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
pub fn full_get_token_text_lossy(
|
|
||||||
&self,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<String, WhisperError> {
|
|
||||||
Ok(self
|
|
||||||
.full_get_token_raw(segment, token)?
|
|
||||||
.to_string_lossy()
|
|
||||||
.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the token ID of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// [crate::WhisperToken]
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
pub fn full_get_token_id(
|
|
||||||
&self,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<WhisperToken, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_id_from_state(self.ptr, segment, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get token data for the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// [crate::WhisperTokenData]
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_token_data(
|
|
||||||
&self,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<WhisperTokenData, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_data_from_state(self.ptr, segment, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the probability of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// f32
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> Result<f32, WhisperError> {
|
|
||||||
Ok(
|
|
||||||
unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_p_from_state(self.ptr, segment, token)
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get whether the next segment is predicted as a speaker turn.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * i_segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// bool
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)`
|
|
||||||
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
|
|
||||||
unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(
|
|
||||||
self.ptr, i_segment,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the no_speech probability for the specified segment
|
|
||||||
pub fn full_get_segment_no_speech_prob(&self, i_segment: c_int) -> f32 {
|
|
||||||
unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state(self.ptr, i_segment)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
28
src/whisper_state/iterator.rs
Normal file
28
src/whisper_state/iterator.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
||||||
|
use crate::whisper_state::WhisperSegment;
|
||||||
|
use crate::WhisperState;
|
||||||
|
use std::ffi::c_int;
|
||||||
|
|
||||||
|
/// An iterator over a [`WhisperState`]'s result.
|
||||||
|
pub struct WhisperStateSegmentIterator<'a> {
|
||||||
|
state_ptr: &'a WhisperState,
|
||||||
|
current_segment: c_int,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> WhisperStateSegmentIterator<'a> {
|
||||||
|
pub(super) fn new(state_ptr: &'a WhisperState) -> Self {
|
||||||
|
Self {
|
||||||
|
state_ptr,
|
||||||
|
current_segment: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> Iterator for WhisperStateSegmentIterator<'a> {
|
||||||
|
type Item = WhisperSegment<'a>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
let ret = self.state_ptr.get_segment(self.current_segment);
|
||||||
|
self.current_segment += 1;
|
||||||
|
ret
|
||||||
|
}
|
||||||
|
}
|
||||||
341
src/whisper_state/segment.rs
Normal file
341
src/whisper_state/segment.rs
Normal file
|
|
@ -0,0 +1,341 @@
|
||||||
|
use crate::{WhisperError, WhisperState, WhisperTokenData, WhisperTokenId};
|
||||||
|
use std::borrow::Cow;
|
||||||
|
use std::ffi::{c_int, CStr};
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
/// A segment returned by Whisper after running the transcription pipeline.
|
||||||
|
pub struct WhisperSegment<'a> {
|
||||||
|
state: &'a WhisperState,
|
||||||
|
|
||||||
|
segment_idx: c_int,
|
||||||
|
token_count: c_int,
|
||||||
|
}
|
||||||
|
impl<'a> WhisperSegment<'a> {
|
||||||
|
/// # Safety
|
||||||
|
/// You must ensure `segment_idx` is in bounds for the linked [`WhisperState`].
|
||||||
|
pub(super) unsafe fn new_unchecked(state: &'a WhisperState, segment_idx: c_int) -> Self {
|
||||||
|
assert!(
|
||||||
|
state.segment_in_bounds(segment_idx),
|
||||||
|
"tried to create a WhisperSegment out of bounds for linked state"
|
||||||
|
);
|
||||||
|
Self {
|
||||||
|
state,
|
||||||
|
segment_idx,
|
||||||
|
token_count: unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_n_tokens_from_state(state.ptr, segment_idx)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the start time of the specified segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Start time in centiseconds (10s of milliseconds)
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn start_timestamp(&self) -> i64 {
|
||||||
|
unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.state.ptr, self.segment_idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the end time of the specified segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// End time in centiseconds (10s of milliseconds)
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn end_timestamp(&self) -> i64 {
|
||||||
|
unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.state.ptr, self.segment_idx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get number of tokens in this segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// `c_int`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn n_tokens(&self) -> c_int {
|
||||||
|
self.token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get whether the next segment is predicted as a speaker turn.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// `bool`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)`
|
||||||
|
pub fn next_segment_speaker_turn(&self) -> bool {
|
||||||
|
unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(
|
||||||
|
self.state.ptr,
|
||||||
|
self.segment_idx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the no_speech probability for the specified segment
|
||||||
|
pub fn no_speech_probability(&self) -> f32 {
|
||||||
|
unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state(
|
||||||
|
self.state.ptr,
|
||||||
|
self.segment_idx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_raw_cstr(&self) -> Result<&CStr, WhisperError> {
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_segment_text_from_state(
|
||||||
|
self.state.ptr,
|
||||||
|
self.segment_idx,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret.is_null() {
|
||||||
|
return Err(WhisperError::NullPointer);
|
||||||
|
}
|
||||||
|
Ok(unsafe { CStr::from_ptr(ret) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the raw bytes of this segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: The raw bytes, with no null terminator
|
||||||
|
/// * On failure: [`WhisperError::NullPointer`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn to_bytes(&self) -> Result<&[u8], WhisperError> {
|
||||||
|
Ok(self.to_raw_cstr()?.to_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the text of this segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: the UTF-8 validated string.
|
||||||
|
/// * On failure: [`WhisperError::NullPointer`] or [`WhisperError::InvalidUtf8`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn to_str(&self) -> Result<&str, WhisperError> {
|
||||||
|
Ok(self.to_raw_cstr()?.to_str()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the text of this segment.
|
||||||
|
///
|
||||||
|
/// This function differs from [`Self::to_str`]
|
||||||
|
/// in that it ignores invalid UTF-8 in strings,
|
||||||
|
/// and instead replaces it with the replacement character.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: The valid string, with any invalid UTF-8 replaced with the replacement character
|
||||||
|
/// * On failure: [`WhisperError::NullPointer`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn to_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
|
||||||
|
Ok(self.to_raw_cstr()?.to_string_lossy())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn token_in_bounds(&self, token_idx: c_int) -> bool {
|
||||||
|
token_idx >= 0 && token_idx < self.token_count
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_token(&self, token: c_int) -> Option<WhisperToken<'_, '_>> {
|
||||||
|
self.token_in_bounds(token)
|
||||||
|
.then(|| unsafe { WhisperToken::new_unchecked(self, token) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// # Safety
|
||||||
|
/// You must ensure `token` is in bounds for this [`WhisperSegment`].
|
||||||
|
pub unsafe fn get_token_unchecked(&self, token: c_int) -> WhisperToken<'_, '_> {
|
||||||
|
WhisperToken::new_unchecked(self, token)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write the contents of this segment to the output.
|
||||||
|
/// This will panic if Whisper returns a null pointer.
|
||||||
|
///
|
||||||
|
/// Uses [`Self::to_str_lossy`] internally.
|
||||||
|
impl fmt::Display for WhisperSegment<'_> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"{}",
|
||||||
|
self.to_str_lossy()
|
||||||
|
.expect("got null pointer during string write")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for WhisperSegment<'_> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("WhisperSegment")
|
||||||
|
.field("segment", &self.segment_idx)
|
||||||
|
.field("n_tokens", &self.token_count)
|
||||||
|
.field("start_ts", &self.start_timestamp())
|
||||||
|
.field("end_ts", &self.end_timestamp())
|
||||||
|
.field(
|
||||||
|
"next_segment_speaker_turn",
|
||||||
|
&self.next_segment_speaker_turn(),
|
||||||
|
)
|
||||||
|
.field("no_speech_probability", &self.no_speech_probability())
|
||||||
|
.field("text", &self.to_str_lossy())
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct WhisperToken<'a, 'b: 'a> {
|
||||||
|
segment: &'a WhisperSegment<'b>,
|
||||||
|
token_idx: c_int,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, 'b> WhisperToken<'a, 'b> {
|
||||||
|
/// # Safety
|
||||||
|
/// You must ensure `token_idx` is in bounds for this [`WhisperSegment`].
|
||||||
|
unsafe fn new_unchecked(segment: &'a WhisperSegment<'b>, token_idx: c_int) -> Self {
|
||||||
|
Self { segment, token_idx }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token ID of this token in its segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// [`WhisperTokenId`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn token_id(&self) -> Result<WhisperTokenId, WhisperError> {
|
||||||
|
Ok(unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_id_from_state(
|
||||||
|
self.segment.state.ptr,
|
||||||
|
self.segment.segment_idx,
|
||||||
|
self.token_idx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get token data for this token in its segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// [`WhisperTokenData`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn token_data(&self) -> Result<WhisperTokenData, WhisperError> {
|
||||||
|
Ok(unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_data_from_state(
|
||||||
|
self.segment.state.ptr,
|
||||||
|
self.segment.segment_idx,
|
||||||
|
self.token_idx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the probability of this token in its segment.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// `f32`
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn token_probability(&self) -> Result<f32, WhisperError> {
|
||||||
|
Ok(unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_p_from_state(
|
||||||
|
self.segment.state.ptr,
|
||||||
|
self.segment.segment_idx,
|
||||||
|
self.token_idx,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_raw_cstr(&self) -> Result<&CStr, WhisperError> {
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||||
|
self.segment.state.ctx.ctx,
|
||||||
|
self.segment.state.ptr,
|
||||||
|
self.segment.segment_idx,
|
||||||
|
self.token_idx,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret.is_null() {
|
||||||
|
return Err(WhisperError::NullPointer);
|
||||||
|
}
|
||||||
|
Ok(unsafe { CStr::from_ptr(ret) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the raw bytes of this token.
|
||||||
|
///
|
||||||
|
/// Useful if you're using a language for which Whisper is known to split tokens
|
||||||
|
/// away from UTF-8 character boundaries.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: The raw bytes, with no null terminator
|
||||||
|
/// * On failure: [`WhisperError::NullPointer`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn to_bytes(&self) -> Result<&[u8], WhisperError> {
|
||||||
|
Ok(self.to_raw_cstr()?.to_bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the text of this token.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: the UTF-8 validated string.
|
||||||
|
/// * On failure: [`WhisperError::NullPointer`] or [`WhisperError::InvalidUtf8`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn to_str(&self) -> Result<&str, WhisperError> {
|
||||||
|
Ok(self.to_raw_cstr()?.to_str()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the text of this token.
|
||||||
|
///
|
||||||
|
/// This function differs from [`Self::to_str`]
|
||||||
|
/// in that it ignores invalid UTF-8 in strings,
|
||||||
|
/// and instead replaces it with the replacement character.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// * On success: The valid string, with any invalid UTF-8 replaced with the replacement character
|
||||||
|
/// * On failure: [`WhisperError::NullPointer`]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn to_str_lossy(&self) -> Result<Cow<'_, str>, WhisperError> {
|
||||||
|
Ok(self.to_raw_cstr()?.to_string_lossy())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write the contents of this token to the output.
|
||||||
|
/// This will panic if Whisper returns a null pointer.
|
||||||
|
///
|
||||||
|
/// Uses [`Self::to_str_lossy`] internally.
|
||||||
|
impl fmt::Display for WhisperToken<'_, '_> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"{}",
|
||||||
|
self.to_str_lossy()
|
||||||
|
.expect("got null pointer during string write")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Debug for WhisperToken<'_, '_> {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
f.debug_struct("WhisperToken")
|
||||||
|
.field("segment_idx", &self.segment.segment_idx)
|
||||||
|
.field("token_idx", &self.token_idx)
|
||||||
|
.field("token_id", &self.token_id())
|
||||||
|
.field("token_data", &self.token_data())
|
||||||
|
.field("token_probability", &self.token_probability())
|
||||||
|
.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue