add changes from whisper.cpp update

This commit is contained in:
Zero 2023-04-17 17:57:00 -06:00
parent 7c78c128a1
commit 13d44e5881
No known key found for this signature in database
GPG key ID: 3861E636EA1E0E2B
9 changed files with 536 additions and 140 deletions

View file

@ -15,6 +15,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
[dependencies] [dependencies]
whisper-rs-sys = { path = "sys", version = "0.4" } whisper-rs-sys = { path = "sys", version = "0.4" }
dashmap = "5"
[dev-dependencies] [dev-dependencies]
hound = "3.5.0" hound = "3.5.0"

View file

@ -9,11 +9,12 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() -> Result<(), &'static str> { fn main() -> Result<(), &'static str> {
// Load a context and model. // Load a context and model.
let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
.expect("failed to load model"); .expect("failed to load model");
// Create a single global key.
ctx.create_key(()).expect("failed to create key");
// Create a params object for running the model. // Create a params object for running the model.
// Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP.
// The number of past samples to consider defaults to 0. // The number of past samples to consider defaults to 0.
let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
@ -62,18 +63,27 @@ fn main() -> Result<(), &'static str> {
} }
// Run the model. // Run the model.
ctx.full(params, &audio[..]).expect("failed to run model"); ctx.full(&(), params, &audio[..])
.expect("failed to run model");
// Create a file to write the transcript to. // Create a file to write the transcript to.
let mut file = File::create("transcript.txt").expect("failed to create file"); let mut file = File::create("transcript.txt").expect("failed to create file");
// Iterate through the segments of the transcript. // Iterate through the segments of the transcript.
let num_segments = ctx.full_n_segments(); let num_segments = ctx
.full_n_segments(&())
.expect("failed to get number of segments");
for i in 0..num_segments { for i in 0..num_segments {
// Get the transcribed text and timestamps for the current segment. // Get the transcribed text and timestamps for the current segment.
let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); let segment = ctx
let start_timestamp = ctx.full_get_segment_t0(i); .full_get_segment_text(&(), i)
let end_timestamp = ctx.full_get_segment_t1(i); .expect("failed to get segment");
let start_timestamp = ctx
.full_get_segment_t0(&(), i)
.expect("failed to get start timestamp");
let end_timestamp = ctx
.full_get_segment_t1(&(), i)
.expect("failed to get end timestamp");
// Print the segment to stdout. // Print the segment to stdout.
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);

View file

@ -7,7 +7,10 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
// more dependencies than the base library. // more dependencies than the base library.
pub fn usage() -> Result<(), &'static str> { pub fn usage() -> Result<(), &'static str> {
// load a context and model // load a context and model
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model"); let ctx = WhisperContext::new("path/to/model").expect("failed to load model");
// make a sample key
// here, since we only use this model once, we use a unique global key
ctx.create_key(()).expect("failed to create key");
// create a params object // create a params object
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP // note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
@ -41,15 +44,24 @@ pub fn usage() -> Result<(), &'static str> {
)?; )?;
// now we can run the model // now we can run the model
ctx.full(params, &audio_data[..]) // note the key we use here is the one we created above
ctx.full(&(), params, &audio_data[..])
.expect("failed to run model"); .expect("failed to run model");
// fetch the results // fetch the results
let num_segments = ctx.full_n_segments(); let num_segments = ctx
.full_n_segments(&())
.expect("failed to get number of segments");
for i in 0..num_segments { for i in 0..num_segments {
let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); let segment = ctx
let start_timestamp = ctx.full_get_segment_t0(i); .full_get_segment_text(&(), i)
let end_timestamp = ctx.full_get_segment_t1(i); .expect("failed to get segment");
let start_timestamp = ctx
.full_get_segment_t0(&(), i)
.expect("failed to get segment start timestamp");
let end_timestamp = ctx
.full_get_segment_t1(&(), i)
.expect("failed to get segment end timestamp");
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
} }

View file

@ -45,18 +45,19 @@ fn main() {
let original_samples = parse_wav_file(audio_path); let original_samples = parse_wav_file(audio_path);
let samples = whisper_rs::convert_integer_to_float_audio(&original_samples); let samples = whisper_rs::convert_integer_to_float_audio(&original_samples);
let mut ctx = let ctx =
WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
ctx.create_key(()).expect("failed to create key");
let params = FullParams::new(SamplingStrategy::default()); let params = FullParams::new(SamplingStrategy::default());
ctx.full(params, &samples) ctx.full(&(), params, &samples)
.expect("failed to convert samples"); .expect("failed to convert samples");
let num_segments = ctx.full_n_segments(); let num_segments = ctx.full_n_segments(&()).expect("failed to get number of segments");
for i in 0..num_segments { for i in 0..num_segments {
let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); let segment = ctx.full_get_segment_text(&(), i).expect("failed to get segment");
let start_timestamp = ctx.full_get_segment_t0(i); let start_timestamp = ctx.full_get_segment_t0(&(), i).expect("failed to get start timestamp");
let end_timestamp = ctx.full_get_segment_t1(i); let end_timestamp = ctx.full_get_segment_t1(&(), i).expect("failed to get end timestamp");
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
} }
} }

View file

@ -37,6 +37,12 @@ pub enum WhisperError {
GenericError(c_int), GenericError(c_int),
/// Whisper failed to convert the provided text into tokens. /// Whisper failed to convert the provided text into tokens.
InvalidText, InvalidText,
/// Creating a state pointer failed. Check stderr for more information.
FailedToCreateState,
/// State pointer ID already exists.
StateIdAlreadyExists,
/// State pointer ID does not exist.
StateIdDoesNotExist,
} }
impl From<Utf8Error> for WhisperError { impl From<Utf8Error> for WhisperError {

View file

@ -6,6 +6,7 @@ mod standalone;
mod utilities; mod utilities;
mod whisper_ctx; mod whisper_ctx;
mod whisper_params; mod whisper_params;
mod whisper_state;
pub use error::WhisperError; pub use error::WhisperError;
pub use standalone::*; pub use standalone::*;
@ -17,3 +18,5 @@ pub type WhisperTokenData = whisper_rs_sys::whisper_token_data;
pub type WhisperToken = whisper_rs_sys::whisper_token; pub type WhisperToken = whisper_rs_sys::whisper_token;
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback; pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;

View file

@ -1,24 +1,24 @@
use crate::error::WhisperError; use crate::error::WhisperError;
use crate::whisper_params::FullParams; use crate::whisper_params::FullParams;
use crate::whisper_state::WhisperState;
use crate::{WhisperToken, WhisperTokenData}; use crate::{WhisperToken, WhisperTokenData};
use dashmap::DashMap;
use std::ffi::{c_int, CStr, CString}; use std::ffi::{c_int, CStr, CString};
use std::hash::Hash;
/// Safe Rust wrapper around a Whisper context. /// Safe Rust wrapper around a Whisper context.
/// ///
/// You likely want to create this with [WhisperContext::new], /// You likely want to create this with [WhisperContext::new],
/// then run a full transcription with [WhisperContext::full]. /// then run a full transcription with [WhisperContext::full].
#[derive(Debug)] #[derive(Debug)]
pub struct WhisperContext { pub struct WhisperContext<K: Hash + Eq> {
ctx: *mut whisper_rs_sys::whisper_context, ctx: *mut whisper_rs_sys::whisper_context,
/// has the spectrogram been initialized in at least one way?
spectrogram_initialized: bool, /// Map of state IDs to state objects.
/// has the data been encoded? state_map: DashMap<K, WhisperState>,
encode_complete: bool,
/// has decode been called at least once?
decode_once: bool,
} }
impl WhisperContext { impl<K: Hash + Eq> WhisperContext<K> {
/// Create a new WhisperContext from a file. /// Create a new WhisperContext from a file.
/// ///
/// # Arguments /// # Arguments
@ -31,15 +31,13 @@ impl WhisperContext {
/// `struct whisper_context * whisper_init_from_file(const char * path_model);` /// `struct whisper_context * whisper_init_from_file(const char * path_model);`
pub fn new(path: &str) -> Result<Self, WhisperError> { pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?; let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file(path_cstr.as_ptr()) }; let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
if ctx.is_null() { if ctx.is_null() {
Err(WhisperError::InitError) Err(WhisperError::InitError)
} else { } else {
Ok(Self { Ok(Self {
ctx, ctx,
spectrogram_initialized: false, state_map: DashMap::new(),
encode_complete: false,
decode_once: false,
}) })
} }
} }
@ -55,22 +53,54 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);` /// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);`
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> { pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx = let ctx = unsafe {
unsafe { whisper_rs_sys::whisper_init_from_buffer(buffer.as_ptr() as _, buffer.len()) }; whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
};
if ctx.is_null() { if ctx.is_null() {
Err(WhisperError::InitError) Err(WhisperError::InitError)
} else { } else {
Ok(Self { Ok(Self {
ctx, ctx,
spectrogram_initialized: false, state_map: DashMap::new(),
encode_complete: false,
decode_once: false,
}) })
} }
} }
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Create a new state object, ready for use.
///
/// # Arguments
/// * id: The ID of the state object. Must be unique.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
/// If the ID is already in use, returns Err(WhisperError::StateIdAlreadyExists).
///
/// # C++ equivalent
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
pub fn create_key(&self, id: K) -> Result<(), WhisperError> {
if self.state_map.contains_key(&id) {
return Err(WhisperError::StateIdAlreadyExists);
}
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) };
if state.is_null() {
Err(WhisperError::InitError)
} else {
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
self.state_map
.insert(id, unsafe { WhisperState::new(state) });
Ok(())
}
}
fn get_state_ptr(&self, id: &K) -> Result<*mut whisper_rs_sys::whisper_state, WhisperError> {
self.state_map
.get(id)
.map(|s| s.value().as_ptr())
.ok_or(WhisperError::StateIdDoesNotExist)
}
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// The resulting spectrogram is stored in the context transparently. /// The resulting spectrogram is stored in the context transparently.
/// ///
@ -83,13 +113,15 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { pub fn pcm_to_mel(&self, key: &K, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
if threads < 1 { if threads < 1 {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel( whisper_rs_sys::whisper_pcm_to_mel_with_state(
self.ctx, self.ctx,
state_ptr,
pcm.as_ptr(), pcm.as_ptr(),
pcm.len() as c_int, pcm.len() as c_int,
threads as c_int, threads as c_int,
@ -98,14 +130,54 @@ impl WhisperContext {
if ret == -1 { if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram) Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 0 { } else if ret == 0 {
self.spectrogram_initialized = true;
Ok(()) Ok(())
} else { } else {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
} }
} }
/// This can be used to set a custom log mel spectrogram inside the provided whisper context. /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
/// Applies a Phase Vocoder to speed up the audio x2.
/// The resulting spectrogram is stored in the context transparently.
///
/// # Arguments
/// * pcm: The raw PCM audio.
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// Ok(()) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
pub fn pcm_to_mel_phase_vocoder(
&self,
key: &K,
pcm: &[f32],
threads: usize,
) -> Result<(), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
self.ctx,
state_ptr,
pcm.as_ptr(),
pcm.len() as c_int,
threads as c_int,
)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 0 {
Ok(())
} else {
Err(WhisperError::GenericError(ret))
}
}
/// This can be used to set a custom log mel spectrogram inside the provided whisper state.
/// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. /// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
/// ///
/// # Note /// # Note
@ -121,10 +193,13 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> { pub fn set_mel(&self, key: &K, data: &[f32]) -> Result<(), WhisperError> {
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_set_mel( whisper_rs_sys::whisper_set_mel_with_state(
self.ctx, self.ctx,
state_ptr,
data.as_ptr(), data.as_ptr(),
data.len() as c_int, data.len() as c_int,
80 as c_int, 80 as c_int,
@ -133,7 +208,6 @@ impl WhisperContext {
if ret == -1 { if ret == -1 {
Err(WhisperError::InvalidMelBands) Err(WhisperError::InvalidMelBands)
} else if ret == 0 { } else if ret == 0 {
self.spectrogram_initialized = true;
Ok(()) Ok(())
} else { } else {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
@ -152,19 +226,22 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)` /// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> { pub fn encode(&self, key: &K, offset: usize, threads: usize) -> Result<(), WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
if threads < 1 { if threads < 1 {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let ret = let state_ptr = self.get_state_ptr(key)?;
unsafe { whisper_rs_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) }; let ret = unsafe {
whisper_rs_sys::whisper_encode_with_state(
self.ctx,
state_ptr,
offset as c_int,
threads as c_int,
)
};
if ret == -1 { if ret == -1 {
Err(WhisperError::UnableToCalculateEvaluation) Err(WhisperError::UnableToCalculateEvaluation)
} else if ret == 0 { } else if ret == 0 {
self.encode_complete = true;
Ok(()) Ok(())
} else { } else {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
@ -187,20 +264,20 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
pub fn decode( pub fn decode(
&mut self, &self,
key: &K,
tokens: &[WhisperToken], tokens: &[WhisperToken],
n_past: usize, n_past: usize,
threads: usize, threads: usize,
) -> Result<(), WhisperError> { ) -> Result<(), WhisperError> {
if !self.encode_complete {
return Err(WhisperError::EncodeNotComplete);
}
if threads < 1 { if threads < 1 {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_decode( whisper_rs_sys::whisper_decode_with_state(
self.ctx, self.ctx,
state_ptr,
tokens.as_ptr(), tokens.as_ptr(),
tokens.len() as c_int, tokens.len() as c_int,
n_past as c_int, n_past as c_int,
@ -210,7 +287,6 @@ impl WhisperContext {
if ret == -1 { if ret == -1 {
Err(WhisperError::UnableToCalculateEvaluation) Err(WhisperError::UnableToCalculateEvaluation)
} else if ret == 0 { } else if ret == 0 {
self.decode_once = true;
Ok(()) Ok(())
} else { } else {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
@ -228,7 +304,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
pub fn tokenize( pub fn tokenize(
&mut self, &self,
text: &str, text: &str,
max_tokens: usize, max_tokens: usize,
) -> Result<Vec<WhisperToken>, WhisperError> { ) -> Result<Vec<WhisperToken>, WhisperError> {
@ -265,20 +341,21 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)` /// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
pub fn lang_detect( pub fn lang_detect(
&mut self, &self,
key: &K,
offset_ms: usize, offset_ms: usize,
threads: usize, threads: usize,
) -> Result<Vec<f32>, WhisperError> { ) -> Result<Vec<f32>, WhisperError> {
if !self.spectrogram_initialized {
return Err(WhisperError::SpectrogramNotInitialized);
}
if threads < 1 { if threads < 1 {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let state_ptr = self.get_state_ptr(key)?;
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_lang_auto_detect( whisper_rs_sys::whisper_lang_auto_detect_with_state(
self.ctx, self.ctx,
state_ptr,
offset_ms as c_int, offset_ms as c_int,
threads as c_int, threads as c_int,
lang_probs.as_mut_ptr(), lang_probs.as_mut_ptr(),
@ -296,7 +373,10 @@ impl WhisperContext {
// and abort, as this will cause Undefined Behavior // and abort, as this will cause Undefined Behavior
// might get here due to the unwind being caught by a user-installed panic handler // might get here due to the unwind being caught by a user-installed panic handler
if lang_probs.len() != ret as usize { if lang_probs.len() != ret as usize {
eprintln!("lang_probs length mismatch: this is a bug in whisper.cpp, aborting"); eprintln!(
"lang_probs length mismatch: this is a bug in whisper.cpp, \
aborting to avoid Undefined Behavior"
);
std::process::abort(); std::process::abort();
} }
Ok(lang_probs) Ok(lang_probs)
@ -307,13 +387,13 @@ impl WhisperContext {
/// Get the mel spectrogram length. /// Get the mel spectrogram length.
/// ///
/// # Returns /// # Returns
/// c_int /// Ok(c_int) on success, Err(WhisperError) on failure.
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_n_len (struct whisper_context * ctx)` /// `int whisper_n_len_from_state(struct whisper_context * ctx)`
#[inline] #[inline]
pub fn n_len(&self) -> c_int { pub fn n_len(&self, key: &K) -> Result<c_int, WhisperError> {
unsafe { whisper_rs_sys::whisper_n_len(self.ctx) } Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(self.get_state_ptr(key)?) })
} }
/// Get n_vocab. /// Get n_vocab.
@ -334,7 +414,7 @@ impl WhisperContext {
/// c_int /// c_int
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_n_text_ctx (struct whisper_context * ctx)` /// `int whisper_n_text_ctx (struct whisper_context * ctx);`
#[inline] #[inline]
pub fn n_text_ctx(&self) -> c_int { pub fn n_text_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) } unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) }
@ -346,7 +426,7 @@ impl WhisperContext {
/// c_int /// c_int
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_n_audio_ctx (struct whisper_context * ctx)` /// `int whisper_n_audio_ctx (struct whisper_context * ctx);`
#[inline] #[inline]
pub fn n_audio_ctx(&self) -> c_int { pub fn n_audio_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) } unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) }
@ -361,6 +441,150 @@ impl WhisperContext {
unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 } unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
} }
/// Get model_n_vocab.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_vocab (struct whisper_context * ctx);`
#[inline]
pub fn model_n_vocab(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_vocab(self.ctx) }
}
/// Get model_n_audio_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)`
#[inline]
pub fn model_n_audio_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_ctx(self.ctx) }
}
/// Get model_n_audio_state.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_state(struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_state(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_state(self.ctx) }
}
/// Get model_n_audio_head.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_head (struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_head(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_head(self.ctx) }
}
/// Get model_n_audio_layer.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_audio_layer(struct whisper_context * ctx);`
#[inline]
pub fn model_n_audio_layer(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_audio_layer(self.ctx) }
}
/// Get model_n_text_ctx.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_ctx (struct whisper_context * ctx)`
#[inline]
pub fn model_n_text_ctx(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_ctx(self.ctx) }
}
/// Get model_n_text_state.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_state (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_state(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_state(self.ctx) }
}
/// Get model_n_text_head.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_head (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_head(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_head(self.ctx) }
}
/// Get model_n_text_layer.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_text_layer (struct whisper_context * ctx);`
#[inline]
pub fn model_n_text_layer(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_text_layer(self.ctx) }
}
/// Get model_n_mels.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_n_mels (struct whisper_context * ctx);`
#[inline]
pub fn model_n_mels(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_n_mels(self.ctx) }
}
/// Get model_f16.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_f16 (struct whisper_context * ctx);`
#[inline]
pub fn model_f16(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_f16(self.ctx) }
}
/// Get model_type.
///
/// # Returns
/// c_int
///
/// # C++ equivalent
/// `int whisper_model_type (struct whisper_context * ctx);`
#[inline]
pub fn model_type(&self) -> c_int {
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
}
// logit functions // logit functions
/// Get the logits obtained from the last call to [WhisperContext::decode]. /// Get the logits obtained from the last call to [WhisperContext::decode].
/// The logits for the last token are stored in the last row of the matrix. /// The logits for the last token are stored in the last row of the matrix.
@ -376,18 +600,16 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `float * whisper_get_logits(struct whisper_context * ctx)` /// `float * whisper_get_logits(struct whisper_context * ctx)`
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> { pub fn get_logits(&self, key: &K, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
if !self.spectrogram_initialized { let state_ptr = self.get_state_ptr(key)?;
return Err(WhisperError::SpectrogramNotInitialized);
}
let ret = unsafe { whisper_rs_sys::whisper_get_logits(self.ctx) }; let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(state_ptr) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
let mut logits = Vec::new(); let mut logits = Vec::new();
let n_vocab = self.n_vocab(); let n_vocab = self.n_vocab();
let n_tokens = self.full_n_tokens(segment); let n_tokens = self.full_n_tokens(key, segment)?;
for i in 0..n_tokens { for i in 0..n_tokens {
let mut row = Vec::new(); let mut row = Vec::new();
for j in 0..n_vocab { for j in 0..n_vocab {
@ -421,6 +643,21 @@ impl WhisperContext {
Ok(r_str.to_string()) Ok(r_str.to_string())
} }
/// Undocumented but exposed function in the C++ API.
/// `const char * whisper_model_type_readable(struct whisper_context * ctx);`
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
pub fn model_type_readable(&self) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
}
/// Get the ID of the eot token. /// Get the ID of the eot token.
/// ///
/// # C++ equivalent /// # C++ equivalent
@ -519,51 +756,15 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> { pub fn full(&self, key: &K, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int) whisper_rs_sys::whisper_full_with_state(
};
if ret == -1 {
Err(WhisperError::UnableToCalculateSpectrogram)
} else if ret == 7 {
Err(WhisperError::FailedToEncode)
} else if ret == 8 {
Err(WhisperError::FailedToDecode)
} else if ret == 0 {
Ok(ret)
} else {
Err(WhisperError::GenericError(ret))
}
}
/// Split the input audio into chunks and delegate to [WhisperContext::full].
///
/// It seems this approach can offer some speedup in some cases,
/// however, the accuracy can be worse at the start and end of chunks.
///
/// # Arguments
/// * params: [crate::FullParams] struct.
/// * pcm: PCM audio data.
/// * n_processors: Number of threads to use.
///
/// # Returns
/// Ok(c_int) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `int whisper_full_parallel(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples, int n_processors)`
pub fn full_parallel(
&mut self,
params: FullParams,
data: &[f32],
n_processors: c_int,
) -> Result<c_int, WhisperError> {
let ret = unsafe {
whisper_rs_sys::whisper_full_parallel(
self.ctx, self.ctx,
state_ptr,
params.fp, params.fp,
data.as_ptr(), data.as_ptr(),
data.len() as c_int, data.len() as c_int,
n_processors,
) )
}; };
if ret == -1 { if ret == -1 {
@ -573,8 +774,6 @@ impl WhisperContext {
} else if ret == 8 { } else if ret == 8 {
Err(WhisperError::FailedToDecode) Err(WhisperError::FailedToDecode)
} else if ret == 0 { } else if ret == 0 {
// note 0 is returned on success and also when initializing other contexts fails,
// causing some audio to not be processed
Ok(ret) Ok(ret)
} else { } else {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
@ -587,8 +786,17 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_full_n_segments(struct whisper_context * ctx)` /// `int whisper_full_n_segments(struct whisper_context * ctx)`
#[inline] #[inline]
pub fn full_n_segments(&self) -> c_int { pub fn full_n_segments(&self, key: &K) -> Result<c_int, WhisperError> {
unsafe { whisper_rs_sys::whisper_full_n_segments(self.ctx) } Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.get_state_ptr(key)?) })
}
/// Language ID associated with the provided state.
///
/// # C++ equivalent
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
#[inline]
pub fn full_lang_id_from_state(&self, key: &K) -> Result<c_int, WhisperError> {
Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.get_state_ptr(key)?) })
} }
/// Get the start time of the specified segment. /// Get the start time of the specified segment.
@ -599,8 +807,13 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
#[inline] #[inline]
pub fn full_get_segment_t0(&self, segment: c_int) -> i64 { pub fn full_get_segment_t0(&self, key: &K, segment: c_int) -> Result<i64, WhisperError> {
unsafe { whisper_rs_sys::whisper_full_get_segment_t0(self.ctx, segment) } Ok(unsafe {
whisper_rs_sys::whisper_full_get_segment_t0_from_state(
self.get_state_ptr(key)?,
segment,
)
})
} }
/// Get the end time of the specified segment. /// Get the end time of the specified segment.
@ -611,8 +824,13 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
#[inline] #[inline]
pub fn full_get_segment_t1(&self, segment: c_int) -> i64 { pub fn full_get_segment_t1(&self, key: &K, segment: c_int) -> Result<i64, WhisperError> {
unsafe { whisper_rs_sys::whisper_full_get_segment_t1(self.ctx, segment) } Ok(unsafe {
whisper_rs_sys::whisper_full_get_segment_t1_from_state(
self.get_state_ptr(key)?,
segment,
)
})
} }
/// Get the text of the specified segment. /// Get the text of the specified segment.
@ -625,8 +843,11 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> { pub fn full_get_segment_text(&self, key: &K, segment: c_int) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_get_segment_text(self.ctx, segment) }; let state_ptr = self.get_state_ptr(key)?;
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(state_ptr, segment) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
@ -646,8 +867,10 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
#[inline] #[inline]
pub fn full_n_tokens(&self, segment: c_int) -> c_int { pub fn full_n_tokens(&self, key: &K, segment: c_int) -> Result<c_int, WhisperError> {
unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) } Ok(unsafe {
whisper_rs_sys::whisper_full_n_tokens_from_state(self.get_state_ptr(key)?, segment)
})
} }
/// Get the token text of the specified token in the specified segment. /// Get the token text of the specified token in the specified segment.
@ -663,10 +886,16 @@ impl WhisperContext {
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_text( pub fn full_get_token_text(
&self, &self,
key: &K,
segment: c_int, segment: c_int,
token: c_int, token: c_int,
) -> Result<String, WhisperError> { ) -> Result<String, WhisperError> {
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) }; let state_ptr = self.get_state_ptr(key)?;
let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx, state_ptr, segment, token,
)
};
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
@ -686,8 +915,19 @@ impl WhisperContext {
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_id(&self, segment: c_int, token: c_int) -> WhisperToken { pub fn full_get_token_id(
unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) } &self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<WhisperToken, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_token_id_from_state(
self.get_state_ptr(key)?,
segment,
token,
)
})
} }
/// Get token data for the specified token in the specified segment. /// Get token data for the specified token in the specified segment.
@ -702,8 +942,19 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
#[inline] #[inline]
pub fn full_get_token_data(&self, segment: c_int, token: c_int) -> WhisperTokenData { pub fn full_get_token_data(
unsafe { whisper_rs_sys::whisper_full_get_token_data(self.ctx, segment, token) } &self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<WhisperTokenData, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_token_data_from_state(
self.get_state_ptr(key)?,
segment,
token,
)
})
} }
/// Get the probability of the specified token in the specified segment. /// Get the probability of the specified token in the specified segment.
@ -718,12 +969,23 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
#[inline] #[inline]
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 { pub fn full_get_token_prob(
unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) } &self,
key: &K,
segment: c_int,
token: c_int,
) -> Result<f32, WhisperError> {
Ok(unsafe {
whisper_rs_sys::whisper_full_get_token_p_from_state(
self.get_state_ptr(key)?,
segment,
token,
)
})
} }
} }
impl Drop for WhisperContext { impl<K: Hash + Eq> Drop for WhisperContext<K> {
#[inline] #[inline]
fn drop(&mut self) { fn drop(&mut self) {
unsafe { whisper_rs_sys::whisper_free(self.ctx) }; unsafe { whisper_rs_sys::whisper_free(self.ctx) };
@ -732,6 +994,5 @@ impl Drop for WhisperContext {
// following implementations are safe // following implementations are safe
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
// concurrent usage is prevented by &mut self on methods that modify the struct unsafe impl<K: Hash + Eq> Send for WhisperContext<K> {}
unsafe impl Send for WhisperContext {} unsafe impl<K: Hash + Eq> Sync for WhisperContext<K> {}
unsafe impl Sync for WhisperContext {}

View file

@ -175,6 +175,15 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.max_len = max_len; self.fp.max_len = max_len;
} }
/// # EXPERIMENTAL
///
/// Should the timestamps be split on words instead of characters?
///
/// Defaults to false.
pub fn set_split_on_word(&mut self, split_on_word: bool) {
self.fp.split_on_word = split_on_word;
}
/// # EXPERIMENTAL /// # EXPERIMENTAL
/// ///
/// Set maximum tokens per segment. 0 means no limit. /// Set maximum tokens per segment. 0 means no limit.
@ -243,6 +252,14 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.suppress_blank = suppress_blank; self.fp.suppress_blank = suppress_blank;
} }
/// Set suppress_non_speech_tokens. See https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
/// for more information.
///
/// Defaults to false.
pub fn set_suppress_non_speech_tokens(&mut self, suppress_non_speech_tokens: bool) {
self.fp.suppress_non_speech_tokens = suppress_non_speech_tokens;
}
/// Set initial decoding temperature. See https://ai.stackexchange.com/a/32478 for more information. /// Set initial decoding temperature. See https://ai.stackexchange.com/a/32478 for more information.
/// ///
/// Defaults to 0.0. /// Defaults to 0.0.
@ -290,7 +307,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.logprob_thold = logprob_thold; self.fp.logprob_thold = logprob_thold;
} }
/// Set no_speech_thold. Currently (as of v1.2.0) not implemented. /// Set no_speech_thold. Currently (as of v1.3.0) not implemented.
/// ///
/// Defaults to 0.6. /// Defaults to 0.6.
pub fn set_no_speech_thold(&mut self, no_speech_thold: f32) { pub fn set_no_speech_thold(&mut self, no_speech_thold: f32) {
@ -325,7 +342,35 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.new_segment_callback_user_data = user_data; self.fp.new_segment_callback_user_data = user_data;
} }
/// Set the callback for starting the encoder. /// Set the callback for progress updates.
///
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
/// It is still a C callback.
///
/// # Safety
/// Do not use this function unless you know what you are doing.
/// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
/// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
///
/// Defaults to None.
pub unsafe fn set_progress_callback(
&mut self,
progress_callback: crate::WhisperProgressCallback,
) {
self.fp.progress_callback = progress_callback;
}
/// Set the user data to be passed to the progress callback.
///
/// # Safety
/// See the safety notes for `set_progress_callback`.
///
/// Defaults to None.
pub unsafe fn set_progress_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
self.fp.progress_callback_user_data = user_data;
}
/// Set the callback that is called each time before the encoder begins.
/// ///
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
/// It is still a C callback. /// It is still a C callback.
@ -355,6 +400,37 @@ impl<'a, 'b> FullParams<'a, 'b> {
) { ) {
self.fp.encoder_begin_callback_user_data = user_data; self.fp.encoder_begin_callback_user_data = user_data;
} }
/// Set the callback that is called by each decoder to filter obtained logits.
///
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
/// It is still a C callback.
///
/// # Safety
/// Do not use this function unless you know what you are doing.
/// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
/// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
///
/// Defaults to None.
pub unsafe fn set_filter_logits_callback(
&mut self,
logits_filter_callback: crate::WhisperLogitsFilterCallback,
) {
self.fp.logits_filter_callback = logits_filter_callback;
}
/// Set the user data to be passed to the logits filter callback.
///
/// # Safety
/// See the safety notes for `set_filter_logits_callback`.
///
/// Defaults to None.
pub unsafe fn set_filter_logits_callback_user_data(
&mut self,
user_data: *mut std::ffi::c_void,
) {
self.fp.logits_filter_callback_user_data = user_data;
}
} }
// following implementations are safe // following implementations are safe

26
src/whisper_state.rs Normal file
View file

@ -0,0 +1,26 @@
/// Rustified pointer to a Whisper state.
#[derive(Debug)]
pub struct WhisperState {
ptr: *mut whisper_rs_sys::whisper_state,
}
unsafe impl Send for WhisperState {}
unsafe impl Sync for WhisperState {}
impl Drop for WhisperState {
fn drop(&mut self) {
unsafe {
whisper_rs_sys::whisper_free_state(self.ptr);
}
}
}
impl WhisperState {
pub(crate) unsafe fn new(ptr: *mut whisper_rs_sys::whisper_state) -> Self {
Self { ptr }
}
pub(crate) fn as_ptr(&self) -> *mut whisper_rs_sys::whisper_state {
self.ptr
}
}