add changes from whisper.cpp update

This commit is contained in:
Zero 2023-04-17 17:57:00 -06:00
parent 7c78c128a1
commit 13d44e5881
No known key found for this signature in database
GPG key ID: 3861E636EA1E0E2B
9 changed files with 536 additions and 140 deletions

View file

@ -1,24 +1,24 @@
use crate::error::WhisperError;
use crate::whisper_params::FullParams;
use crate::whisper_state::WhisperState;
use crate::{WhisperToken, WhisperTokenData};
use dashmap::DashMap;
use std::ffi::{c_int, CStr, CString};
use std::hash::Hash;
/// Safe Rust wrapper around a Whisper context.
///
/// You likely want to create this with [WhisperContext::new],
/// then run a full transcription with [WhisperContext::full].
#[derive(Debug)]
pub struct WhisperContext {
pub struct WhisperContext<K: Hash + Eq> {
ctx: *mut whisper_rs_sys::whisper_context,
/// has the spectrogram been initialized in at least one way?
spectrogram_initialized: bool,
/// has the data been encoded?
encode_complete: bool,
/// has decode been called at least once?
decode_once: bool,
/// Map of state IDs to state objects.
state_map: DashMap<K, WhisperState>,
}
impl WhisperContext {
impl<K: Hash + Eq> WhisperContext<K> {
/// Create a new WhisperContext from a file.
///
/// # Arguments
@ -31,15 +31,13 @@ impl WhisperContext {
/// `struct whisper_context * whisper_init_from_file(const char * path_model);`
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file(path_cstr.as_ptr()) };
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self {
ctx,
spectrogram_initialized: false,
encode_complete: false,
decode_once: false,
state_map: DashMap::new(),
})
}
}
@ -55,22 +53,54 @@ impl WhisperContext {
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);`
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx =
unsafe { whisper_rs_sys::whisper_init_from_buffer(buffer.as_ptr() as _, buffer.len()) };
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self {
ctx,
spectrogram_initialized: false,
encode_complete: false,
decode_once: false,
state_map: DashMap::new(),
})
}
}
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Create a new state object, ready for use.
///
/// # Arguments
/// * id: The ID of the state object. Must be unique.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
/// If the ID is already in use, returns Err(WhisperError::StateIdAlreadyExists).
///
/// # C++ equivalent
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
pub fn create_key(&self, id: K) -> Result<(), WhisperError> {
if self.state_map.contains_key(&id) {
return Err(WhisperError::StateIdAlreadyExists);
}
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) };
if state.is_null() {
Err(WhisperError::InitError)
} else {
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
self.state_map
.insert(id, unsafe { WhisperState::new(state) });
Ok(())
}
}
fn get_state_ptr(&self, id: &K) -> Result<*mut whisper_rs_sys::whisper_state, WhisperError> {
self.state_map
.get(id)
.map(|s| s.value().as_ptr())
.ok_or(WhisperError::StateIdDoesNotExist)
}
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// The resulting spectrogram is stored in the context transparently.
///
@ -83,13 +113,15 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
pub fn pcm_to_mel(&self, key: &K, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel(
whisper_rs_sys::whisper_pcm_to_mel_with_state(
self.ctx,
state_ptr,
pcm.as_ptr(),
pcm.len() as c_int,
threads as c_int,
@ -98,14 +130,54 @@ impl WhisperContext {
if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 0 {
self.spectrogram_initialized = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
/// This can be used to set a custom log mel spectrogram inside the provided whisper context.
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// Applies a Phase Vocoder to speed up the audio x2.
/// The resulting spectrogram is stored in the context transparently.
///
/// # Arguments
/// * pcm: The raw PCM audio.
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
pub fn pcm_to_mel_phase_vocoder(
&self,
key: &K,
pcm: &[f32],
threads: usize,
) -> Result<(), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
self.ctx,
state_ptr,
pcm.as_ptr(),
pcm.len() as c_int,
threads as c_int,
)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 0 {
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
/// This can be used to set a custom log mel spectrogram inside the provided whisper state.
/// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
///
/// # Note
@ -121,10 +193,13 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> {
pub fn set_mel(&self, key: &K, data: &[f32]) -> Result<(), WhisperError> {
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_set_mel(
whisper_rs_sys::whisper_set_mel_with_state(
self.ctx,
state_ptr,
data.as_ptr(),
data.len() as c_int,
80 as c_int,
@ -133,7 +208,6 @@ impl WhisperContext {
if ret == -1 {
Err(WhisperError::InvalidMelBands)
} else if ret == 0 {
self.spectrogram_initialized = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
@ -152,19 +226,22 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
pub fn encode(&self, key: &K, offset: usize, threads: usize) -> Result<(), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let ret =
unsafe { whisper_rs_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) };
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_encode_with_state(
self.ctx,
state_ptr,
offset as c_int,
threads as c_int,
)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateEvaluation)
} else if ret == 0 {
self.encode_complete = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
@ -187,20 +264,20 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
pub fn decode(
&mut self,
&self,
key: &K,
tokens: &[WhisperToken],
n_past: usize,
threads: usize,
) -> Result<(), WhisperError> {
if !self.encode_complete {
return Err(WhisperError::EncodeNotComplete);
}
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_decode(
whisper_rs_sys::whisper_decode_with_state(
self.ctx,
state_ptr,
tokens.as_ptr(),
tokens.len() as c_int,
n_past as c_int,
@ -210,7 +287,6 @@ impl WhisperContext {
if ret == -1 {
Err(WhisperError::UnableToCalculateEvaluation)
} else if ret == 0 {
self.decode_once = true;
Ok(())
} else {
Err(WhisperError::GenericError(ret))
@ -228,7 +304,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
pub fn tokenize(
&mut self,
&self,
text: &str,
max_tokens: usize,
) -> Result<Vec<WhisperToken>, WhisperError> {
@ -265,20 +341,21 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
pub fn lang_detect(
&mut self,
&self,
key: &K,
offset_ms: usize,
threads: usize,
) -> Result<Vec<f32>, WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let state_ptr = self.get_state_ptr(key)?;
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
let ret = unsafe {
whisper_rs_sys::whisper_lang_auto_detect(
whisper_rs_sys::whisper_lang_auto_detect_with_state(
self.ctx,
state_ptr,
offset_ms as c_int,
threads as c_int,
lang_probs.as_mut_ptr(),
@ -296,7 +373,10 @@ impl WhisperContext {
// and abort, as this will cause Undefined Behavior
// might get here due to the unwind being caught by a user-installed panic handler
if lang_probs.len() != ret as usize {
eprintln!("lang_probs length mismatch: this is a bug in whisper.cpp, aborting");
eprintln!(
"lang_probs length mismatch: this is a bug in whisper.cpp, \
aborting to avoid Undefined Behavior"
);
std::process::abort();
}
Ok(lang_probs)
@ -307,13 +387,13 @@ impl WhisperContext {
/// Get the mel spectrogram length.
///
/// # Returns
/// c_int
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_n_len (struct whisper_context * ctx)`
/// `int whisper_n_len_from_state(struct whisper_context * ctx)`
#[inline]
pub fn n_len(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_len(self.ctx) }
pub fn n_len(&self, key: &K) -> Result<c_int, WhisperError> {
Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(self.get_state_ptr(key)?) })
}
/// Get n_vocab.
@ -334,7 +414,7 @@ impl WhisperContext {
/// c_int
///
/// # C++ equivalent
/// `int whisper_n_text_ctx (struct whisper_context * ctx)`
/// `int whisper_n_text_ctx (struct whisper_context * ctx);`
#[inline]
pub fn n_text_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) }
@ -346,7 +426,7 @@ impl WhisperContext {
/// c_int
///
/// # C++ equivalent
/// `int whisper_n_audio_ctx (struct whisper_context * ctx)`
/// `int whisper_n_audio_ctx (struct whisper_context * ctx);`
#[inline]
pub fn n_audio_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) }
@ -361,6 +441,150 @@ impl WhisperContext {
unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
}
/// Get model_n_vocab.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_vocab (struct whisper_context * ctx);`
#[inline]
pub fn model_n_vocab(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_vocab(self.ctx) }
}
/// Get model_n_audio_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)`
#[inline]
pub fn model_n_audio_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_ctx(self.ctx) }
}
/// Get model_n_audio_state.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_state(struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_state(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_state(self.ctx) }
}
/// Get model_n_audio_head.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_head (struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_head(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_head(self.ctx) }
}
/// Get model_n_audio_layer.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_layer(struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_layer(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_layer(self.ctx) }
}
/// Get model_n_text_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_ctx (struct whisper_context * ctx)`
#[inline]
pub fn model_n_text_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_ctx(self.ctx) }
}
/// Get model_n_text_state.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_state (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_state(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_state(self.ctx) }
}
/// Get model_n_text_head.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_head (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_head(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_head(self.ctx) }
}
/// Get model_n_text_layer.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_layer (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_layer(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_layer(self.ctx) }
}
/// Get model_n_mels.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_mels (struct whisper_context * ctx);`
#[inline]
pub fn model_n_mels(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_mels(self.ctx) }
}
/// Get model_f16.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_f16 (struct whisper_context * ctx);`
#[inline]
pub fn model_f16(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_f16(self.ctx) }
}
/// Get model_type.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_type (struct whisper_context * ctx);`
#[inline]
pub fn model_type(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
}
// logit functions
/// Get the logits obtained from the last call to [WhisperContext::decode].
/// The logits for the last token are stored in the last row of the matrix.
@ -376,18 +600,16 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `float * whisper_get_logits(struct whisper_context * ctx)`
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
pub fn get_logits(&self, key: &K, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe { whisper_rs_sys::whisper_get_logits(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(state_ptr) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let mut logits = Vec::new();
let n_vocab = self.n_vocab();
let n_tokens = self.full_n_tokens(segment);
let n_tokens = self.full_n_tokens(key, segment)?;
for i in 0..n_tokens {
let mut row = Vec::new();
for j in 0..n_vocab {
@ -421,6 +643,21 @@ impl WhisperContext {
Ok(r_str.to_string())
}
/// Undocumented but exposed function in the C++ API.
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
}
/// Get the ID of the eot token.
///
/// # C++ equivalent
@ -519,51 +756,15 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
pub fn full(&self, key: &K, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 7 {
Err(WhisperError::FailedToEncode)
} else if ret == 8 {
Err(WhisperError::FailedToDecode)
} else if ret == 0 {
Ok(ret)
} else {
Err(WhisperError::GenericError(ret))
}
}
/// Split the input audio into chunks and delegate to [WhisperContext::full].
///
/// It seems this approach can offer some speedup in some cases,
/// however, the accuracy can be worse at the start and end of chunks.
///
/// # Arguments
/// * params: [crate::FullParams] struct.
/// * pcm: PCM audio data.
/// * n_processors: Number of threads to use.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_full_parallel(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples, int n_processors)`
pub fn full_parallel(
&mut self,
params: FullParams,
data: &[f32],
n_processors: c_int,
) -> Result<c_int, WhisperError> {
let ret = unsafe {
whisper_rs_sys::whisper_full_parallel(
whisper_rs_sys::whisper_full_with_state(
self.ctx,
state_ptr,
params.fp,
data.as_ptr(),
data.len() as c_int,
n_processors,
)
};
if ret == -1 {
@ -573,8 +774,6 @@ impl WhisperContext {
} else if ret == 8 {
Err(WhisperError::FailedToDecode)
} else if ret == 0 {
// note 0 is returned on success and also when initializing other contexts fails,
// causing some audio to not be processed
Ok(ret)
} else {
Err(WhisperError::GenericError(ret))
@ -587,8 +786,17 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
#[inline]
pub fn full_n_segments(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_full_n_segments(self.ctx) }
pub fn full_n_segments(&self, key: &K) -> Result<c_int, WhisperError> {
Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.get_state_ptr(key)?) })
}
/// Language ID associated with the provided state.
///
/// # C++ equivalent
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
#[inline]
pub fn full_lang_id_from_state(&self, key: &K) -> Result<c_int, WhisperError> {
Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.get_state_ptr(key)?) })
}
/// Get the start time of the specified segment.
@ -599,8 +807,13 @@ impl WhisperContext {
/// # C++ equivalent
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
#[inline]
pub fn full_get_segment_t0(&self, segment: c_int) -> i64 {
unsafe { whisper_rs_sys::whisper_full_get_segment_t0(self.ctx, segment) }
pub fn full_get_segment_t0(&self, key: &K, segment: c_int) -> Result<i64, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_segment_t0_from_state(
self.get_state_ptr(key)?,
segment,
)
})
}
/// Get the end time of the specified segment.
@ -611,8 +824,13 @@ impl WhisperContext {
/// # C++ equivalent
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
#[inline]
pub fn full_get_segment_t1(&self, segment: c_int) -> i64 {
unsafe { whisper_rs_sys::whisper_full_get_segment_t1(self.ctx, segment) }
pub fn full_get_segment_t1(&self, key: &K, segment: c_int) -> Result<i64, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_segment_t1_from_state(
self.get_state_ptr(key)?,
segment,
)
})
}
/// Get the text of the specified segment.
@ -625,8 +843,11 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_get_segment_text(self.ctx, segment) };
pub fn full_get_segment_text(&self, key: &K, segment: c_int) -> Result<String, WhisperError> {
let state_ptr = self.get_state_ptr(key)?;
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(state_ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
@ -646,8 +867,10 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
#[inline]
pub fn full_n_tokens(&self, segment: c_int) -> c_int {
unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) }
pub fn full_n_tokens(&self, key: &K, segment: c_int) -> Result<c_int, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_n_tokens_from_state(self.get_state_ptr(key)?, segment)
})
}
/// Get the token text of the specified token in the specified segment.
@ -663,10 +886,16 @@ impl WhisperContext {
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_text(
&self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) };
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx, state_ptr, segment, token,
)
};
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
@ -686,8 +915,19 @@ impl WhisperContext {
///
/// # C++ equivalent
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_id(&self, segment: c_int, token: c_int) -> WhisperToken {
unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) }
pub fn full_get_token_id(
&self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<WhisperToken, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_token_id_from_state(
self.get_state_ptr(key)?,
segment,
token,
)
})
}
/// Get token data for the specified token in the specified segment.
@ -702,8 +942,19 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
#[inline]
pub fn full_get_token_data(&self, segment: c_int, token: c_int) -> WhisperTokenData {
unsafe { whisper_rs_sys::whisper_full_get_token_data(self.ctx, segment, token) }
pub fn full_get_token_data(
&self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<WhisperTokenData, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_token_data_from_state(
self.get_state_ptr(key)?,
segment,
token,
)
})
}
/// Get the probability of the specified token in the specified segment.
@ -718,12 +969,23 @@ impl WhisperContext {
/// # C++ equivalent
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
#[inline]
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 {
unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) }
pub fn full_get_token_prob(
&self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<f32, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_token_p_from_state(
self.get_state_ptr(key)?,
segment,
token,
)
})
}
}
impl Drop for WhisperContext {
impl<K: Hash + Eq> Drop for WhisperContext<K> {
#[inline]
fn drop(&mut self) {
unsafe { whisper_rs_sys::whisper_free(self.ctx) };
@ -732,6 +994,5 @@ impl Drop for WhisperContext {
// following implementations are safe
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
// concurrent usage is prevented by &mut self on methods that modify the struct
unsafe impl Send for WhisperContext {}
unsafe impl Sync for WhisperContext {}
unsafe impl<K: Hash + Eq> Send for WhisperContext<K> {}
unsafe impl<K: Hash + Eq> Sync for WhisperContext<K> {}