diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index a6d0557..c3352fd 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -43,7 +43,7 @@ fn main() { } let original_samples = parse_wav_file(audio_path); - let mut samples = Vec::with_capacity(original_samples.len()); + let mut samples = vec![0.0f32; original_samples.len()]; whisper_rs::convert_integer_to_float_audio(&original_samples, &mut samples) .expect("failed to convert samples"); diff --git a/src/lib.rs b/src/lib.rs index a6da664..f5dc43f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod error; mod standalone; mod utilities; mod whisper_ctx; +mod whisper_ctx_wrapper; mod whisper_grammar; mod whisper_params; mod whisper_state; @@ -21,8 +22,9 @@ pub use standalone::*; #[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))] use std::sync::Once; pub use utilities::*; -pub use whisper_ctx::WhisperContext; pub use whisper_ctx::WhisperContextParameters; +use whisper_ctx::WhisperInnerContext; +pub use whisper_ctx_wrapper::WhisperContext; pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; pub use whisper_params::{FullParams, SamplingStrategy}; #[cfg(feature = "raw-api")] diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 8d948cd..30bd75a 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -1,19 +1,18 @@ use crate::error::WhisperError; -use crate::whisper_state::WhisperState; use crate::WhisperToken; use std::ffi::{c_int, CStr, CString}; /// Safe Rust wrapper around a Whisper context. /// -/// You likely want to create this with [WhisperContext::new_with_params], -/// create a state with [WhisperContext::create_state], +/// You likely want to create this with [WhisperInnerContext::new_with_params], +/// create a state with [WhisperInnerContext::create_state], /// then run a full transcription with [WhisperState::full]. #[derive(Debug)] -pub struct WhisperContext { - ctx: *mut whisper_rs_sys::whisper_context, +pub struct WhisperInnerContext { + pub(crate) ctx: *mut whisper_rs_sys::whisper_context, } -impl WhisperContext { +impl WhisperInnerContext { /// Create a new WhisperContext from a file, with parameters. /// /// # Arguments @@ -71,68 +70,6 @@ impl WhisperContext { } } - /// Create a new WhisperContext from a file. - /// - /// # Arguments - /// * path: The path to the model file. - /// - /// # Returns - /// Ok(Self) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)` - #[deprecated = "Use `new_with_params` instead"] - pub fn new(path: &str) -> Result { - let path_cstr = CString::new(path)?; - let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) }; - if ctx.is_null() { - Err(WhisperError::InitError) - } else { - Ok(Self { ctx }) - } - } - - /// Create a new WhisperContext from a buffer. - /// - /// # Arguments - /// * buffer: The buffer containing the model. - /// - /// # Returns - /// Ok(Self) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)` - #[deprecated = "Use `new_from_buffer_with_params` instead"] - pub fn new_from_buffer(buffer: &[u8]) -> Result { - let ctx = unsafe { - whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len()) - }; - if ctx.is_null() { - Err(WhisperError::InitError) - } else { - Ok(Self { ctx }) - } - } - - // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does - - /// Create a new state object, ready for use. - /// - /// # Returns - /// Ok(WhisperState) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` - pub fn create_state(&self) -> Result { - let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) }; - if state.is_null() { - Err(WhisperError::InitError) - } else { - // SAFETY: this is known to be a valid pointer to a `whisper_state` struct - Ok(WhisperState::new(self.ctx, state)) - } - } - /// Convert the provided text into tokens. /// /// # Arguments @@ -518,23 +455,9 @@ impl WhisperContext { pub fn token_transcribe(&self) -> WhisperToken { unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) } } - - /// Get whether the next segment is predicted as a speaker turn - /// - /// # Arguments - /// * i_segment: Segment index. - /// - /// # Returns - /// bool - /// - /// # C++ equivalent - /// `bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { - unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next(self.ctx, i_segment) } - } } -impl Drop for WhisperContext { +impl Drop for WhisperInnerContext { #[inline] fn drop(&mut self) { unsafe { whisper_rs_sys::whisper_free(self.ctx) }; @@ -543,8 +466,8 @@ impl Drop for WhisperContext { // following implementations are safe // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 -unsafe impl Send for WhisperContext {} -unsafe impl Sync for WhisperContext {} +unsafe impl Send for WhisperInnerContext {} +unsafe impl Sync for WhisperInnerContext {} pub struct WhisperContextParameters { /// Use GPU if available. @@ -588,7 +511,7 @@ mod test_with_tiny_model { #[test] fn test_tokenize_round_trip() { - let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'"); + let ctx = WhisperInnerContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'"); let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."; let tokens = ctx.tokenize(text_in, 1024).unwrap(); let text_out = tokens diff --git a/src/whisper_ctx_wrapper.rs b/src/whisper_ctx_wrapper.rs new file mode 100644 index 0000000..ff3caff --- /dev/null +++ b/src/whisper_ctx_wrapper.rs @@ -0,0 +1,427 @@ +use std::ffi::{c_int, CStr}; +use std::sync::Arc; + +use crate::{ + WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperToken, +}; + +pub struct WhisperContext { + ctx: Arc, +} + +impl WhisperContext { + fn wrap(ctx: WhisperInnerContext) -> Self { + Self { ctx: Arc::new(ctx) } + } + + /// Create a new WhisperContext from a file, with parameters. + /// + /// # Arguments + /// * path: The path to the model file. + /// * parameters: A parameter struct containing the parameters to use. + /// + /// # Returns + /// Ok(Self) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);` + pub fn new_with_params( + path: &str, + parameters: WhisperContextParameters, + ) -> Result { + let ctx = WhisperInnerContext::new_with_params(path, parameters)?; + Ok(Self::wrap(ctx)) + } + + /// Create a new WhisperContext from a buffer. + /// + /// # Arguments + /// * buffer: The buffer containing the model. + /// + /// # Returns + /// Ok(Self) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);` + pub fn new_from_buffer_with_params( + buffer: &[u8], + parameters: WhisperContextParameters, + ) -> Result { + let ctx = WhisperInnerContext::new_from_buffer_with_params(buffer, parameters)?; + Ok(Self::wrap(ctx)) + } + + /// Convert the provided text into tokens. + /// + /// # Arguments + /// * text: The text to convert. + /// + /// # Returns + /// `Ok(Vec)` on success, `Err(WhisperError)` on failure. + /// + /// # C++ equivalent + /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` + pub fn tokenize( + &self, + text: &str, + max_tokens: usize, + ) -> Result, WhisperError> { + self.ctx.tokenize(text, max_tokens) + } + + /// Get n_vocab. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_n_vocab (struct whisper_context * ctx)` + #[inline] + pub fn n_vocab(&self) -> c_int { + self.ctx.n_vocab() + } + + /// Get n_text_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_n_text_ctx (struct whisper_context * ctx);` + #[inline] + pub fn n_text_ctx(&self) -> c_int { + self.ctx.n_text_ctx() + } + + /// Get n_audio_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_n_audio_ctx (struct whisper_context * ctx);` + #[inline] + pub fn n_audio_ctx(&self) -> c_int { + self.ctx.n_audio_ctx() + } + + /// Does this model support multiple languages? + /// + /// # C++ equivalent + /// `int whisper_is_multilingual(struct whisper_context * ctx)` + #[inline] + pub fn is_multilingual(&self) -> bool { + self.ctx.is_multilingual() + } + + /// Get model_n_vocab. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_vocab (struct whisper_context * ctx);` + #[inline] + pub fn model_n_vocab(&self) -> c_int { + self.ctx.model_n_vocab() + } + + /// Get model_n_audio_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)` + #[inline] + pub fn model_n_audio_ctx(&self) -> c_int { + self.ctx.model_n_audio_ctx() + } + + /// Get model_n_audio_state. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_state(struct whisper_context * ctx);` + #[inline] + pub fn model_n_audio_state(&self) -> c_int { + self.ctx.model_n_audio_state() + } + + /// Get model_n_audio_head. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_head (struct whisper_context * ctx);` + #[inline] + pub fn model_n_audio_head(&self) -> c_int { + self.ctx.model_n_audio_head() + } + + /// Get model_n_audio_layer. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_layer(struct whisper_context * ctx);` + #[inline] + pub fn model_n_audio_layer(&self) -> c_int { + self.ctx.model_n_audio_layer() + } + + /// Get model_n_text_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_ctx (struct whisper_context * ctx)` + #[inline] + pub fn model_n_text_ctx(&self) -> c_int { + self.ctx.model_n_text_ctx() + } + + /// Get model_n_text_state. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_state (struct whisper_context * ctx);` + #[inline] + pub fn model_n_text_state(&self) -> c_int { + self.ctx.model_n_text_state() + } + + /// Get model_n_text_head. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_head (struct whisper_context * ctx);` + #[inline] + pub fn model_n_text_head(&self) -> c_int { + self.ctx.model_n_text_head() + } + + /// Get model_n_text_layer. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_layer (struct whisper_context * ctx);` + #[inline] + pub fn model_n_text_layer(&self) -> c_int { + self.ctx.model_n_text_layer() + } + + /// Get model_n_mels. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_mels (struct whisper_context * ctx);` + #[inline] + pub fn model_n_mels(&self) -> c_int { + self.ctx.model_n_mels() + } + + /// Get model_ftype. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_ftype (struct whisper_context * ctx);` + #[inline] + pub fn model_ftype(&self) -> c_int { + self.ctx.model_ftype() + } + + /// Get model_type. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_type (struct whisper_context * ctx);` + #[inline] + pub fn model_type(&self) -> c_int { + self.ctx.model_type() + } + + // token functions + /// Convert a token ID to a string. + /// + /// # Arguments + /// * token_id: ID of the token. + /// + /// # Returns + /// Ok(&str) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` + pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> { + self.ctx.token_to_str(token_id) + } + + /// Convert a token ID to a &CStr. + /// + /// # Arguments + /// * token_id: ID of the token. + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` + pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> { + self.ctx.token_to_cstr(token_id) + } + + /// Undocumented but exposed function in the C++ API. + /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + pub fn model_type_readable(&self) -> Result { + self.ctx.model_type_readable() + } + + /// Get the ID of the eot token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_eot (struct whisper_context * ctx)` + #[inline] + pub fn token_eot(&self) -> WhisperToken { + self.ctx.token_eot() + } + + /// Get the ID of the sot token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_sot (struct whisper_context * ctx)` + #[inline] + pub fn token_sot(&self) -> WhisperToken { + self.ctx.token_sot() + } + + /// Get the ID of the solm token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_solm(struct whisper_context * ctx)` + #[inline] + pub fn token_solm(&self) -> WhisperToken { + self.ctx.token_solm() + } + + /// Get the ID of the prev token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_prev(struct whisper_context * ctx)` + #[inline] + pub fn token_prev(&self) -> WhisperToken { + self.ctx.token_prev() + } + + /// Get the ID of the nosp token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_nosp(struct whisper_context * ctx)` + #[inline] + pub fn token_nosp(&self) -> WhisperToken { + self.ctx.token_nosp() + } + + /// Get the ID of the not token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_not (struct whisper_context * ctx)` + #[inline] + pub fn token_not(&self) -> WhisperToken { + self.ctx.token_not() + } + + /// Get the ID of the beg token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_beg (struct whisper_context * ctx)` + #[inline] + pub fn token_beg(&self) -> WhisperToken { + self.ctx.token_beg() + } + + /// Get the ID of a specified language token + /// + /// # Arguments + /// * lang_id: ID of the language + /// + /// # C++ equivalent + /// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)` + #[inline] + pub fn token_lang(&self, lang_id: c_int) -> WhisperToken { + self.ctx.token_lang(lang_id) + } + + /// Print performance statistics to stderr. + /// + /// # C++ equivalent + /// `void whisper_print_timings(struct whisper_context * ctx)` + #[inline] + pub fn print_timings(&self) { + self.ctx.print_timings() + } + + /// Reset performance statistics. + /// + /// # C++ equivalent + /// `void whisper_reset_timings(struct whisper_context * ctx)` + #[inline] + pub fn reset_timings(&self) { + self.ctx.reset_timings() + } + + // task tokens + /// Get the ID of the translate task token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_translate ()` + pub fn token_translate(&self) -> WhisperToken { + self.ctx.token_translate() + } + + /// Get the ID of the transcribe task token. + /// + /// # C++ equivalent + /// `whisper_token whisper_token_transcribe()` + pub fn token_transcribe(&self) -> WhisperToken { + self.ctx.token_transcribe() + } + + // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does + + /// Create a new state object, ready for use. + /// + /// # Returns + /// Ok(WhisperState) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` + pub fn create_state(&self) -> Result { + let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx.ctx) }; + if state.is_null() { + Err(WhisperError::InitError) + } else { + // SAFETY: this is known to be a valid pointer to a `whisper_state` struct + Ok(WhisperState::new(self.ctx.clone(), state)) + } + } +} diff --git a/src/whisper_state.rs b/src/whisper_state.rs index d9b02c3..6805aa5 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -1,19 +1,20 @@ -use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData}; use std::ffi::{c_int, CStr}; -use std::marker::PhantomData; +use std::sync::Arc; + +use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData}; /// Rustified pointer to a Whisper state. #[derive(Debug)] -pub struct WhisperState<'a> { - ctx: *mut whisper_rs_sys::whisper_context, +pub struct WhisperState { + ctx: Arc, ptr: *mut whisper_rs_sys::whisper_state, - _phantom: PhantomData<&'a WhisperContext>, } -unsafe impl<'a> Send for WhisperState<'a> {} -unsafe impl<'a> Sync for WhisperState<'a> {} +unsafe impl Send for WhisperState {} -impl<'a> Drop for WhisperState<'a> { +unsafe impl Sync for WhisperState {} + +impl Drop for WhisperState { fn drop(&mut self) { unsafe { whisper_rs_sys::whisper_free_state(self.ptr); @@ -21,16 +22,12 @@ impl<'a> Drop for WhisperState<'a> { } } -impl<'a> WhisperState<'a> { +impl WhisperState { pub(crate) fn new( - ctx: *mut whisper_rs_sys::whisper_context, + ctx: Arc, ptr: *mut whisper_rs_sys::whisper_state, ) -> Self { - Self { - ctx, - ptr, - _phantom: PhantomData, - } + Self { ctx, ptr } } /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. @@ -51,7 +48,7 @@ impl<'a> WhisperState<'a> { } let ret = unsafe { whisper_rs_sys::whisper_pcm_to_mel_with_state( - self.ctx, + self.ctx.ctx, self.ptr, pcm.as_ptr(), pcm.len() as c_int, @@ -90,7 +87,7 @@ impl<'a> WhisperState<'a> { } let ret = unsafe { whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state( - self.ctx, + self.ctx.ctx, self.ptr, pcm.as_ptr(), pcm.len() as c_int, @@ -127,7 +124,7 @@ impl<'a> WhisperState<'a> { let n_len = (data.len() / hop_size) * 2; let ret = unsafe { whisper_rs_sys::whisper_set_mel_with_state( - self.ctx, + self.ctx.ctx, self.ptr, data.as_ptr(), n_len as c_int, @@ -161,7 +158,7 @@ impl<'a> WhisperState<'a> { } let ret = unsafe { whisper_rs_sys::whisper_encode_with_state( - self.ctx, + self.ctx.ctx, self.ptr, offset as c_int, threads as c_int, @@ -202,7 +199,7 @@ impl<'a> WhisperState<'a> { } let ret = unsafe { whisper_rs_sys::whisper_decode_with_state( - self.ctx, + self.ctx.ctx, self.ptr, tokens.as_ptr(), tokens.len() as c_int, @@ -240,7 +237,7 @@ impl<'a> WhisperState<'a> { let mut lang_probs: Vec = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; let ret = unsafe { whisper_rs_sys::whisper_lang_auto_detect_with_state( - self.ctx, + self.ctx.ctx, self.ptr, offset_ms as c_int, threads as c_int, @@ -309,7 +306,7 @@ impl<'a> WhisperState<'a> { /// `int whisper_n_vocab (struct whisper_context * ctx)` #[inline] pub fn n_vocab(&self) -> c_int { - unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) } + unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx.ctx) } } /// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text @@ -335,7 +332,7 @@ impl<'a> WhisperState<'a> { let ret = unsafe { whisper_rs_sys::whisper_full_with_state( - self.ctx, + self.ctx.ctx, self.ptr, params.fp, data.as_ptr(), @@ -495,7 +492,10 @@ impl<'a> WhisperState<'a> { ) -> Result { let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx, self.ptr, segment, token, + self.ctx.ctx, + self.ptr, + segment, + token, ) }; if ret.is_null() { @@ -527,7 +527,10 @@ impl<'a> WhisperState<'a> { ) -> Result { let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx, self.ptr, segment, token, + self.ctx.ctx, + self.ptr, + segment, + token, ) }; if ret.is_null() {