Merge branch 'tazz4843:master' into linux-cuda-build

This commit is contained in:
arizhih 2024-05-07 13:14:14 +02:00 committed by GitHub
commit 3f7e43252c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 468 additions and 113 deletions

View file

@ -43,7 +43,7 @@ fn main() {
} }
let original_samples = parse_wav_file(audio_path); let original_samples = parse_wav_file(audio_path);
let mut samples = Vec::with_capacity(original_samples.len()); let mut samples = vec![0.0f32; original_samples.len()];
whisper_rs::convert_integer_to_float_audio(&original_samples, &mut samples) whisper_rs::convert_integer_to_float_audio(&original_samples, &mut samples)
.expect("failed to convert samples"); .expect("failed to convert samples");

View file

@ -5,6 +5,7 @@ mod error;
mod standalone; mod standalone;
mod utilities; mod utilities;
mod whisper_ctx; mod whisper_ctx;
mod whisper_ctx_wrapper;
mod whisper_grammar; mod whisper_grammar;
mod whisper_params; mod whisper_params;
mod whisper_state; mod whisper_state;
@ -21,8 +22,9 @@ pub use standalone::*;
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))] #[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
use std::sync::Once; use std::sync::Once;
pub use utilities::*; pub use utilities::*;
pub use whisper_ctx::WhisperContext;
pub use whisper_ctx::WhisperContextParameters; pub use whisper_ctx::WhisperContextParameters;
use whisper_ctx::WhisperInnerContext;
pub use whisper_ctx_wrapper::WhisperContext;
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
pub use whisper_params::{FullParams, SamplingStrategy}; pub use whisper_params::{FullParams, SamplingStrategy};
#[cfg(feature = "raw-api")] #[cfg(feature = "raw-api")]

View file

@ -1,19 +1,18 @@
use crate::error::WhisperError; use crate::error::WhisperError;
use crate::whisper_state::WhisperState;
use crate::WhisperToken; use crate::WhisperToken;
use std::ffi::{c_int, CStr, CString}; use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context. /// Safe Rust wrapper around a Whisper context.
/// ///
/// You likely want to create this with [WhisperContext::new_with_params], /// You likely want to create this with [WhisperInnerContext::new_with_params],
/// create a state with [WhisperContext::create_state], /// create a state with [WhisperInnerContext::create_state],
/// then run a full transcription with [WhisperState::full]. /// then run a full transcription with [WhisperState::full].
#[derive(Debug)] #[derive(Debug)]
pub struct WhisperContext { pub struct WhisperInnerContext {
ctx: *mut whisper_rs_sys::whisper_context, pub(crate) ctx: *mut whisper_rs_sys::whisper_context,
} }
impl WhisperContext { impl WhisperInnerContext {
/// Create a new WhisperContext from a file, with parameters. /// Create a new WhisperContext from a file, with parameters.
/// ///
/// # Arguments /// # Arguments
@ -71,68 +70,6 @@ impl WhisperContext {
} }
} }
/// Create a new WhisperContext from a file.
///
/// # Arguments
/// * path: The path to the model file.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)`
#[deprecated = "Use `new_with_params` instead"]
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
/// Create a new WhisperContext from a buffer.
///
/// # Arguments
/// * buffer: The buffer containing the model.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)`
#[deprecated = "Use `new_from_buffer_with_params` instead"]
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Create a new state object, ready for use.
///
/// # Returns
/// Ok(WhisperState) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
pub fn create_state(&self) -> Result<WhisperState, WhisperError> {
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) };
if state.is_null() {
Err(WhisperError::InitError)
} else {
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
Ok(WhisperState::new(self.ctx, state))
}
}
/// Convert the provided text into tokens. /// Convert the provided text into tokens.
/// ///
/// # Arguments /// # Arguments
@ -518,23 +455,9 @@ impl WhisperContext {
pub fn token_transcribe(&self) -> WhisperToken { pub fn token_transcribe(&self) -> WhisperToken {
unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) } unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) }
} }
/// Get whether the next segment is predicted as a speaker turn
///
/// # Arguments
/// * i_segment: Segment index.
///
/// # Returns
/// bool
///
/// # C++ equivalent
/// `bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next(self.ctx, i_segment) }
}
} }
impl Drop for WhisperContext { impl Drop for WhisperInnerContext {
#[inline] #[inline]
fn drop(&mut self) { fn drop(&mut self) {
unsafe { whisper_rs_sys::whisper_free(self.ctx) }; unsafe { whisper_rs_sys::whisper_free(self.ctx) };
@ -543,8 +466,8 @@ impl Drop for WhisperContext {
// following implementations are safe // following implementations are safe
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
unsafe impl Send for WhisperContext {} unsafe impl Send for WhisperInnerContext {}
unsafe impl Sync for WhisperContext {} unsafe impl Sync for WhisperInnerContext {}
pub struct WhisperContextParameters { pub struct WhisperContextParameters {
/// Use GPU if available. /// Use GPU if available.
@ -588,7 +511,7 @@ mod test_with_tiny_model {
#[test] #[test]
fn test_tokenize_round_trip() { fn test_tokenize_round_trip() {
let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'"); let ctx = WhisperInnerContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."; let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.";
let tokens = ctx.tokenize(text_in, 1024).unwrap(); let tokens = ctx.tokenize(text_in, 1024).unwrap();
let text_out = tokens let text_out = tokens

427
src/whisper_ctx_wrapper.rs Normal file
View file

@ -0,0 +1,427 @@
use std::ffi::{c_int, CStr};
use std::sync::Arc;
use crate::{
WhisperContextParameters, WhisperError, WhisperInnerContext, WhisperState, WhisperToken,
};
pub struct WhisperContext {
ctx: Arc<WhisperInnerContext>,
}
impl WhisperContext {
fn wrap(ctx: WhisperInnerContext) -> Self {
Self { ctx: Arc::new(ctx) }
}
/// Create a new WhisperContext from a file, with parameters.
///
/// # Arguments
/// * path: The path to the model file.
/// * parameters: A parameter struct containing the parameters to use.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params(
path: &str,
parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> {
let ctx = WhisperInnerContext::new_with_params(path, parameters)?;
Ok(Self::wrap(ctx))
}
/// Create a new WhisperContext from a buffer.
///
/// # Arguments
/// * buffer: The buffer containing the model.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);`
pub fn new_from_buffer_with_params(
buffer: &[u8],
parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> {
let ctx = WhisperInnerContext::new_from_buffer_with_params(buffer, parameters)?;
Ok(Self::wrap(ctx))
}
/// Convert the provided text into tokens.
///
/// # Arguments
/// * text: The text to convert.
///
/// # Returns
/// `Ok(Vec<WhisperToken>)` on success, `Err(WhisperError)` on failure.
///
/// # C++ equivalent
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
pub fn tokenize(
&self,
text: &str,
max_tokens: usize,
) -> Result<Vec<WhisperToken>, WhisperError> {
self.ctx.tokenize(text, max_tokens)
}
/// Get n_vocab.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_n_vocab (struct whisper_context * ctx)`
#[inline]
pub fn n_vocab(&self) -> c_int {
self.ctx.n_vocab()
}
/// Get n_text_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_n_text_ctx (struct whisper_context * ctx);`
#[inline]
pub fn n_text_ctx(&self) -> c_int {
self.ctx.n_text_ctx()
}
/// Get n_audio_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_n_audio_ctx (struct whisper_context * ctx);`
#[inline]
pub fn n_audio_ctx(&self) -> c_int {
self.ctx.n_audio_ctx()
}
/// Does this model support multiple languages?
///
/// # C++ equivalent
/// `int whisper_is_multilingual(struct whisper_context * ctx)`
#[inline]
pub fn is_multilingual(&self) -> bool {
self.ctx.is_multilingual()
}
/// Get model_n_vocab.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_vocab (struct whisper_context * ctx);`
#[inline]
pub fn model_n_vocab(&self) -> c_int {
self.ctx.model_n_vocab()
}
/// Get model_n_audio_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)`
#[inline]
pub fn model_n_audio_ctx(&self) -> c_int {
self.ctx.model_n_audio_ctx()
}
/// Get model_n_audio_state.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_state(struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_state(&self) -> c_int {
self.ctx.model_n_audio_state()
}
/// Get model_n_audio_head.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_head (struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_head(&self) -> c_int {
self.ctx.model_n_audio_head()
}
/// Get model_n_audio_layer.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_layer(struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_layer(&self) -> c_int {
self.ctx.model_n_audio_layer()
}
/// Get model_n_text_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_ctx (struct whisper_context * ctx)`
#[inline]
pub fn model_n_text_ctx(&self) -> c_int {
self.ctx.model_n_text_ctx()
}
/// Get model_n_text_state.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_state (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_state(&self) -> c_int {
self.ctx.model_n_text_state()
}
/// Get model_n_text_head.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_head (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_head(&self) -> c_int {
self.ctx.model_n_text_head()
}
/// Get model_n_text_layer.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_layer (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_layer(&self) -> c_int {
self.ctx.model_n_text_layer()
}
/// Get model_n_mels.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_mels (struct whisper_context * ctx);`
#[inline]
pub fn model_n_mels(&self) -> c_int {
self.ctx.model_n_mels()
}
/// Get model_ftype.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_ftype (struct whisper_context * ctx);`
#[inline]
pub fn model_ftype(&self) -> c_int {
self.ctx.model_ftype()
}
/// Get model_type.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_type (struct whisper_context * ctx);`
#[inline]
pub fn model_type(&self) -> c_int {
self.ctx.model_type()
}
// token functions
/// Convert a token ID to a string.
///
/// # Arguments
/// * token_id: ID of the token.
///
/// # Returns
/// Ok(&str) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<&str, WhisperError> {
self.ctx.token_to_str(token_id)
}
/// Convert a token ID to a &CStr.
///
/// # Arguments
/// * token_id: ID of the token.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_cstr(&self, token_id: WhisperToken) -> Result<&CStr, WhisperError> {
self.ctx.token_to_cstr(token_id)
}
/// Undocumented but exposed function in the C++ API.
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
self.ctx.model_type_readable()
}
/// Get the ID of the eot token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
#[inline]
pub fn token_eot(&self) -> WhisperToken {
self.ctx.token_eot()
}
/// Get the ID of the sot token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
#[inline]
pub fn token_sot(&self) -> WhisperToken {
self.ctx.token_sot()
}
/// Get the ID of the solm token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
#[inline]
pub fn token_solm(&self) -> WhisperToken {
self.ctx.token_solm()
}
/// Get the ID of the prev token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
#[inline]
pub fn token_prev(&self) -> WhisperToken {
self.ctx.token_prev()
}
/// Get the ID of the nosp token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_nosp(struct whisper_context * ctx)`
#[inline]
pub fn token_nosp(&self) -> WhisperToken {
self.ctx.token_nosp()
}
/// Get the ID of the not token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
#[inline]
pub fn token_not(&self) -> WhisperToken {
self.ctx.token_not()
}
/// Get the ID of the beg token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
#[inline]
pub fn token_beg(&self) -> WhisperToken {
self.ctx.token_beg()
}
/// Get the ID of a specified language token
///
/// # Arguments
/// * lang_id: ID of the language
///
/// # C++ equivalent
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
#[inline]
pub fn token_lang(&self, lang_id: c_int) -> WhisperToken {
self.ctx.token_lang(lang_id)
}
/// Print performance statistics to stderr.
///
/// # C++ equivalent
/// `void whisper_print_timings(struct whisper_context * ctx)`
#[inline]
pub fn print_timings(&self) {
self.ctx.print_timings()
}
/// Reset performance statistics.
///
/// # C++ equivalent
/// `void whisper_reset_timings(struct whisper_context * ctx)`
#[inline]
pub fn reset_timings(&self) {
self.ctx.reset_timings()
}
// task tokens
/// Get the ID of the translate task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_translate ()`
pub fn token_translate(&self) -> WhisperToken {
self.ctx.token_translate()
}
/// Get the ID of the transcribe task token.
///
/// # C++ equivalent
/// `whisper_token whisper_token_transcribe()`
pub fn token_transcribe(&self) -> WhisperToken {
self.ctx.token_transcribe()
}
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Create a new state object, ready for use.
///
/// # Returns
/// Ok(WhisperState) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
pub fn create_state(&self) -> Result<WhisperState, WhisperError> {
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx.ctx) };
if state.is_null() {
Err(WhisperError::InitError)
} else {
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
Ok(WhisperState::new(self.ctx.clone(), state))
}
}
}

View file

@ -1,19 +1,20 @@
use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData};
use std::ffi::{c_int, CStr}; use std::ffi::{c_int, CStr};
use std::marker::PhantomData; use std::sync::Arc;
use crate::{FullParams, WhisperError, WhisperInnerContext, WhisperToken, WhisperTokenData};
/// Rustified pointer to a Whisper state. /// Rustified pointer to a Whisper state.
#[derive(Debug)] #[derive(Debug)]
pub struct WhisperState<'a> { pub struct WhisperState {
ctx: *mut whisper_rs_sys::whisper_context, ctx: Arc<WhisperInnerContext>,
ptr: *mut whisper_rs_sys::whisper_state, ptr: *mut whisper_rs_sys::whisper_state,
_phantom: PhantomData<&'a WhisperContext>,
} }
unsafe impl<'a> Send for WhisperState<'a> {} unsafe impl Send for WhisperState {}
unsafe impl<'a> Sync for WhisperState<'a> {}
impl<'a> Drop for WhisperState<'a> { unsafe impl Sync for WhisperState {}
impl Drop for WhisperState {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { unsafe {
whisper_rs_sys::whisper_free_state(self.ptr); whisper_rs_sys::whisper_free_state(self.ptr);
@ -21,16 +22,12 @@ impl<'a> Drop for WhisperState<'a> {
} }
} }
impl<'a> WhisperState<'a> { impl WhisperState {
pub(crate) fn new( pub(crate) fn new(
ctx: *mut whisper_rs_sys::whisper_context, ctx: Arc<WhisperInnerContext>,
ptr: *mut whisper_rs_sys::whisper_state, ptr: *mut whisper_rs_sys::whisper_state,
) -> Self { ) -> Self {
Self { Self { ctx, ptr }
ctx,
ptr,
_phantom: PhantomData,
}
} }
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
@ -51,7 +48,7 @@ impl<'a> WhisperState<'a> {
} }
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel_with_state( whisper_rs_sys::whisper_pcm_to_mel_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
pcm.as_ptr(), pcm.as_ptr(),
pcm.len() as c_int, pcm.len() as c_int,
@ -90,7 +87,7 @@ impl<'a> WhisperState<'a> {
} }
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state( whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
pcm.as_ptr(), pcm.as_ptr(),
pcm.len() as c_int, pcm.len() as c_int,
@ -127,7 +124,7 @@ impl<'a> WhisperState<'a> {
let n_len = (data.len() / hop_size) * 2; let n_len = (data.len() / hop_size) * 2;
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_set_mel_with_state( whisper_rs_sys::whisper_set_mel_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
data.as_ptr(), data.as_ptr(),
n_len as c_int, n_len as c_int,
@ -161,7 +158,7 @@ impl<'a> WhisperState<'a> {
} }
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_encode_with_state( whisper_rs_sys::whisper_encode_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
offset as c_int, offset as c_int,
threads as c_int, threads as c_int,
@ -202,7 +199,7 @@ impl<'a> WhisperState<'a> {
} }
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_decode_with_state( whisper_rs_sys::whisper_decode_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
tokens.as_ptr(), tokens.as_ptr(),
tokens.len() as c_int, tokens.len() as c_int,
@ -240,7 +237,7 @@ impl<'a> WhisperState<'a> {
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_lang_auto_detect_with_state( whisper_rs_sys::whisper_lang_auto_detect_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
offset_ms as c_int, offset_ms as c_int,
threads as c_int, threads as c_int,
@ -309,7 +306,7 @@ impl<'a> WhisperState<'a> {
/// `int whisper_n_vocab (struct whisper_context * ctx)` /// `int whisper_n_vocab (struct whisper_context * ctx)`
#[inline] #[inline]
pub fn n_vocab(&self) -> c_int { pub fn n_vocab(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) } unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx.ctx) }
} }
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text /// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
@ -335,7 +332,7 @@ impl<'a> WhisperState<'a> {
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_full_with_state( whisper_rs_sys::whisper_full_with_state(
self.ctx, self.ctx.ctx,
self.ptr, self.ptr,
params.fp, params.fp,
data.as_ptr(), data.as_ptr(),
@ -495,7 +492,10 @@ impl<'a> WhisperState<'a> {
) -> Result<String, WhisperError> { ) -> Result<String, WhisperError> {
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state( whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx, self.ptr, segment, token, self.ctx.ctx,
self.ptr,
segment,
token,
) )
}; };
if ret.is_null() { if ret.is_null() {
@ -527,7 +527,10 @@ impl<'a> WhisperState<'a> {
) -> Result<String, WhisperError> { ) -> Result<String, WhisperError> {
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state( whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx, self.ptr, segment, token, self.ctx.ctx,
self.ptr,
segment,
token,
) )
}; };
if ret.is_null() { if ret.is_null() {