From 6169229e607994f71e264ff8619fcf930596b0a6 Mon Sep 17 00:00:00 2001 From: Yuniru Yuni Date: Tue, 25 Apr 2023 22:48:45 +0900 Subject: [PATCH] refactor: delete map for State and expose struct with lifetime --- src/whisper_ctx.rs | 193 ++++++++++++++++++++----------------------- src/whisper_state.rs | 20 +++-- 2 files changed, 101 insertions(+), 112 deletions(-) diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 11ee051..8eae734 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -2,23 +2,18 @@ use crate::error::WhisperError; use crate::whisper_params::FullParams; use crate::whisper_state::WhisperState; use crate::{WhisperToken, WhisperTokenData}; -use dashmap::DashMap; use std::ffi::{c_int, CStr, CString}; -use std::hash::Hash; /// Safe Rust wrapper around a Whisper context. /// /// You likely want to create this with [WhisperContext::new], /// then run a full transcription with [WhisperContext::full]. #[derive(Debug)] -pub struct WhisperContext { +pub struct WhisperContext { ctx: *mut whisper_rs_sys::whisper_context, - - /// Map of state IDs to state objects. - state_map: DashMap, } -impl WhisperContext { +impl WhisperContext { /// Create a new WhisperContext from a file. /// /// # Arguments @@ -35,10 +30,7 @@ impl WhisperContext { if ctx.is_null() { Err(WhisperError::InitError) } else { - Ok(Self { - ctx, - state_map: DashMap::new(), - }) + Ok(Self { ctx }) } } @@ -59,10 +51,7 @@ impl WhisperContext { if ctx.is_null() { Err(WhisperError::InitError) } else { - Ok(Self { - ctx, - state_map: DashMap::new(), - }) + Ok(Self { ctx }) } } @@ -79,28 +68,16 @@ impl WhisperContext { /// /// # C++ equivalent /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` - pub fn create_key(&self, id: K) -> Result<(), WhisperError> { - if self.state_map.contains_key(&id) { - return Err(WhisperError::StateIdAlreadyExists); - } + pub fn create_state(&self) -> Result { let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) }; if state.is_null() { Err(WhisperError::InitError) } else { // SAFETY: this is known to be a valid pointer to a `whisper_state` struct - self.state_map - .insert(id, unsafe { WhisperState::new(state) }); - Ok(()) + Ok(WhisperState::new(state)) } } - fn get_state_ptr(&self, id: &K) -> Result<*mut whisper_rs_sys::whisper_state, WhisperError> { - self.state_map - .get(id) - .map(|s| s.value().as_ptr()) - .ok_or(WhisperError::StateIdDoesNotExist) - } - /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. /// The resulting spectrogram is stored in the context transparently. /// @@ -113,15 +90,19 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` - pub fn pcm_to_mel(&self, key: &K, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { + pub fn pcm_to_mel( + &self, + state: &WhisperState, + pcm: &[f32], + threads: usize, + ) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } - let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { whisper_rs_sys::whisper_pcm_to_mel_with_state( self.ctx, - state_ptr, + state.as_ptr(), pcm.as_ptr(), pcm.len() as c_int, threads as c_int, @@ -151,18 +132,17 @@ impl WhisperContext { /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` pub fn pcm_to_mel_phase_vocoder( &self, - key: &K, + state: &WhisperState, pcm: &[f32], threads: usize, ) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } - let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state( self.ctx, - state_ptr, + state.as_ptr(), pcm.as_ptr(), pcm.len() as c_int, threads as c_int, @@ -193,13 +173,11 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` - pub fn set_mel(&self, key: &K, data: &[f32]) -> Result<(), WhisperError> { - let state_ptr = self.get_state_ptr(key)?; - + pub fn set_mel(&self, state: &WhisperState, data: &[f32]) -> Result<(), WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_set_mel_with_state( self.ctx, - state_ptr, + state.as_ptr(), data.as_ptr(), data.len() as c_int, 80 as c_int, @@ -226,15 +204,19 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)` - pub fn encode(&self, key: &K, offset: usize, threads: usize) -> Result<(), WhisperError> { + pub fn encode( + &self, + state: &WhisperState, + offset: usize, + threads: usize, + ) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } - let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { whisper_rs_sys::whisper_encode_with_state( self.ctx, - state_ptr, + state.as_ptr(), offset as c_int, threads as c_int, ) @@ -265,7 +247,7 @@ impl WhisperContext { /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` pub fn decode( &self, - key: &K, + state: &WhisperState, tokens: &[WhisperToken], n_past: usize, threads: usize, @@ -273,11 +255,10 @@ impl WhisperContext { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } - let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { whisper_rs_sys::whisper_decode_with_state( self.ctx, - state_ptr, + state.as_ptr(), tokens.as_ptr(), tokens.len() as c_int, n_past as c_int, @@ -342,20 +323,19 @@ impl WhisperContext { /// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)` pub fn lang_detect( &self, - key: &K, + state: &WhisperState, offset_ms: usize, threads: usize, ) -> Result, WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } - let state_ptr = self.get_state_ptr(key)?; let mut lang_probs: Vec = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; let ret = unsafe { whisper_rs_sys::whisper_lang_auto_detect_with_state( self.ctx, - state_ptr, + state.as_ptr(), offset_ms as c_int, threads as c_int, lang_probs.as_mut_ptr(), @@ -392,8 +372,8 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_n_len_from_state(struct whisper_context * ctx)` #[inline] - pub fn n_len(&self, key: &K) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(self.get_state_ptr(key)?) }) + pub fn n_len(&self, state: &WhisperState) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(state.as_ptr()) }) } /// Get n_vocab. @@ -600,16 +580,18 @@ impl WhisperContext { /// /// # C++ equivalent /// `float * whisper_get_logits(struct whisper_context * ctx)` - pub fn get_logits(&self, key: &K, segment: c_int) -> Result>, WhisperError> { - let state_ptr = self.get_state_ptr(key)?; - - let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(state_ptr) }; + pub fn get_logits( + &self, + state: &WhisperState, + segment: c_int, + ) -> Result>, WhisperError> { + let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(state.as_ptr()) }; if ret.is_null() { return Err(WhisperError::NullPointer); } let mut logits = Vec::new(); let n_vocab = self.n_vocab(); - let n_tokens = self.full_n_tokens(key, segment)?; + let n_tokens = self.full_n_tokens(state, segment)?; for i in 0..n_tokens { let mut row = Vec::new(); for j in 0..n_vocab { @@ -756,12 +738,16 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` - pub fn full(&self, key: &K, params: FullParams, data: &[f32]) -> Result { - let state_ptr = self.get_state_ptr(key)?; + pub fn full( + &self, + state: &WhisperState, + params: FullParams, + data: &[f32], + ) -> Result { let ret = unsafe { whisper_rs_sys::whisper_full_with_state( self.ctx, - state_ptr, + state.as_ptr(), params.fp, data.as_ptr(), data.len() as c_int, @@ -786,8 +772,8 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_full_n_segments(struct whisper_context * ctx)` #[inline] - pub fn full_n_segments(&self, key: &K) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.get_state_ptr(key)?) }) + pub fn full_n_segments(&self, state: &WhisperState) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(state.as_ptr()) }) } /// Language ID associated with the provided state. @@ -795,8 +781,8 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_full_lang_id_from_state(struct whisper_state * state);` #[inline] - pub fn full_lang_id_from_state(&self, key: &K) -> Result { - Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.get_state_ptr(key)?) }) + pub fn full_lang_id_from_state(&self, state: &WhisperState) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(state.as_ptr()) }) } /// Get the start time of the specified segment. @@ -807,12 +793,13 @@ impl WhisperContext { /// # C++ equivalent /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` #[inline] - pub fn full_get_segment_t0(&self, key: &K, segment: c_int) -> Result { + pub fn full_get_segment_t0( + &self, + state: &WhisperState, + segment: c_int, + ) -> Result { Ok(unsafe { - whisper_rs_sys::whisper_full_get_segment_t0_from_state( - self.get_state_ptr(key)?, - segment, - ) + whisper_rs_sys::whisper_full_get_segment_t0_from_state(state.as_ptr(), segment) }) } @@ -824,12 +811,13 @@ impl WhisperContext { /// # C++ equivalent /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` #[inline] - pub fn full_get_segment_t1(&self, key: &K, segment: c_int) -> Result { + pub fn full_get_segment_t1( + &self, + state: &WhisperState, + segment: c_int, + ) -> Result { Ok(unsafe { - whisper_rs_sys::whisper_full_get_segment_t1_from_state( - self.get_state_ptr(key)?, - segment, - ) + whisper_rs_sys::whisper_full_get_segment_t1_from_state(state.as_ptr(), segment) }) } @@ -843,11 +831,14 @@ impl WhisperContext { /// /// # C++ equivalent /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_text(&self, key: &K, segment: c_int) -> Result { - let state_ptr = self.get_state_ptr(key)?; - - let ret = - unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(state_ptr, segment) }; + pub fn full_get_segment_text( + &self, + state: &WhisperState, + segment: c_int, + ) -> Result { + let ret = unsafe { + whisper_rs_sys::whisper_full_get_segment_text_from_state(state.as_ptr(), segment) + }; if ret.is_null() { return Err(WhisperError::NullPointer); } @@ -867,10 +858,12 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` #[inline] - pub fn full_n_tokens(&self, key: &K, segment: c_int) -> Result { - Ok(unsafe { - whisper_rs_sys::whisper_full_n_tokens_from_state(self.get_state_ptr(key)?, segment) - }) + pub fn full_n_tokens( + &self, + state: &WhisperState, + segment: c_int, + ) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(state.as_ptr(), segment) }) } /// Get the token text of the specified token in the specified segment. @@ -886,14 +879,16 @@ impl WhisperContext { /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` pub fn full_get_token_text( &self, - key: &K, + state: &WhisperState, segment: c_int, token: c_int, ) -> Result { - let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx, state_ptr, segment, token, + self.ctx, + state.as_ptr(), + segment, + token, ) }; if ret.is_null() { @@ -917,16 +912,12 @@ impl WhisperContext { /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` pub fn full_get_token_id( &self, - key: &K, + state: &WhisperState, segment: c_int, token: c_int, ) -> Result { Ok(unsafe { - whisper_rs_sys::whisper_full_get_token_id_from_state( - self.get_state_ptr(key)?, - segment, - token, - ) + whisper_rs_sys::whisper_full_get_token_id_from_state(state.as_ptr(), segment, token) }) } @@ -944,16 +935,12 @@ impl WhisperContext { #[inline] pub fn full_get_token_data( &self, - key: &K, + state: &WhisperState, segment: c_int, token: c_int, ) -> Result { Ok(unsafe { - whisper_rs_sys::whisper_full_get_token_data_from_state( - self.get_state_ptr(key)?, - segment, - token, - ) + whisper_rs_sys::whisper_full_get_token_data_from_state(state.as_ptr(), segment, token) }) } @@ -971,21 +958,17 @@ impl WhisperContext { #[inline] pub fn full_get_token_prob( &self, - key: &K, + state: &WhisperState, segment: c_int, token: c_int, ) -> Result { Ok(unsafe { - whisper_rs_sys::whisper_full_get_token_p_from_state( - self.get_state_ptr(key)?, - segment, - token, - ) + whisper_rs_sys::whisper_full_get_token_p_from_state(state.as_ptr(), segment, token) }) } } -impl Drop for WhisperContext { +impl Drop for WhisperContext { #[inline] fn drop(&mut self) { unsafe { whisper_rs_sys::whisper_free(self.ctx) }; @@ -994,5 +977,5 @@ impl Drop for WhisperContext { // following implementations are safe // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 -unsafe impl Send for WhisperContext {} -unsafe impl Sync for WhisperContext {} +unsafe impl Send for WhisperContext {} +unsafe impl Sync for WhisperContext {} diff --git a/src/whisper_state.rs b/src/whisper_state.rs index 64453e9..d873152 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -1,13 +1,16 @@ +use std::marker::PhantomData; + /// Rustified pointer to a Whisper state. #[derive(Debug)] -pub struct WhisperState { +pub struct WhisperState<'a> { ptr: *mut whisper_rs_sys::whisper_state, + _phantom: PhantomData<&'a ()>, } -unsafe impl Send for WhisperState {} -unsafe impl Sync for WhisperState {} +unsafe impl<'a> Send for WhisperState<'a> {} +unsafe impl<'a> Sync for WhisperState<'a> {} -impl Drop for WhisperState { +impl<'a> Drop for WhisperState<'a> { fn drop(&mut self) { unsafe { whisper_rs_sys::whisper_free_state(self.ptr); @@ -15,9 +18,12 @@ impl Drop for WhisperState { } } -impl WhisperState { - pub(crate) unsafe fn new(ptr: *mut whisper_rs_sys::whisper_state) -> Self { - Self { ptr } +impl<'a> WhisperState<'a> { + pub(crate) fn new(ptr: *mut whisper_rs_sys::whisper_state) -> Self { + Self { + ptr, + _phantom: PhantomData, + } } pub(crate) fn as_ptr(&self) -> *mut whisper_rs_sys::whisper_state {