diff --git a/CHANGELOG.md b/CHANGELOG.md index 998e28e..fab056c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,13 @@ +# Version 0.6.0 (2023-04-17) +* Update upstream whisper.cpp to v1.3.0 +* Fix breaking changes in update, which cascade to users: + * `WhisperContext`s now have a generic type parameter, which is a hashable key for a state map. + This allows for a single context to be reused for multiple different states, saving memory. + * You must create a new state upon creation, even if you are using the context only once, by calling `WhisperContext::create_key`. + * Each method that now takes a state now takes a key, which internally is used to look up the state. + * This also turns `WhisperContext` into an entirely immutable object, meaning it can be shared across threads and used concurrently, safely. +* Send feedback on these changes to the PR: https://github.com/tazz4843/whisper-rs/pull/33 + # Version 0.2.0 (2022-10-28) * Update upstream whisper.cpp to 2c281d190b7ec351b8128ba386d110f100993973. * Fix breaking changes in update, which cascade to users: diff --git a/Cargo.toml b/Cargo.toml index 117dabf..ef8c01f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"] [package] name = "whisper-rs" -version = "0.5.0" +version = "0.6.0" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -14,7 +14,8 @@ repository = "https://github.com/tazz4843/whisper-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -whisper-rs-sys = { path = "sys", version = "0.3" } +whisper-rs-sys = { path = "sys", version = "0.4" } +dashmap = "5" [dev-dependencies] hound = "3.5.0" diff --git a/examples/audio_transcription.rs b/examples/audio_transcription.rs index 7ab716d..e389723 100644 --- a/examples/audio_transcription.rs +++ b/examples/audio_transcription.rs @@ -9,11 +9,12 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. fn main() -> Result<(), &'static str> { // Load a context and model. - let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") + let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") .expect("failed to load model"); + // Create a single global key. + ctx.create_key(()).expect("failed to create key"); // Create a params object for running the model. - // Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP. // The number of past samples to consider defaults to 0. let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); @@ -62,18 +63,27 @@ fn main() -> Result<(), &'static str> { } // Run the model. - ctx.full(params, &audio[..]).expect("failed to run model"); + ctx.full(&(), params, &audio[..]) + .expect("failed to run model"); // Create a file to write the transcript to. let mut file = File::create("transcript.txt").expect("failed to create file"); // Iterate through the segments of the transcript. - let num_segments = ctx.full_n_segments(); + let num_segments = ctx + .full_n_segments(&()) + .expect("failed to get number of segments"); for i in 0..num_segments { // Get the transcribed text and timestamps for the current segment. - let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); - let start_timestamp = ctx.full_get_segment_t0(i); - let end_timestamp = ctx.full_get_segment_t1(i); + let segment = ctx + .full_get_segment_text(&(), i) + .expect("failed to get segment"); + let start_timestamp = ctx + .full_get_segment_t0(&(), i) + .expect("failed to get start timestamp"); + let end_timestamp = ctx + .full_get_segment_t1(&(), i) + .expect("failed to get end timestamp"); // Print the segment to stdout. println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 727deba..435320a 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -7,7 +7,10 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; // more dependencies than the base library. pub fn usage() -> Result<(), &'static str> { // load a context and model - let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model"); + let ctx = WhisperContext::new("path/to/model").expect("failed to load model"); + // make a sample key + // here, since we only use this model once, we use a unique global key + ctx.create_key(()).expect("failed to create key"); // create a params object // note that currently the only implemented strategy is Greedy, BeamSearch is a WIP @@ -41,15 +44,24 @@ pub fn usage() -> Result<(), &'static str> { )?; // now we can run the model - ctx.full(params, &audio_data[..]) + // note the key we use here is the one we created above + ctx.full(&(), params, &audio_data[..]) .expect("failed to run model"); // fetch the results - let num_segments = ctx.full_n_segments(); + let num_segments = ctx + .full_n_segments(&()) + .expect("failed to get number of segments"); for i in 0..num_segments { - let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); - let start_timestamp = ctx.full_get_segment_t0(i); - let end_timestamp = ctx.full_get_segment_t1(i); + let segment = ctx + .full_get_segment_text(&(), i) + .expect("failed to get segment"); + let start_timestamp = ctx + .full_get_segment_t0(&(), i) + .expect("failed to get segment start timestamp"); + let end_timestamp = ctx + .full_get_segment_t1(&(), i) + .expect("failed to get segment end timestamp"); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); } diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index d7da7a9..fa73a92 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -45,18 +45,19 @@ fn main() { let original_samples = parse_wav_file(audio_path); let samples = whisper_rs::convert_integer_to_float_audio(&original_samples); - let mut ctx = + let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); + ctx.create_key(()).expect("failed to create key"); let params = FullParams::new(SamplingStrategy::default()); - ctx.full(params, &samples) + ctx.full(&(), params, &samples) .expect("failed to convert samples"); - let num_segments = ctx.full_n_segments(); + let num_segments = ctx.full_n_segments(&()).expect("failed to get number of segments"); for i in 0..num_segments { - let segment = ctx.full_get_segment_text(i).expect("failed to get segment"); - let start_timestamp = ctx.full_get_segment_t0(i); - let end_timestamp = ctx.full_get_segment_t1(i); + let segment = ctx.full_get_segment_text(&(), i).expect("failed to get segment"); + let start_timestamp = ctx.full_get_segment_t0(&(), i).expect("failed to get start timestamp"); + let end_timestamp = ctx.full_get_segment_t1(&(), i).expect("failed to get end timestamp"); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); } } diff --git a/src/error.rs b/src/error.rs index ef87c00..e92841f 100644 --- a/src/error.rs +++ b/src/error.rs @@ -37,6 +37,12 @@ pub enum WhisperError { GenericError(c_int), /// Whisper failed to convert the provided text into tokens. InvalidText, + /// Creating a state pointer failed. Check stderr for more information. + FailedToCreateState, + /// State pointer ID already exists. + StateIdAlreadyExists, + /// State pointer ID does not exist. + StateIdDoesNotExist, } impl From for WhisperError { diff --git a/src/lib.rs b/src/lib.rs index 041fcfa..1962f92 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,6 +6,7 @@ mod standalone; mod utilities; mod whisper_ctx; mod whisper_params; +mod whisper_state; pub use error::WhisperError; pub use standalone::*; @@ -17,3 +18,5 @@ pub type WhisperTokenData = whisper_rs_sys::whisper_token_data; pub type WhisperToken = whisper_rs_sys::whisper_token; pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback; +pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback; +pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback; diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index d5f1f0f..11ee051 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -1,24 +1,24 @@ use crate::error::WhisperError; use crate::whisper_params::FullParams; +use crate::whisper_state::WhisperState; use crate::{WhisperToken, WhisperTokenData}; +use dashmap::DashMap; use std::ffi::{c_int, CStr, CString}; +use std::hash::Hash; /// Safe Rust wrapper around a Whisper context. /// /// You likely want to create this with [WhisperContext::new], /// then run a full transcription with [WhisperContext::full]. #[derive(Debug)] -pub struct WhisperContext { +pub struct WhisperContext { ctx: *mut whisper_rs_sys::whisper_context, - /// has the spectrogram been initialized in at least one way? - spectrogram_initialized: bool, - /// has the data been encoded? - encode_complete: bool, - /// has decode been called at least once? - decode_once: bool, + + /// Map of state IDs to state objects. + state_map: DashMap, } -impl WhisperContext { +impl WhisperContext { /// Create a new WhisperContext from a file. /// /// # Arguments @@ -31,15 +31,13 @@ impl WhisperContext { /// `struct whisper_context * whisper_init_from_file(const char * path_model);` pub fn new(path: &str) -> Result { let path_cstr = CString::new(path)?; - let ctx = unsafe { whisper_rs_sys::whisper_init_from_file(path_cstr.as_ptr()) }; + let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) }; if ctx.is_null() { Err(WhisperError::InitError) } else { Ok(Self { ctx, - spectrogram_initialized: false, - encode_complete: false, - decode_once: false, + state_map: DashMap::new(), }) } } @@ -55,22 +53,54 @@ impl WhisperContext { /// # C++ equivalent /// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);` pub fn new_from_buffer(buffer: &[u8]) -> Result { - let ctx = - unsafe { whisper_rs_sys::whisper_init_from_buffer(buffer.as_ptr() as _, buffer.len()) }; + let ctx = unsafe { + whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len()) + }; if ctx.is_null() { Err(WhisperError::InitError) } else { Ok(Self { ctx, - spectrogram_initialized: false, - encode_complete: false, - decode_once: false, + state_map: DashMap::new(), }) } } // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does + /// Create a new state object, ready for use. + /// + /// # Arguments + /// * id: The ID of the state object. Must be unique. + /// + /// # Returns + /// Ok(()) on success, Err(WhisperError) on failure. + /// If the ID is already in use, returns Err(WhisperError::StateIdAlreadyExists). + /// + /// # C++ equivalent + /// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);` + pub fn create_key(&self, id: K) -> Result<(), WhisperError> { + if self.state_map.contains_key(&id) { + return Err(WhisperError::StateIdAlreadyExists); + } + let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) }; + if state.is_null() { + Err(WhisperError::InitError) + } else { + // SAFETY: this is known to be a valid pointer to a `whisper_state` struct + self.state_map + .insert(id, unsafe { WhisperState::new(state) }); + Ok(()) + } + } + + fn get_state_ptr(&self, id: &K) -> Result<*mut whisper_rs_sys::whisper_state, WhisperError> { + self.state_map + .get(id) + .map(|s| s.value().as_ptr()) + .ok_or(WhisperError::StateIdDoesNotExist) + } + /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. /// The resulting spectrogram is stored in the context transparently. /// @@ -83,13 +113,15 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` - pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { + pub fn pcm_to_mel(&self, key: &K, pcm: &[f32], threads: usize) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } + let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { - whisper_rs_sys::whisper_pcm_to_mel( + whisper_rs_sys::whisper_pcm_to_mel_with_state( self.ctx, + state_ptr, pcm.as_ptr(), pcm.len() as c_int, threads as c_int, @@ -98,14 +130,54 @@ impl WhisperContext { if ret == -1 { Err(WhisperError::UnableToCalculateSpectrogram) } else if ret == 0 { - self.spectrogram_initialized = true; Ok(()) } else { Err(WhisperError::GenericError(ret)) } } - /// This can be used to set a custom log mel spectrogram inside the provided whisper context. + /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. + /// Applies a Phase Vocoder to speed up the audio x2. + /// The resulting spectrogram is stored in the context transparently. + /// + /// # Arguments + /// * pcm: The raw PCM audio. + /// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. + /// + /// # Returns + /// Ok(()) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)` + pub fn pcm_to_mel_phase_vocoder( + &self, + key: &K, + pcm: &[f32], + threads: usize, + ) -> Result<(), WhisperError> { + if threads < 1 { + return Err(WhisperError::InvalidThreadCount); + } + let state_ptr = self.get_state_ptr(key)?; + let ret = unsafe { + whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state( + self.ctx, + state_ptr, + pcm.as_ptr(), + pcm.len() as c_int, + threads as c_int, + ) + }; + if ret == -1 { + Err(WhisperError::UnableToCalculateSpectrogram) + } else if ret == 0 { + Ok(()) + } else { + Err(WhisperError::GenericError(ret)) + } + } + + /// This can be used to set a custom log mel spectrogram inside the provided whisper state. /// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram. /// /// # Note @@ -121,10 +193,13 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` - pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> { + pub fn set_mel(&self, key: &K, data: &[f32]) -> Result<(), WhisperError> { + let state_ptr = self.get_state_ptr(key)?; + let ret = unsafe { - whisper_rs_sys::whisper_set_mel( + whisper_rs_sys::whisper_set_mel_with_state( self.ctx, + state_ptr, data.as_ptr(), data.len() as c_int, 80 as c_int, @@ -133,7 +208,6 @@ impl WhisperContext { if ret == -1 { Err(WhisperError::InvalidMelBands) } else if ret == 0 { - self.spectrogram_initialized = true; Ok(()) } else { Err(WhisperError::GenericError(ret)) @@ -152,19 +226,22 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)` - pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> { - if !self.spectrogram_initialized { - return Err(WhisperError::SpectrogramNotInitialized); - } + pub fn encode(&self, key: &K, offset: usize, threads: usize) -> Result<(), WhisperError> { if threads < 1 { return Err(WhisperError::InvalidThreadCount); } - let ret = - unsafe { whisper_rs_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) }; + let state_ptr = self.get_state_ptr(key)?; + let ret = unsafe { + whisper_rs_sys::whisper_encode_with_state( + self.ctx, + state_ptr, + offset as c_int, + threads as c_int, + ) + }; if ret == -1 { Err(WhisperError::UnableToCalculateEvaluation) } else if ret == 0 { - self.encode_complete = true; Ok(()) } else { Err(WhisperError::GenericError(ret)) @@ -187,20 +264,20 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)` pub fn decode( - &mut self, + &self, + key: &K, tokens: &[WhisperToken], n_past: usize, threads: usize, ) -> Result<(), WhisperError> { - if !self.encode_complete { - return Err(WhisperError::EncodeNotComplete); - } if threads < 1 { return Err(WhisperError::InvalidThreadCount); } + let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { - whisper_rs_sys::whisper_decode( + whisper_rs_sys::whisper_decode_with_state( self.ctx, + state_ptr, tokens.as_ptr(), tokens.len() as c_int, n_past as c_int, @@ -210,7 +287,6 @@ impl WhisperContext { if ret == -1 { Err(WhisperError::UnableToCalculateEvaluation) } else if ret == 0 { - self.decode_once = true; Ok(()) } else { Err(WhisperError::GenericError(ret)) @@ -228,7 +304,7 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` pub fn tokenize( - &mut self, + &self, text: &str, max_tokens: usize, ) -> Result, WhisperError> { @@ -265,20 +341,21 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)` pub fn lang_detect( - &mut self, + &self, + key: &K, offset_ms: usize, threads: usize, ) -> Result, WhisperError> { - if !self.spectrogram_initialized { - return Err(WhisperError::SpectrogramNotInitialized); - } if threads < 1 { return Err(WhisperError::InvalidThreadCount); } + let state_ptr = self.get_state_ptr(key)?; + let mut lang_probs: Vec = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; let ret = unsafe { - whisper_rs_sys::whisper_lang_auto_detect( + whisper_rs_sys::whisper_lang_auto_detect_with_state( self.ctx, + state_ptr, offset_ms as c_int, threads as c_int, lang_probs.as_mut_ptr(), @@ -296,7 +373,10 @@ impl WhisperContext { // and abort, as this will cause Undefined Behavior // might get here due to the unwind being caught by a user-installed panic handler if lang_probs.len() != ret as usize { - eprintln!("lang_probs length mismatch: this is a bug in whisper.cpp, aborting"); + eprintln!( + "lang_probs length mismatch: this is a bug in whisper.cpp, \ + aborting to avoid Undefined Behavior" + ); std::process::abort(); } Ok(lang_probs) @@ -307,13 +387,13 @@ impl WhisperContext { /// Get the mel spectrogram length. /// /// # Returns - /// c_int + /// Ok(c_int) on success, Err(WhisperError) on failure. /// /// # C++ equivalent - /// `int whisper_n_len (struct whisper_context * ctx)` + /// `int whisper_n_len_from_state(struct whisper_context * ctx)` #[inline] - pub fn n_len(&self) -> c_int { - unsafe { whisper_rs_sys::whisper_n_len(self.ctx) } + pub fn n_len(&self, key: &K) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(self.get_state_ptr(key)?) }) } /// Get n_vocab. @@ -334,7 +414,7 @@ impl WhisperContext { /// c_int /// /// # C++ equivalent - /// `int whisper_n_text_ctx (struct whisper_context * ctx)` + /// `int whisper_n_text_ctx (struct whisper_context * ctx);` #[inline] pub fn n_text_ctx(&self) -> c_int { unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) } @@ -346,7 +426,7 @@ impl WhisperContext { /// c_int /// /// # C++ equivalent - /// `int whisper_n_audio_ctx (struct whisper_context * ctx)` + /// `int whisper_n_audio_ctx (struct whisper_context * ctx);` #[inline] pub fn n_audio_ctx(&self) -> c_int { unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) } @@ -361,6 +441,150 @@ impl WhisperContext { unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 } } + /// Get model_n_vocab. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_vocab (struct whisper_context * ctx);` + #[inline] + pub fn model_n_vocab(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_vocab(self.ctx) } + } + + /// Get model_n_audio_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_ctx (struct whisper_context * ctx)` + #[inline] + pub fn model_n_audio_ctx(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_audio_ctx(self.ctx) } + } + + /// Get model_n_audio_state. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_state(struct whisper_context * ctx);` + #[inline] + pub fn model_n_audio_state(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_audio_state(self.ctx) } + } + + /// Get model_n_audio_head. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_head (struct whisper_context * ctx);` + #[inline] + pub fn model_n_audio_head(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_audio_head(self.ctx) } + } + + /// Get model_n_audio_layer. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_audio_layer(struct whisper_context * ctx);` + #[inline] + pub fn model_n_audio_layer(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_audio_layer(self.ctx) } + } + + /// Get model_n_text_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_ctx (struct whisper_context * ctx)` + #[inline] + pub fn model_n_text_ctx(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_text_ctx(self.ctx) } + } + + /// Get model_n_text_state. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_state (struct whisper_context * ctx);` + #[inline] + pub fn model_n_text_state(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_text_state(self.ctx) } + } + + /// Get model_n_text_head. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_head (struct whisper_context * ctx);` + #[inline] + pub fn model_n_text_head(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_text_head(self.ctx) } + } + + /// Get model_n_text_layer. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_text_layer (struct whisper_context * ctx);` + #[inline] + pub fn model_n_text_layer(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_text_layer(self.ctx) } + } + + /// Get model_n_mels. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_n_mels (struct whisper_context * ctx);` + #[inline] + pub fn model_n_mels(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_n_mels(self.ctx) } + } + + /// Get model_f16. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_f16 (struct whisper_context * ctx);` + #[inline] + pub fn model_f16(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_f16(self.ctx) } + } + + /// Get model_type. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_model_type (struct whisper_context * ctx);` + #[inline] + pub fn model_type(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_model_type(self.ctx) } + } + // logit functions /// Get the logits obtained from the last call to [WhisperContext::decode]. /// The logits for the last token are stored in the last row of the matrix. @@ -376,18 +600,16 @@ impl WhisperContext { /// /// # C++ equivalent /// `float * whisper_get_logits(struct whisper_context * ctx)` - pub fn get_logits(&self, segment: c_int) -> Result>, WhisperError> { - if !self.spectrogram_initialized { - return Err(WhisperError::SpectrogramNotInitialized); - } + pub fn get_logits(&self, key: &K, segment: c_int) -> Result>, WhisperError> { + let state_ptr = self.get_state_ptr(key)?; - let ret = unsafe { whisper_rs_sys::whisper_get_logits(self.ctx) }; + let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(state_ptr) }; if ret.is_null() { return Err(WhisperError::NullPointer); } let mut logits = Vec::new(); let n_vocab = self.n_vocab(); - let n_tokens = self.full_n_tokens(segment); + let n_tokens = self.full_n_tokens(key, segment)?; for i in 0..n_tokens { let mut row = Vec::new(); for j in 0..n_vocab { @@ -421,6 +643,21 @@ impl WhisperContext { Ok(r_str.to_string()) } + /// Undocumented but exposed function in the C++ API. + /// `const char * whisper_model_type_readable(struct whisper_context * ctx);` + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + pub fn model_type_readable(&self) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_model_type_readable(self.ctx) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + let c_str = unsafe { CStr::from_ptr(ret) }; + let r_str = c_str.to_str()?; + Ok(r_str.to_string()) + } + /// Get the ID of the eot token. /// /// # C++ equivalent @@ -519,51 +756,15 @@ impl WhisperContext { /// /// # C++ equivalent /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` - pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result { + pub fn full(&self, key: &K, params: FullParams, data: &[f32]) -> Result { + let state_ptr = self.get_state_ptr(key)?; let ret = unsafe { - whisper_rs_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int) - }; - if ret == -1 { - Err(WhisperError::UnableToCalculateSpectrogram) - } else if ret == 7 { - Err(WhisperError::FailedToEncode) - } else if ret == 8 { - Err(WhisperError::FailedToDecode) - } else if ret == 0 { - Ok(ret) - } else { - Err(WhisperError::GenericError(ret)) - } - } - - /// Split the input audio into chunks and delegate to [WhisperContext::full]. - /// - /// It seems this approach can offer some speedup in some cases, - /// however, the accuracy can be worse at the start and end of chunks. - /// - /// # Arguments - /// * params: [crate::FullParams] struct. - /// * pcm: PCM audio data. - /// * n_processors: Number of threads to use. - /// - /// # Returns - /// Ok(c_int) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `int whisper_full_parallel(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples, int n_processors)` - pub fn full_parallel( - &mut self, - params: FullParams, - data: &[f32], - n_processors: c_int, - ) -> Result { - let ret = unsafe { - whisper_rs_sys::whisper_full_parallel( + whisper_rs_sys::whisper_full_with_state( self.ctx, + state_ptr, params.fp, data.as_ptr(), data.len() as c_int, - n_processors, ) }; if ret == -1 { @@ -573,8 +774,6 @@ impl WhisperContext { } else if ret == 8 { Err(WhisperError::FailedToDecode) } else if ret == 0 { - // note 0 is returned on success and also when initializing other contexts fails, - // causing some audio to not be processed Ok(ret) } else { Err(WhisperError::GenericError(ret)) @@ -587,8 +786,17 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_full_n_segments(struct whisper_context * ctx)` #[inline] - pub fn full_n_segments(&self) -> c_int { - unsafe { whisper_rs_sys::whisper_full_n_segments(self.ctx) } + pub fn full_n_segments(&self, key: &K) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.get_state_ptr(key)?) }) + } + + /// Language ID associated with the provided state. + /// + /// # C++ equivalent + /// `int whisper_full_lang_id_from_state(struct whisper_state * state);` + #[inline] + pub fn full_lang_id_from_state(&self, key: &K) -> Result { + Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.get_state_ptr(key)?) }) } /// Get the start time of the specified segment. @@ -599,8 +807,13 @@ impl WhisperContext { /// # C++ equivalent /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` #[inline] - pub fn full_get_segment_t0(&self, segment: c_int) -> i64 { - unsafe { whisper_rs_sys::whisper_full_get_segment_t0(self.ctx, segment) } + pub fn full_get_segment_t0(&self, key: &K, segment: c_int) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_segment_t0_from_state( + self.get_state_ptr(key)?, + segment, + ) + }) } /// Get the end time of the specified segment. @@ -611,8 +824,13 @@ impl WhisperContext { /// # C++ equivalent /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` #[inline] - pub fn full_get_segment_t1(&self, segment: c_int) -> i64 { - unsafe { whisper_rs_sys::whisper_full_get_segment_t1(self.ctx, segment) } + pub fn full_get_segment_t1(&self, key: &K, segment: c_int) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_segment_t1_from_state( + self.get_state_ptr(key)?, + segment, + ) + }) } /// Get the text of the specified segment. @@ -625,8 +843,11 @@ impl WhisperContext { /// /// # C++ equivalent /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_text(&self, segment: c_int) -> Result { - let ret = unsafe { whisper_rs_sys::whisper_full_get_segment_text(self.ctx, segment) }; + pub fn full_get_segment_text(&self, key: &K, segment: c_int) -> Result { + let state_ptr = self.get_state_ptr(key)?; + + let ret = + unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(state_ptr, segment) }; if ret.is_null() { return Err(WhisperError::NullPointer); } @@ -646,8 +867,10 @@ impl WhisperContext { /// # C++ equivalent /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` #[inline] - pub fn full_n_tokens(&self, segment: c_int) -> c_int { - unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) } + pub fn full_n_tokens(&self, key: &K, segment: c_int) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_n_tokens_from_state(self.get_state_ptr(key)?, segment) + }) } /// Get the token text of the specified token in the specified segment. @@ -663,10 +886,16 @@ impl WhisperContext { /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` pub fn full_get_token_text( &self, + key: &K, segment: c_int, token: c_int, ) -> Result { - let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) }; + let state_ptr = self.get_state_ptr(key)?; + let ret = unsafe { + whisper_rs_sys::whisper_full_get_token_text_from_state( + self.ctx, state_ptr, segment, token, + ) + }; if ret.is_null() { return Err(WhisperError::NullPointer); } @@ -686,8 +915,19 @@ impl WhisperContext { /// /// # C++ equivalent /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_id(&self, segment: c_int, token: c_int) -> WhisperToken { - unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) } + pub fn full_get_token_id( + &self, + key: &K, + segment: c_int, + token: c_int, + ) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_token_id_from_state( + self.get_state_ptr(key)?, + segment, + token, + ) + }) } /// Get token data for the specified token in the specified segment. @@ -702,8 +942,19 @@ impl WhisperContext { /// # C++ equivalent /// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)` #[inline] - pub fn full_get_token_data(&self, segment: c_int, token: c_int) -> WhisperTokenData { - unsafe { whisper_rs_sys::whisper_full_get_token_data(self.ctx, segment, token) } + pub fn full_get_token_data( + &self, + key: &K, + segment: c_int, + token: c_int, + ) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_token_data_from_state( + self.get_state_ptr(key)?, + segment, + token, + ) + }) } /// Get the probability of the specified token in the specified segment. @@ -718,12 +969,23 @@ impl WhisperContext { /// # C++ equivalent /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` #[inline] - pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 { - unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) } + pub fn full_get_token_prob( + &self, + key: &K, + segment: c_int, + token: c_int, + ) -> Result { + Ok(unsafe { + whisper_rs_sys::whisper_full_get_token_p_from_state( + self.get_state_ptr(key)?, + segment, + token, + ) + }) } } -impl Drop for WhisperContext { +impl Drop for WhisperContext { #[inline] fn drop(&mut self) { unsafe { whisper_rs_sys::whisper_free(self.ctx) }; @@ -732,6 +994,5 @@ impl Drop for WhisperContext { // following implementations are safe // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 -// concurrent usage is prevented by &mut self on methods that modify the struct -unsafe impl Send for WhisperContext {} -unsafe impl Sync for WhisperContext {} +unsafe impl Send for WhisperContext {} +unsafe impl Sync for WhisperContext {} diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 432ba63..493db09 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -175,6 +175,15 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.max_len = max_len; } + /// # EXPERIMENTAL + /// + /// Should the timestamps be split on words instead of characters? + /// + /// Defaults to false. + pub fn set_split_on_word(&mut self, split_on_word: bool) { + self.fp.split_on_word = split_on_word; + } + /// # EXPERIMENTAL /// /// Set maximum tokens per segment. 0 means no limit. @@ -243,6 +252,14 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.suppress_blank = suppress_blank; } + /// Set suppress_non_speech_tokens. See https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253 + /// for more information. + /// + /// Defaults to false. + pub fn set_suppress_non_speech_tokens(&mut self, suppress_non_speech_tokens: bool) { + self.fp.suppress_non_speech_tokens = suppress_non_speech_tokens; + } + /// Set initial decoding temperature. See https://ai.stackexchange.com/a/32478 for more information. /// /// Defaults to 0.0. @@ -290,7 +307,7 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.logprob_thold = logprob_thold; } - /// Set no_speech_thold. Currently (as of v1.2.0) not implemented. + /// Set no_speech_thold. Currently (as of v1.3.0) not implemented. /// /// Defaults to 0.6. pub fn set_no_speech_thold(&mut self, no_speech_thold: f32) { @@ -325,7 +342,35 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.new_segment_callback_user_data = user_data; } - /// Set the callback for starting the encoder. + /// Set the callback for progress updates. + /// + /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). + /// It is still a C callback. + /// + /// # Safety + /// Do not use this function unless you know what you are doing. + /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// + /// Defaults to None. + pub unsafe fn set_progress_callback( + &mut self, + progress_callback: crate::WhisperProgressCallback, + ) { + self.fp.progress_callback = progress_callback; + } + + /// Set the user data to be passed to the progress callback. + /// + /// # Safety + /// See the safety notes for `set_progress_callback`. + /// + /// Defaults to None. + pub unsafe fn set_progress_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { + self.fp.progress_callback_user_data = user_data; + } + + /// Set the callback that is called each time before the encoder begins. /// /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). /// It is still a C callback. @@ -355,6 +400,37 @@ impl<'a, 'b> FullParams<'a, 'b> { ) { self.fp.encoder_begin_callback_user_data = user_data; } + + /// Set the callback that is called by each decoder to filter obtained logits. + /// + /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). + /// It is still a C callback. + /// + /// # Safety + /// Do not use this function unless you know what you are doing. + /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// + /// Defaults to None. + pub unsafe fn set_filter_logits_callback( + &mut self, + logits_filter_callback: crate::WhisperLogitsFilterCallback, + ) { + self.fp.logits_filter_callback = logits_filter_callback; + } + + /// Set the user data to be passed to the logits filter callback. + /// + /// # Safety + /// See the safety notes for `set_filter_logits_callback`. + /// + /// Defaults to None. + pub unsafe fn set_filter_logits_callback_user_data( + &mut self, + user_data: *mut std::ffi::c_void, + ) { + self.fp.logits_filter_callback_user_data = user_data; + } } // following implementations are safe diff --git a/src/whisper_state.rs b/src/whisper_state.rs new file mode 100644 index 0000000..64453e9 --- /dev/null +++ b/src/whisper_state.rs @@ -0,0 +1,26 @@ +/// Rustified pointer to a Whisper state. +#[derive(Debug)] +pub struct WhisperState { + ptr: *mut whisper_rs_sys::whisper_state, +} + +unsafe impl Send for WhisperState {} +unsafe impl Sync for WhisperState {} + +impl Drop for WhisperState { + fn drop(&mut self) { + unsafe { + whisper_rs_sys::whisper_free_state(self.ptr); + } + } +} + +impl WhisperState { + pub(crate) unsafe fn new(ptr: *mut whisper_rs_sys::whisper_state) -> Self { + Self { ptr } + } + + pub(crate) fn as_ptr(&self) -> *mut whisper_rs_sys::whisper_state { + self.ptr + } +} diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 14e421c..1c1eb9d 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -1,19 +1,26 @@ [package] name = "whisper-rs-sys" -version = "0.3.2" +version = "0.4.0" edition = "2021" description = "Rust bindings for whisper.cpp (FFI bindings)" license = "Unlicense" documentation = "https://docs.rs/whisper-rs-sys" repository = "https://github.com/tazz4843/whisper-rs" links = "whisper" -exclude = [ - "whisper.cpp/bindings", - "whisper.cpp/examples", - "whisper.cpp/extra", - "whisper.cpp/models", - "whisper.cpp/samples", - "whisper.cpp/tests", +include = [ + "whisper.cpp/bindings/javascript/package-tmpl.json", + "whisper.cpp/bindings/CMakeLists.txt", + "whisper.cpp/cmake", + "whisper.cpp/coreml", + "whisper.cpp/CMakeLists.txt", + "whisper.cpp/ggml.c", + "whisper.cpp/ggml.h", + "whisper.cpp/LICENSE", + "whisper.cpp/whisper.cpp", + "whisper.cpp/whisper.h", + "src/*.rs", + "build.rs", + "wrapper.h", ] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html diff --git a/sys/build.rs b/sys/build.rs index bd440e1..dc45e30 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -108,7 +108,10 @@ fn main() { } // clean the whisper build directory to prevent Cargo from complaining during crate publish + env::set_current_dir("..").expect("Unable to change directory to whisper.cpp"); _ = std::fs::remove_dir_all("build"); + // for whatever reason this file is generated during build and triggers cargo complaining + _ = std::fs::remove_file("bindings/javascript/package.json"); } // From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462 diff --git a/sys/src/bindings.rs b/sys/src/bindings.rs index 638ce68..fbcb505 100644 --- a/sys/src/bindings.rs +++ b/sys/src/bindings.rs @@ -1,4 +1,4 @@ -/* automatically generated by rust-bindgen 0.61.0 */ +/* automatically generated by rust-bindgen 0.65.1 */ pub const _STDINT_H: u32 = 1; pub const _FEATURES_H: u32 = 1; @@ -17,6 +17,10 @@ pub const __USE_POSIX199506: u32 = 1; pub const __USE_XOPEN2K: u32 = 1; pub const __USE_XOPEN2K8: u32 = 1; pub const _ATFILE_SOURCE: u32 = 1; +pub const __WORDSIZE: u32 = 64; +pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; +pub const __SYSCALL_WORDSIZE: u32 = 64; +pub const __TIMESIZE: u32 = 64; pub const __USE_MISC: u32 = 1; pub const __USE_ATFILE: u32 = 1; pub const __USE_FORTIFY_LEVEL: u32 = 0; @@ -24,31 +28,31 @@ pub const __GLIBC_USE_DEPRECATED_GETS: u32 = 0; pub const __GLIBC_USE_DEPRECATED_SCANF: u32 = 0; pub const _STDC_PREDEF_H: u32 = 1; pub const __STDC_IEC_559__: u32 = 1; +pub const __STDC_IEC_60559_BFP__: u32 = 201404; pub const __STDC_IEC_559_COMPLEX__: u32 = 1; +pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; pub const __STDC_ISO_10646__: u32 = 201706; pub const __GNU_LIBRARY__: u32 = 6; pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 31; +pub const __GLIBC_MINOR__: u32 = 37; pub const _SYS_CDEFS_H: u32 = 1; pub const __glibc_c99_flexarr_available: u32 = 1; -pub const __WORDSIZE: u32 = 64; -pub const __WORDSIZE_TIME64_COMPAT32: u32 = 1; -pub const __SYSCALL_WORDSIZE: u32 = 64; -pub const __LONG_DOUBLE_USES_FLOAT128: u32 = 0; +pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; pub const __HAVE_GENERIC_SELECTION: u32 = 1; pub const __GLIBC_USE_LIB_EXT2: u32 = 0; pub const __GLIBC_USE_IEC_60559_BFP_EXT: u32 = 0; pub const __GLIBC_USE_IEC_60559_BFP_EXT_C2X: u32 = 0; +pub const __GLIBC_USE_IEC_60559_EXT: u32 = 0; pub const __GLIBC_USE_IEC_60559_FUNCS_EXT: u32 = 0; pub const __GLIBC_USE_IEC_60559_FUNCS_EXT_C2X: u32 = 0; pub const __GLIBC_USE_IEC_60559_TYPES_EXT: u32 = 0; pub const _BITS_TYPES_H: u32 = 1; -pub const __TIMESIZE: u32 = 64; pub const _BITS_TYPESIZES_H: u32 = 1; pub const __OFF_T_MATCHES_OFF64_T: u32 = 1; pub const __INO_T_MATCHES_INO64_T: u32 = 1; pub const __RLIM_T_MATCHES_RLIM64_T: u32 = 1; pub const __STATFS_MATCHES_STATFS64: u32 = 1; +pub const __KERNEL_OLD_TIMEVAL_MATCHES_TIMEVAL64: u32 = 1; pub const __FD_SETSIZE: u32 = 1024; pub const _BITS_TIME64_H: u32 = 1; pub const _BITS_WCHAR_H: u32 = 1; @@ -91,9 +95,9 @@ pub const SIG_ATOMIC_MAX: u32 = 2147483647; pub const SIZE_MAX: i32 = -1; pub const WINT_MIN: u32 = 0; pub const WINT_MAX: u32 = 4294967295; +pub const __bool_true_false_are_defined: u32 = 1; pub const true_: u32 = 1; pub const false_: u32 = 0; -pub const __bool_true_false_are_defined: u32 = 1; pub const WHISPER_SAMPLE_RATE: u32 = 16000; pub const WHISPER_N_FFT: u32 = 400; pub const WHISPER_N_MEL: u32 = 80; @@ -214,6 +218,7 @@ pub type __id_t = ::std::os::raw::c_uint; pub type __time_t = ::std::os::raw::c_long; pub type __useconds_t = ::std::os::raw::c_uint; pub type __suseconds_t = ::std::os::raw::c_long; +pub type __suseconds64_t = ::std::os::raw::c_long; pub type __daddr_t = ::std::os::raw::c_int; pub type __key_t = ::std::os::raw::c_int; pub type __clockid_t = ::std::os::raw::c_int; @@ -257,6 +262,11 @@ pub type uintmax_t = __uintmax_t; pub struct whisper_context { _unused: [u8; 0], } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct whisper_state { + _unused: [u8; 0], +} pub type whisper_token = ::std::os::raw::c_int; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -459,9 +469,29 @@ extern "C" { extern "C" { pub fn whisper_init(loader: *mut whisper_model_loader) -> *mut whisper_context; } +extern "C" { + pub fn whisper_init_from_file_no_state( + path_model: *const ::std::os::raw::c_char, + ) -> *mut whisper_context; +} +extern "C" { + pub fn whisper_init_from_buffer_no_state( + buffer: *mut ::std::os::raw::c_void, + buffer_size: usize, + ) -> *mut whisper_context; +} +extern "C" { + pub fn whisper_init_no_state(loader: *mut whisper_model_loader) -> *mut whisper_context; +} +extern "C" { + pub fn whisper_init_state(ctx: *mut whisper_context) -> *mut whisper_state; +} extern "C" { pub fn whisper_free(ctx: *mut whisper_context); } +extern "C" { + pub fn whisper_free_state(state: *mut whisper_state); +} extern "C" { pub fn whisper_pcm_to_mel( ctx: *mut whisper_context, @@ -470,6 +500,32 @@ extern "C" { n_threads: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_pcm_to_mel_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_pcm_to_mel_phase_vocoder( + ctx: *mut whisper_context, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_pcm_to_mel_phase_vocoder_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_set_mel( ctx: *mut whisper_context, @@ -478,6 +534,15 @@ extern "C" { n_mel: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_set_mel_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + data: *const f32, + n_len: ::std::os::raw::c_int, + n_mel: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_encode( ctx: *mut whisper_context, @@ -485,6 +550,14 @@ extern "C" { n_threads: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_encode_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + offset: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_decode( ctx: *mut whisper_context, @@ -494,6 +567,16 @@ extern "C" { n_threads: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_decode_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + tokens: *const whisper_token, + n_tokens: ::std::os::raw::c_int, + n_past: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_tokenize( ctx: *mut whisper_context, @@ -519,9 +602,21 @@ extern "C" { lang_probs: *mut f32, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_lang_auto_detect_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + offset_ms: ::std::os::raw::c_int, + n_threads: ::std::os::raw::c_int, + lang_probs: *mut f32, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_n_len(ctx: *mut whisper_context) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_n_len_from_state(state: *mut whisper_state) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_n_vocab(ctx: *mut whisper_context) -> ::std::os::raw::c_int; } @@ -534,15 +629,57 @@ extern "C" { extern "C" { pub fn whisper_is_multilingual(ctx: *mut whisper_context) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_model_n_vocab(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_audio_ctx(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_audio_state(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_audio_head(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_audio_layer(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_text_ctx(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_text_state(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_text_head(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_text_layer(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_n_mels(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_f16(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_model_type(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_get_logits(ctx: *mut whisper_context) -> *mut f32; } +extern "C" { + pub fn whisper_get_logits_from_state(state: *mut whisper_state) -> *mut f32; +} extern "C" { pub fn whisper_token_to_str( ctx: *mut whisper_context, token: whisper_token, ) -> *const ::std::os::raw::c_char; } +extern "C" { + pub fn whisper_model_type_readable(ctx: *mut whisper_context) -> *const ::std::os::raw::c_char; +} extern "C" { pub fn whisper_token_eot(ctx: *mut whisper_context) -> whisper_token; } @@ -584,17 +721,39 @@ extern "C" { } pub const whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY: whisper_sampling_strategy = 0; pub const whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH: whisper_sampling_strategy = 1; -#[doc = ""] pub type whisper_sampling_strategy = ::std::os::raw::c_uint; pub type whisper_new_segment_callback = ::std::option::Option< unsafe extern "C" fn( ctx: *mut whisper_context, + state: *mut whisper_state, n_new: ::std::os::raw::c_int, user_data: *mut ::std::os::raw::c_void, ), >; +pub type whisper_progress_callback = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + progress: ::std::os::raw::c_int, + user_data: *mut ::std::os::raw::c_void, + ), +>; pub type whisper_encoder_begin_callback = ::std::option::Option< - unsafe extern "C" fn(ctx: *mut whisper_context, user_data: *mut ::std::os::raw::c_void) -> bool, + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + user_data: *mut ::std::os::raw::c_void, + ) -> bool, +>; +pub type whisper_logits_filter_callback = ::std::option::Option< + unsafe extern "C" fn( + ctx: *mut whisper_context, + state: *mut whisper_state, + tokens: *const whisper_token_data, + n_tokens: ::std::os::raw::c_int, + logits: *mut f32, + user_data: *mut ::std::os::raw::c_void, + ), >; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -615,13 +774,16 @@ pub struct whisper_full_params { pub thold_pt: f32, pub thold_ptsum: f32, pub max_len: ::std::os::raw::c_int, + pub split_on_word: bool, pub max_tokens: ::std::os::raw::c_int, pub speed_up: bool, pub audio_ctx: ::std::os::raw::c_int, + pub initial_prompt: *const ::std::os::raw::c_char, pub prompt_tokens: *const whisper_token, pub prompt_n_tokens: ::std::os::raw::c_int, pub language: *const ::std::os::raw::c_char, pub suppress_blank: bool, + pub suppress_non_speech_tokens: bool, pub temperature: f32, pub max_initial_ts: f32, pub length_penalty: f32, @@ -633,8 +795,12 @@ pub struct whisper_full_params { pub beam_search: whisper_full_params__bindgen_ty_2, pub new_segment_callback: whisper_new_segment_callback, pub new_segment_callback_user_data: *mut ::std::os::raw::c_void, + pub progress_callback: whisper_progress_callback, + pub progress_callback_user_data: *mut ::std::os::raw::c_void, pub encoder_begin_callback: whisper_encoder_begin_callback, pub encoder_begin_callback_user_data: *mut ::std::os::raw::c_void, + pub logits_filter_callback: whisper_logits_filter_callback, + pub logits_filter_callback_user_data: *mut ::std::os::raw::c_void, } #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -721,7 +887,7 @@ fn bindgen_test_layout_whisper_full_params() { let ptr = UNINIT.as_ptr(); assert_eq!( ::std::mem::size_of::(), - 160usize, + 200usize, concat!("Size of: ", stringify!(whisper_full_params)) ); assert_eq!( @@ -890,8 +1056,18 @@ fn bindgen_test_layout_whisper_full_params() { ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).max_tokens) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).split_on_word) as usize - ptr as usize }, 40usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(split_on_word) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).max_tokens) as usize - ptr as usize }, + 44usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -901,7 +1077,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).speed_up) as usize - ptr as usize }, - 44usize, + 48usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -911,7 +1087,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).audio_ctx) as usize - ptr as usize }, - 48usize, + 52usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -920,8 +1096,18 @@ fn bindgen_test_layout_whisper_full_params() { ) ); assert_eq!( - unsafe { ::std::ptr::addr_of!((*ptr).prompt_tokens) as usize - ptr as usize }, + unsafe { ::std::ptr::addr_of!((*ptr).initial_prompt) as usize - ptr as usize }, 56usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(initial_prompt) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).prompt_tokens) as usize - ptr as usize }, + 64usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -931,7 +1117,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).prompt_n_tokens) as usize - ptr as usize }, - 64usize, + 72usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -941,7 +1127,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).language) as usize - ptr as usize }, - 72usize, + 80usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -951,7 +1137,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).suppress_blank) as usize - ptr as usize }, - 80usize, + 88usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -959,9 +1145,19 @@ fn bindgen_test_layout_whisper_full_params() { stringify!(suppress_blank) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).suppress_non_speech_tokens) as usize - ptr as usize }, + 89usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(suppress_non_speech_tokens) + ) + ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).temperature) as usize - ptr as usize }, - 84usize, + 92usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -971,7 +1167,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).max_initial_ts) as usize - ptr as usize }, - 88usize, + 96usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -981,7 +1177,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).length_penalty) as usize - ptr as usize }, - 92usize, + 100usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -991,7 +1187,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).temperature_inc) as usize - ptr as usize }, - 96usize, + 104usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1001,7 +1197,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).entropy_thold) as usize - ptr as usize }, - 100usize, + 108usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1011,7 +1207,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).logprob_thold) as usize - ptr as usize }, - 104usize, + 112usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1021,7 +1217,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).no_speech_thold) as usize - ptr as usize }, - 108usize, + 116usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1031,7 +1227,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).greedy) as usize - ptr as usize }, - 112usize, + 120usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1041,7 +1237,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).beam_search) as usize - ptr as usize }, - 116usize, + 124usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1051,7 +1247,7 @@ fn bindgen_test_layout_whisper_full_params() { ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).new_segment_callback) as usize - ptr as usize }, - 128usize, + 136usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1063,7 +1259,7 @@ fn bindgen_test_layout_whisper_full_params() { unsafe { ::std::ptr::addr_of!((*ptr).new_segment_callback_user_data) as usize - ptr as usize }, - 136usize, + 144usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1071,9 +1267,29 @@ fn bindgen_test_layout_whisper_full_params() { stringify!(new_segment_callback_user_data) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).progress_callback) as usize - ptr as usize }, + 152usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(progress_callback) + ) + ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).progress_callback_user_data) as usize - ptr as usize }, + 160usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(progress_callback_user_data) + ) + ); assert_eq!( unsafe { ::std::ptr::addr_of!((*ptr).encoder_begin_callback) as usize - ptr as usize }, - 144usize, + 168usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1085,7 +1301,7 @@ fn bindgen_test_layout_whisper_full_params() { unsafe { ::std::ptr::addr_of!((*ptr).encoder_begin_callback_user_data) as usize - ptr as usize }, - 152usize, + 176usize, concat!( "Offset of field: ", stringify!(whisper_full_params), @@ -1093,6 +1309,28 @@ fn bindgen_test_layout_whisper_full_params() { stringify!(encoder_begin_callback_user_data) ) ); + assert_eq!( + unsafe { ::std::ptr::addr_of!((*ptr).logits_filter_callback) as usize - ptr as usize }, + 184usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(logits_filter_callback) + ) + ); + assert_eq!( + unsafe { + ::std::ptr::addr_of!((*ptr).logits_filter_callback_user_data) as usize - ptr as usize + }, + 192usize, + concat!( + "Offset of field: ", + stringify!(whisper_full_params), + "::", + stringify!(logits_filter_callback_user_data) + ) + ); } extern "C" { pub fn whisper_full_default_params(strategy: whisper_sampling_strategy) -> whisper_full_params; @@ -1105,6 +1343,15 @@ extern "C" { n_samples: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_full_with_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + params: whisper_full_params, + samples: *const f32, + n_samples: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_full_parallel( ctx: *mut whisper_context, @@ -1117,30 +1364,63 @@ extern "C" { extern "C" { pub fn whisper_full_n_segments(ctx: *mut whisper_context) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_full_n_segments_from_state(state: *mut whisper_state) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_full_lang_id(ctx: *mut whisper_context) -> ::std::os::raw::c_int; +} +extern "C" { + pub fn whisper_full_lang_id_from_state(state: *mut whisper_state) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_full_get_segment_t0( ctx: *mut whisper_context, i_segment: ::std::os::raw::c_int, ) -> i64; } +extern "C" { + pub fn whisper_full_get_segment_t0_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> i64; +} extern "C" { pub fn whisper_full_get_segment_t1( ctx: *mut whisper_context, i_segment: ::std::os::raw::c_int, ) -> i64; } +extern "C" { + pub fn whisper_full_get_segment_t1_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> i64; +} extern "C" { pub fn whisper_full_get_segment_text( ctx: *mut whisper_context, i_segment: ::std::os::raw::c_int, ) -> *const ::std::os::raw::c_char; } +extern "C" { + pub fn whisper_full_get_segment_text_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} extern "C" { pub fn whisper_full_n_tokens( ctx: *mut whisper_context, i_segment: ::std::os::raw::c_int, ) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_full_n_tokens_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn whisper_full_get_token_text( ctx: *mut whisper_context, @@ -1148,6 +1428,14 @@ extern "C" { i_token: ::std::os::raw::c_int, ) -> *const ::std::os::raw::c_char; } +extern "C" { + pub fn whisper_full_get_token_text_from_state( + ctx: *mut whisper_context, + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} extern "C" { pub fn whisper_full_get_token_id( ctx: *mut whisper_context, @@ -1155,6 +1443,13 @@ extern "C" { i_token: ::std::os::raw::c_int, ) -> whisper_token; } +extern "C" { + pub fn whisper_full_get_token_id_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token; +} extern "C" { pub fn whisper_full_get_token_data( ctx: *mut whisper_context, @@ -1162,6 +1457,13 @@ extern "C" { i_token: ::std::os::raw::c_int, ) -> whisper_token_data; } +extern "C" { + pub fn whisper_full_get_token_data_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> whisper_token_data; +} extern "C" { pub fn whisper_full_get_token_p( ctx: *mut whisper_context, @@ -1170,9 +1472,25 @@ extern "C" { ) -> f32; } extern "C" { - #[doc = ""] + pub fn whisper_full_get_token_p_from_state( + state: *mut whisper_state, + i_segment: ::std::os::raw::c_int, + i_token: ::std::os::raw::c_int, + ) -> f32; +} +extern "C" { pub fn whisper_bench_memcpy(n_threads: ::std::os::raw::c_int) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_bench_memcpy_str( + n_threads: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +} extern "C" { pub fn whisper_bench_ggml_mul_mat(n_threads: ::std::os::raw::c_int) -> ::std::os::raw::c_int; } +extern "C" { + pub fn whisper_bench_ggml_mul_mat_str( + n_threads: ::std::os::raw::c_int, + ) -> *const ::std::os::raw::c_char; +}