diff --git a/src/whisper_state.rs b/src/whisper_state.rs index d9b02c3..99f2ff1 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -1,19 +1,20 @@ -use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData}; use std::ffi::{c_int, CStr}; -use std::marker::PhantomData; +use std::sync::Arc; + +use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData}; /// Rustified pointer to a Whisper state. #[derive(Debug)] -pub struct WhisperState<'a> { - ctx: *mut whisper_rs_sys::whisper_context, +pub struct WhisperState { + ctx: Arc, ptr: *mut whisper_rs_sys::whisper_state, - _phantom: PhantomData<&'a WhisperContext>, } -unsafe impl<'a> Send for WhisperState<'a> {} -unsafe impl<'a> Sync for WhisperState<'a> {} +unsafe impl Send for WhisperState {} -impl<'a> Drop for WhisperState<'a> { +unsafe impl Sync for WhisperState {} + +impl Drop for WhisperState { fn drop(&mut self) { unsafe { whisper_rs_sys::whisper_free_state(self.ptr); @@ -21,15 +22,14 @@ impl<'a> Drop for WhisperState<'a> { } } -impl<'a> WhisperState<'a> { +impl WhisperState { pub(crate) fn new( - ctx: *mut whisper_rs_sys::whisper_context, + ctx: Arc, ptr: *mut whisper_rs_sys::whisper_state, ) -> Self { Self { ctx, ptr, - _phantom: PhantomData, } } @@ -45,13 +45,13 @@ impl<'a> WhisperState<'a> { /// /// # C++ equivalent /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` - pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { + pub fn pcm_to_mel(&self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } let ret = unsafe { whisper_rs_sys::whisper_pcm_to_mel_with_state( - self.ctx, + self.ctx.ctx, self.ptr, pcm.as_ptr(), pcm.len() as c_int, @@ -81,7 +81,7 @@ impl<'a> WhisperState<'a> { /// # C++ equivalent /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` pub fn pcm_to_mel_phase_vocoder( - &mut self, + &self, pcm: &[f32], threads: usize, ) -> Result<(), WhisperError> { @@ -90,7 +90,7 @@ impl<'a> WhisperState<'a> { } let ret = unsafe { whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state( - self.ctx, + self.ctx.ctx, self.ptr, pcm.as_ptr(), pcm.len() as c_int, @@ -122,12 +122,12 @@ impl<'a> WhisperState<'a> { /// /// # C++ equivalent /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` - pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> { + pub fn set_mel(&self, data: &[f32]) -> Result<(), WhisperError> { let hop_size = 160; let n_len = (data.len() / hop_size) * 2; let ret = unsafe { whisper_rs_sys::whisper_set_mel_with_state( - self.ctx, + self.ctx.ctx, self.ptr, data.as_ptr(), n_len as c_int, @@ -155,13 +155,13 @@ impl<'a> WhisperState<'a> { /// /// # C++ equivalent /// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)` - pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> { + pub fn encode(&self, offset: usize, threads: usize) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } let ret = unsafe { whisper_rs_sys::whisper_encode_with_state( - self.ctx, + self.ctx.ctx, self.ptr, offset as c_int, threads as c_int, @@ -192,7 +192,7 @@ impl<'a> WhisperState<'a> { /// # C++ equivalent /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` pub fn decode( - &mut self, + &self, tokens: &[WhisperToken], n_past: usize, threads: usize, @@ -202,7 +202,7 @@ impl<'a> WhisperState<'a> { } let ret = unsafe { whisper_rs_sys::whisper_decode_with_state( - self.ctx, + self.ctx.ctx, self.ptr, tokens.as_ptr(), tokens.len() as c_int, @@ -240,7 +240,7 @@ impl<'a> WhisperState<'a> { let mut lang_probs: Vec = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; let ret = unsafe { whisper_rs_sys::whisper_lang_auto_detect_with_state( - self.ctx, + self.ctx.ctx, self.ptr, offset_ms as c_int, threads as c_int, @@ -309,7 +309,7 @@ impl<'a> WhisperState<'a> { /// `int whisper_n_vocab (struct whisper_context * ctx)` #[inline] pub fn n_vocab(&self) -> c_int { - unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) } + unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx.ctx) } } /// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text @@ -327,7 +327,7 @@ impl<'a> WhisperState<'a> { /// /// # C++ equivalent /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` - pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result { + pub fn full(&self, params: FullParams, data: &[f32]) -> Result { if data.is_empty() { // can randomly trigger segmentation faults if we don't check this return Err(WhisperError::NoSamples); @@ -335,7 +335,7 @@ impl<'a> WhisperState<'a> { let ret = unsafe { whisper_rs_sys::whisper_full_with_state( - self.ctx, + self.ctx.ctx, self.ptr, params.fp, data.as_ptr(), @@ -495,7 +495,7 @@ impl<'a> WhisperState<'a> { ) -> Result { let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx, self.ptr, segment, token, + self.ctx.ctx, self.ptr, segment, token, ) }; if ret.is_null() { @@ -527,7 +527,7 @@ impl<'a> WhisperState<'a> { ) -> Result { let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx, self.ptr, segment, token, + self.ctx.ctx, self.ptr, segment, token, ) }; if ret.is_null() { @@ -610,7 +610,7 @@ impl<'a> WhisperState<'a> { /// /// # C++ equivalent /// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)` - pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { + pub fn full_get_segment_speaker_turn_next(&self, i_segment: c_int) -> bool { unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state( self.ptr, i_segment,