From 9d978add9dabeff58f41c0037045fa79def2dc14 Mon Sep 17 00:00:00 2001 From: jiahua Date: Thu, 25 Apr 2024 15:33:06 +0800 Subject: [PATCH] refactor(state): create_state from whisper context wrapper --- src/lib.rs | 2 ++ src/whisper_ctx.rs | 20 +------------------- src/whisper_ctx_wrapper.rs | 37 +++++++++++++++++++++++++++++++++++++ 3 files changed, 40 insertions(+), 19 deletions(-) create mode 100644 src/whisper_ctx_wrapper.rs diff --git a/src/lib.rs b/src/lib.rs index a6da664..d9b75cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,6 +12,7 @@ mod whisper_state; mod whisper_sys_log; #[cfg(feature = "whisper-cpp-tracing")] mod whisper_sys_tracing; +mod whisper_ctx_wrapper; #[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))] static LOG_TRAMPOLINE_INSTALL: Once = Once::new(); @@ -22,6 +23,7 @@ pub use standalone::*; use std::sync::Once; pub use utilities::*; pub use whisper_ctx::WhisperContext; +pub use whisper_ctx_wrapper::WhisperContextWrapper; pub use whisper_ctx::WhisperContextParameters; pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; pub use whisper_params::{FullParams, SamplingStrategy}; diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 8d948cd..6f71af3 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -10,7 +10,7 @@ use std::ffi::{c_int, CStr, CString}; /// then run a full transcription with [WhisperState::full]. #[derive(Debug)] pub struct WhisperContext { - ctx: *mut whisper_rs_sys::whisper_context, + pub(crate) ctx: *mut whisper_rs_sys::whisper_context, } impl WhisperContext { @@ -114,24 +114,6 @@ impl WhisperContext { } } - // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does - - /// Create a new state object, ready for use. - /// - /// # Returns - /// Ok(WhisperState) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` - pub fn create_state(&self) -> Result { - let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) }; - if state.is_null() { - Err(WhisperError::InitError) - } else { - // SAFETY: this is known to be a valid pointer to a `whisper_state` struct - Ok(WhisperState::new(self.ctx, state)) - } - } /// Convert the provided text into tokens. /// diff --git a/src/whisper_ctx_wrapper.rs b/src/whisper_ctx_wrapper.rs new file mode 100644 index 0000000..74fa4d0 --- /dev/null +++ b/src/whisper_ctx_wrapper.rs @@ -0,0 +1,37 @@ +use std::sync::Arc; + +use crate::{WhisperContext, WhisperContextParameters, WhisperError, WhisperState}; + +pub struct WhisperContextWrapper { + ctx: Arc, +} + +impl WhisperContextWrapper { + /// wrapper of WhisperContext::new_with_params. + pub fn new_with_params( + path: &str, + parameters: WhisperContextParameters, + ) -> Result { + let ctx = WhisperContext::new_with_params(path, parameters)?; + Ok(Self { ctx: Arc::new(ctx) }) + } + + // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does + + /// Create a new state object, ready for use. + /// + /// # Returns + /// Ok(WhisperState) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` + pub fn create_state(&self) -> Result { + let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx.ctx) }; + if state.is_null() { + Err(WhisperError::InitError) + } else { + // SAFETY: this is known to be a valid pointer to a `whisper_state` struct + Ok(WhisperState::new(self.ctx.clone(), state)) + } + } +} \ No newline at end of file