Merge branch 'tazz4843:master' into linux-cuda-build
This commit is contained in:
commit
3f7e43252c
5 changed files with 468 additions and 113 deletions
|
|
@ -43,7 +43,7 @@ fn main() {
|
|||
}
|
||||
|
||||
let original_samples = parse_wav_file(audio_path);
|
||||
let mut samples = Vec::with_capacity(original_samples.len());
|
||||
let mut samples = vec![0.0f32; original_samples.len()];
|
||||
whisper_rs::convert_integer_to_float_audio(&original_samples, &mut samples)
|
||||
.expect("failed to convert samples");
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ mod error;
|
|||
mod standalone;
|
||||
mod utilities;
|
||||
mod whisper_ctx;
|
||||
mod whisper_ctx_wrapper;
|
||||
mod whisper_grammar;
|
||||
mod whisper_params;
|
||||
mod whisper_state;
|
||||
|
|
@ -21,8 +22,9 @@ pub use standalone::*;
|
|||
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
|
||||
use std::sync::Once;
|
||||
pub use utilities::*;
|
||||
pub use whisper_ctx::WhisperContext;
|
||||
pub use whisper_ctx::WhisperContextParameters;
|
||||
use whisper_ctx::WhisperInnerContext;
|
||||
pub use whisper_ctx_wrapper::WhisperContext;
|
||||
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
||||
pub use whisper_params::{FullParams, SamplingStrategy};
|
||||
#[cfg(feature = "raw-api")]
|
||||
|
|
|
|||
|
|
@ -1,19 +1,18 @@
|
|||
use crate::error::WhisperError;
|
||||
use crate::whisper_state::WhisperState;
|
||||
use crate::WhisperToken;
|
||||
use std::ffi::{c_int, CStr, CString};
|
||||
|
||||
/// Safe Rust wrapper around a Whisper context.
|
||||
///
|
||||
/// You likely want to create this with [WhisperContext::new_with_params],
|
||||
/// create a state with [WhisperContext::create_state],
|
||||
/// You likely want to create this with [WhisperInnerContext::new_with_params],
|
||||
/// create a state with [WhisperInnerContext::create_state],
|
||||
/// then run a full transcription with [WhisperState::full].
|
||||
#[derive(Debug)]
|
||||
pub struct WhisperContext {
|
||||
ctx: *mut whisper_rs_sys::whisper_context,
|
||||
pub struct WhisperInnerContext {
|
||||
pub(crate) ctx: *mut whisper_rs_sys::whisper_context,
|
||||
}
|
||||
|
||||
impl WhisperContext {
|
||||
impl WhisperInnerContext {
|
||||
/// Create a new WhisperContext from a file, with parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
@ -71,68 +70,6 @@ impl WhisperContext {
|
|||
}
|
||||
}
|
||||
|
||||
/// Create a new WhisperContext from a file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * path: The path to the model file.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)`
|
||||
#[deprecated = "Use `new_with_params` instead"]
|
||||
pub fn new(path: &str) -> Result<Self, WhisperError> {
|
||||
let path_cstr = CString::new(path)?;
|
||||
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
|
||||
if ctx.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
Ok(Self { ctx })
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new WhisperContext from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * buffer: The buffer containing the model.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)`
|
||||
#[deprecated = "Use `new_from_buffer_with_params` instead"]
|
||||
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
|
||||
let ctx = unsafe {
|
||||
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
|
||||
};
|
||||
if ctx.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
Ok(Self { ctx })
|
||||
}
|
||||
}
|
||||
|
||||
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
|
||||
|
||||
/// Create a new state object, ready for use.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(WhisperState) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
|
||||
pub fn create_state(&self) -> Result<WhisperState, WhisperError> {
|
||||
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) };
|
||||
if state.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
|
||||
Ok(WhisperState::new(self.ctx, state))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the provided text into tokens.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
@ -518,23 +455,9 @@ impl WhisperContext {
|
|||
pub fn token_transcribe(&self) -> WhisperToken {
|
||||
unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) }
|
||||
}
|
||||
|
||||
/// Get whether the next segment is predicted as a speaker turn
|
||||
///
|
||||
/// # Arguments
|
||||
/// * i_segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// bool
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment)`
|
||||
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
|
||||
unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next(self.ctx, i_segment) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WhisperContext {
|
||||
impl Drop for WhisperInnerContext {
|
||||
#[inline]
|
||||
fn drop(&mut self) {
|
||||
unsafe { whisper_rs_sys::whisper_free(self.ctx) };
|
||||
|
|
@ -543,8 +466,8 @@ impl Drop for WhisperContext {
|
|||
|
||||
// following implementations are safe
|
||||
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
|
||||
unsafe impl Send for WhisperContext {}
|
||||
unsafe impl Sync for WhisperContext {}
|
||||
unsafe impl Send for WhisperInnerContext {}
|
||||
unsafe impl Sync for WhisperInnerContext {}
|
||||
|
||||
pub struct WhisperContextParameters {
|
||||
/// Use GPU if available.
|
||||
|
|
@ -588,7 +511,7 @@ mod test_with_tiny_model {
|
|||
|
||||
#[test]
|
||||
fn test_tokenize_round_trip() {
|
||||
let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
|
||||
let ctx = WhisperInnerContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
|
||||
let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.";
|
||||
let tokens = ctx.tokenize(text_in, 1024).unwrap();
|
||||
let text_out = tokens
|
||||
|
|
|
|||
427
src/whisper_ctx_wrapper.rs
Normal file
427
src/whisper_ctx_wrapper.rs
Normal file
|
|
@ -0,0 +1,427 @@
|
|||
use std::ffi::{c_int, CStr};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperToken,
|
||||
};
|
||||
|
||||
pub struct WhisperContext {
|
||||
ctx: Arc<WhisperInnerContext>,
|
||||
}
|
||||
|
||||
impl WhisperContext {
|
||||
fn wrap(ctx: WhisperInnerContext) -> Self {
|
||||
Self { ctx: Arc::new(ctx) }
|
||||
}
|
||||
|
||||
/// Create a new WhisperContext from a file, with parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * path: The path to the model file.
|
||||
/// * parameters: A parameter struct containing the parameters to use.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
|
||||
pub fn new_with_params(
|
||||
path: &str,
|
||||
parameters: WhisperContextParameters,
|
||||
) -> Result<Self, WhisperError> {
|
||||
let ctx = WhisperInnerContext::new_with_params(path, parameters)?;
|
||||
Ok(Self::wrap(ctx))
|
||||
}
|
||||
|
||||
/// Create a new WhisperContext from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * buffer: The buffer containing the model.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);`
|
||||
pub fn new_from_buffer_with_params(
|
||||
buffer: &[u8],
|
||||
parameters: WhisperContextParameters,
|
||||
) -> Result<Self, WhisperError> {
|
||||
let ctx = WhisperInnerContext::new_from_buffer_with_params(buffer, parameters)?;
|
||||
Ok(Self::wrap(ctx))
|
||||
}
|
||||
|
||||
/// Convert the provided text into tokens.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * text: The text to convert.
|
||||
///
|
||||
/// # Returns
|
||||
/// `Ok(Vec<WhisperToken>)` on success, `Err(WhisperError)` on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
|
||||
pub fn tokenize(
|
||||
&self,
|
||||
text: &str,
|
||||
max_tokens: usize,
|
||||
) -> Result<Vec<WhisperToken>, WhisperError> {
|
||||
self.ctx.tokenize(text, max_tokens)
|
||||
}
|
||||
|
||||
/// Get n_vocab.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_n_vocab (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn n_vocab(&self) -> c_int {
|
||||
self.ctx.n_vocab()
|
||||
}
|
||||
|
||||
/// Get n_text_ctx.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_n_text_ctx (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn n_text_ctx(&self) -> c_int {
|
||||
self.ctx.n_text_ctx()
|
||||
}
|
||||
|
||||
/// Get n_audio_ctx.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_n_audio_ctx (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn n_audio_ctx(&self) -> c_int {
|
||||
self.ctx.n_audio_ctx()
|
||||
}
|
||||
|
||||
/// Does this model support multiple languages?
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_is_multilingual(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn is_multilingual(&self) -> bool {
|
||||
self.ctx.is_multilingual()
|
||||
}
|
||||
|
||||
/// Get model_n_vocab.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_vocab (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_vocab(&self) -> c_int {
|
||||
self.ctx.model_n_vocab()
|
||||
}
|
||||
|
||||
/// Get model_n_audio_ctx.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn model_n_audio_ctx(&self) -> c_int {
|
||||
self.ctx.model_n_audio_ctx()
|
||||
}
|
||||
|
||||
/// Get model_n_audio_state.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_audio_state(struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_audio_state(&self) -> c_int {
|
||||
self.ctx.model_n_audio_state()
|
||||
}
|
||||
|
||||
/// Get model_n_audio_head.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_audio_head (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_audio_head(&self) -> c_int {
|
||||
self.ctx.model_n_audio_head()
|
||||
}
|
||||
|
||||
/// Get model_n_audio_layer.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_audio_layer(struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_audio_layer(&self) -> c_int {
|
||||
self.ctx.model_n_audio_layer()
|
||||
}
|
||||
|
||||
/// Get model_n_text_ctx.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_text_ctx (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn model_n_text_ctx(&self) -> c_int {
|
||||
self.ctx.model_n_text_ctx()
|
||||
}
|
||||
|
||||
/// Get model_n_text_state.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_text_state (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_text_state(&self) -> c_int {
|
||||
self.ctx.model_n_text_state()
|
||||
}
|
||||
|
||||
/// Get model_n_text_head.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_text_head (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_text_head(&self) -> c_int {
|
||||
self.ctx.model_n_text_head()
|
||||
}
|
||||
|
||||
/// Get model_n_text_layer.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_text_layer (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_text_layer(&self) -> c_int {
|
||||
self.ctx.model_n_text_layer()
|
||||
}
|
||||
|
||||
/// Get model_n_mels.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_n_mels (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_n_mels(&self) -> c_int {
|
||||
self.ctx.model_n_mels()
|
||||
}
|
||||
|
||||
/// Get model_ftype.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_ftype (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_ftype(&self) -> c_int {
|
||||
self.ctx.model_ftype()
|
||||
}
|
||||
|
||||
/// Get model_type.
|
||||
///
|
||||
/// # Returns
|
||||
/// c_int
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_model_type (struct whisper_context * ctx);`
|
||||
#[inline]
|
||||
pub fn model_type(&self) -> c_int {
|
||||
self.ctx.model_type()
|
||||
}
|
||||
|
||||
// token functions
|
||||
/// Convert a token ID to a string.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * token_id: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(&str) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> {
|
||||
self.ctx.token_to_str(token_id)
|
||||
}
|
||||
|
||||
/// Convert a token ID to a &CStr.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * token_id: ID of the token.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
|
||||
pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> {
|
||||
self.ctx.token_to_cstr(token_id)
|
||||
}
|
||||
|
||||
/// Undocumented but exposed function in the C++ API.
|
||||
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
|
||||
self.ctx.model_type_readable()
|
||||
}
|
||||
|
||||
/// Get the ID of the eot token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_eot(&self) -> WhisperToken {
|
||||
self.ctx.token_eot()
|
||||
}
|
||||
|
||||
/// Get the ID of the sot token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_sot(&self) -> WhisperToken {
|
||||
self.ctx.token_sot()
|
||||
}
|
||||
|
||||
/// Get the ID of the solm token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_solm(&self) -> WhisperToken {
|
||||
self.ctx.token_solm()
|
||||
}
|
||||
|
||||
/// Get the ID of the prev token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_prev(&self) -> WhisperToken {
|
||||
self.ctx.token_prev()
|
||||
}
|
||||
|
||||
/// Get the ID of the nosp token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_nosp(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_nosp(&self) -> WhisperToken {
|
||||
self.ctx.token_nosp()
|
||||
}
|
||||
|
||||
/// Get the ID of the not token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_not(&self) -> WhisperToken {
|
||||
self.ctx.token_not()
|
||||
}
|
||||
|
||||
/// Get the ID of the beg token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn token_beg(&self) -> WhisperToken {
|
||||
self.ctx.token_beg()
|
||||
}
|
||||
|
||||
/// Get the ID of a specified language token
|
||||
///
|
||||
/// # Arguments
|
||||
/// * lang_id: ID of the language
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
|
||||
#[inline]
|
||||
pub fn token_lang(&self, lang_id: c_int) -> WhisperToken {
|
||||
self.ctx.token_lang(lang_id)
|
||||
}
|
||||
|
||||
/// Print performance statistics to stderr.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `void whisper_print_timings(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn print_timings(&self) {
|
||||
self.ctx.print_timings()
|
||||
}
|
||||
|
||||
/// Reset performance statistics.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `void whisper_reset_timings(struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn reset_timings(&self) {
|
||||
self.ctx.reset_timings()
|
||||
}
|
||||
|
||||
// task tokens
|
||||
/// Get the ID of the translate task token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_translate ()`
|
||||
pub fn token_translate(&self) -> WhisperToken {
|
||||
self.ctx.token_translate()
|
||||
}
|
||||
|
||||
/// Get the ID of the transcribe task token.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_token_transcribe()`
|
||||
pub fn token_transcribe(&self) -> WhisperToken {
|
||||
self.ctx.token_transcribe()
|
||||
}
|
||||
|
||||
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
|
||||
|
||||
/// Create a new state object, ready for use.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(WhisperState) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
|
||||
pub fn create_state(&self) -> Result<WhisperState, WhisperError> {
|
||||
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx.ctx) };
|
||||
if state.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
|
||||
Ok(WhisperState::new(self.ctx.clone(), state))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,19 +1,20 @@
|
|||
use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData};
|
||||
use std::ffi::{c_int, CStr};
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData};
|
||||
|
||||
/// Rustified pointer to a Whisper state.
|
||||
#[derive(Debug)]
|
||||
pub struct WhisperState<'a> {
|
||||
ctx: *mut whisper_rs_sys::whisper_context,
|
||||
pub struct WhisperState {
|
||||
ctx: Arc<WhisperInnerContext>,
|
||||
ptr: *mut whisper_rs_sys::whisper_state,
|
||||
_phantom: PhantomData<&'a WhisperContext>,
|
||||
}
|
||||
|
||||
unsafe impl<'a> Send for WhisperState<'a> {}
|
||||
unsafe impl<'a> Sync for WhisperState<'a> {}
|
||||
unsafe impl Send for WhisperState {}
|
||||
|
||||
impl<'a> Drop for WhisperState<'a> {
|
||||
unsafe impl Sync for WhisperState {}
|
||||
|
||||
impl Drop for WhisperState {
|
||||
fn drop(&mut self) {
|
||||
unsafe {
|
||||
whisper_rs_sys::whisper_free_state(self.ptr);
|
||||
|
|
@ -21,16 +22,12 @@ impl<'a> Drop for WhisperState<'a> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a> WhisperState<'a> {
|
||||
impl WhisperState {
|
||||
pub(crate) fn new(
|
||||
ctx: *mut whisper_rs_sys::whisper_context,
|
||||
ctx: Arc<WhisperInnerContext>,
|
||||
ptr: *mut whisper_rs_sys::whisper_state,
|
||||
) -> Self {
|
||||
Self {
|
||||
ctx,
|
||||
ptr,
|
||||
_phantom: PhantomData,
|
||||
}
|
||||
Self { ctx, ptr }
|
||||
}
|
||||
|
||||
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
||||
|
|
@ -51,7 +48,7 @@ impl<'a> WhisperState<'a> {
|
|||
}
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_pcm_to_mel_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
pcm.as_ptr(),
|
||||
pcm.len() as c_int,
|
||||
|
|
@ -90,7 +87,7 @@ impl<'a> WhisperState<'a> {
|
|||
}
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
pcm.as_ptr(),
|
||||
pcm.len() as c_int,
|
||||
|
|
@ -127,7 +124,7 @@ impl<'a> WhisperState<'a> {
|
|||
let n_len = (data.len() / hop_size) * 2;
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_set_mel_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
data.as_ptr(),
|
||||
n_len as c_int,
|
||||
|
|
@ -161,7 +158,7 @@ impl<'a> WhisperState<'a> {
|
|||
}
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_encode_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
offset as c_int,
|
||||
threads as c_int,
|
||||
|
|
@ -202,7 +199,7 @@ impl<'a> WhisperState<'a> {
|
|||
}
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_decode_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
tokens.as_ptr(),
|
||||
tokens.len() as c_int,
|
||||
|
|
@ -240,7 +237,7 @@ impl<'a> WhisperState<'a> {
|
|||
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_lang_auto_detect_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
offset_ms as c_int,
|
||||
threads as c_int,
|
||||
|
|
@ -309,7 +306,7 @@ impl<'a> WhisperState<'a> {
|
|||
/// `int whisper_n_vocab (struct whisper_context * ctx)`
|
||||
#[inline]
|
||||
pub fn n_vocab(&self) -> c_int {
|
||||
unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) }
|
||||
unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx.ctx) }
|
||||
}
|
||||
|
||||
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
|
|
@ -335,7 +332,7 @@ impl<'a> WhisperState<'a> {
|
|||
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_full_with_state(
|
||||
self.ctx,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
params.fp,
|
||||
data.as_ptr(),
|
||||
|
|
@ -495,7 +492,10 @@ impl<'a> WhisperState<'a> {
|
|||
) -> Result<String, WhisperError> {
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||
self.ctx, self.ptr, segment, token,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
segment,
|
||||
token,
|
||||
)
|
||||
};
|
||||
if ret.is_null() {
|
||||
|
|
@ -527,7 +527,10 @@ impl<'a> WhisperState<'a> {
|
|||
) -> Result<String, WhisperError> {
|
||||
let ret = unsafe {
|
||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||
self.ctx, self.ptr, segment, token,
|
||||
self.ctx.ctx,
|
||||
self.ptr,
|
||||
segment,
|
||||
token,
|
||||
)
|
||||
};
|
||||
if ret.is_null() {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue