diff --git a/src/error.rs b/src/error.rs index 592e6e8..ef87c00 100644 --- a/src/error.rs +++ b/src/error.rs @@ -35,6 +35,8 @@ pub enum WhisperError { NullPointer, /// Generic whisper error. Varies depending on the function. GenericError(c_int), + /// Whisper failed to convert the provided text into tokens. + InvalidText, } impl From for WhisperError { diff --git a/src/standalone.rs b/src/standalone.rs index 3f36e71..f5e7842 100644 --- a/src/standalone.rs +++ b/src/standalone.rs @@ -26,6 +26,34 @@ pub fn get_lang_id(lang: &str) -> Option { } } +/// Return the ID of the maximum language (ie the number of languages - 1) +/// +/// # Returns +/// i32 +/// +/// # C++ equivalent +/// `int whisper_lang_max_id()` +pub fn get_lang_max_id() -> i32 { + unsafe { whisper_rs_sys::whisper_lang_max_id() } +} + +/// Get the short string of the specified language id (e.g. 2 -> "de"). +/// +/// # Returns +/// The short string of the language, None if not found. +/// +/// # C++ equivalent +/// `const char * whisper_lang_str(int id)` +pub fn get_lang_str(id: i32) -> Option<&'static str> { + let c_buf = unsafe { whisper_rs_sys::whisper_lang_str(id) }; + if c_buf.is_null() { + None + } else { + let c_str = unsafe { CStr::from_ptr(c_buf) }; + Some(c_str.to_str().unwrap()) + } +} + // task tokens /// Get the ID of the translate task token. /// @@ -51,4 +79,4 @@ pub fn print_system_info() -> &'static str { let c_buf = unsafe { whisper_rs_sys::whisper_print_system_info() }; let c_str = unsafe { CStr::from_ptr(c_buf) }; c_str.to_str().unwrap() -} \ No newline at end of file +} diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 6d00f96..d5f1f0f 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -19,7 +19,7 @@ pub struct WhisperContext { } impl WhisperContext { - /// Create a new WhisperContext. + /// Create a new WhisperContext from a file. /// /// # Arguments /// * path: The path to the model file. @@ -28,10 +28,10 @@ impl WhisperContext { /// Ok(Self) on success, Err(WhisperError) on failure. /// /// # C++ equivalent - /// `struct whisper_context * whisper_init(const char * path_model);` + /// `struct whisper_context * whisper_init_from_file(const char * path_model);` pub fn new(path: &str) -> Result { let path_cstr = CString::new(path)?; - let ctx = unsafe { whisper_rs_sys::whisper_init(path_cstr.as_ptr()) }; + let ctx = unsafe { whisper_rs_sys::whisper_init_from_file(path_cstr.as_ptr()) }; if ctx.is_null() { Err(WhisperError::InitError) } else { @@ -44,6 +44,33 @@ impl WhisperContext { } } + /// Create a new WhisperContext from a buffer. + /// + /// # Arguments + /// * buffer: The buffer containing the model. + /// + /// # Returns + /// Ok(Self) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);` + pub fn new_from_buffer(buffer: &[u8]) -> Result { + let ctx = + unsafe { whisper_rs_sys::whisper_init_from_buffer(buffer.as_ptr() as _, buffer.len()) }; + if ctx.is_null() { + Err(WhisperError::InitError) + } else { + Ok(Self { + ctx, + spectrogram_initialized: false, + encode_complete: false, + decode_once: false, + }) + } + } + + // we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does + /// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram. /// The resulting spectrogram is stored in the context transparently. /// @@ -190,40 +217,90 @@ impl WhisperContext { } } - // Token sampling functions - /// Return the token with the highest probability. - /// Make sure to call [WhisperContext::decode] first. + /// Convert the provided text into tokens. /// /// # Arguments - /// * needs_timestamp + /// * text: The text to convert. /// /// # Returns - /// Ok(WhisperToken) on success, Err(WhisperError) on failure. + /// Ok(Vec) on success, Err(WhisperError) on failure. /// /// # C++ equivalent - /// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)` - pub fn sample_best(&mut self) -> Result { - if !self.decode_once { - return Err(WhisperError::DecodeNotComplete); + /// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);` + pub fn tokenize( + &mut self, + text: &str, + max_tokens: usize, + ) -> Result, WhisperError> { + // allocate at least max_tokens to ensure the memory is valid + let mut tokens: Vec = Vec::with_capacity(max_tokens); + let ret = unsafe { + whisper_rs_sys::whisper_tokenize( + self.ctx, + text.as_ptr() as *const _, + tokens.as_mut_ptr(), + max_tokens as c_int, + ) + }; + if ret == -1 { + Err(WhisperError::InvalidText) + } else { + // SAFETY: when ret != -1, we know that the length of the vector is at least ret tokens + unsafe { tokens.set_len(ret as usize) }; + Ok(tokens) } - let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) }; - Ok(ret) } - /// Return the token with the most probable timestamp. - /// Make sure to call [WhisperContext::decode] first. + // Language functions + /// Use mel data at offset_ms to try and auto-detect the spoken language + /// Make sure to call pcm_to_mel() or set_mel() first + /// + /// # Arguments + /// * offset_ms: The offset in milliseconds to use for the language detection. + /// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise. /// /// # Returns - /// Ok(WhisperToken) on success, Err(WhisperError) on failure. + /// Ok(Vec) on success, Err(WhisperError) on failure. /// /// # C++ equivalent - /// `whisper_token whisper_sample_timestamp(struct whisper_context * ctx)` - pub fn sample_timestamp(&mut self, is_initial: bool) -> Result { - if !self.decode_once { - return Err(WhisperError::DecodeNotComplete); + /// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)` + pub fn lang_detect( + &mut self, + offset_ms: usize, + threads: usize, + ) -> Result, WhisperError> { + if !self.spectrogram_initialized { + return Err(WhisperError::SpectrogramNotInitialized); + } + if threads < 1 { + return Err(WhisperError::InvalidThreadCount); + } + let mut lang_probs: Vec = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1]; + let ret = unsafe { + whisper_rs_sys::whisper_lang_auto_detect( + self.ctx, + offset_ms as c_int, + threads as c_int, + lang_probs.as_mut_ptr(), + ) + }; + if ret == -1 { + Err(WhisperError::UnableToCalculateEvaluation) + } else { + assert_eq!( + ret as usize, + lang_probs.len(), + "lang_probs length mismatch: this is a bug in whisper.cpp" + ); + // if we're still running, double check that the length is correct, otherwise print to stderr + // and abort, as this will cause Undefined Behavior + // might get here due to the unwind being caught by a user-installed panic handler + if lang_probs.len() != ret as usize { + eprintln!("lang_probs length mismatch: this is a bug in whisper.cpp, aborting"); + std::process::abort(); + } + Ok(lang_probs) } - let ret = unsafe { whisper_rs_sys::whisper_sample_timestamp(self.ctx, is_initial) }; - Ok(ret) } // model attributes @@ -263,6 +340,18 @@ impl WhisperContext { unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) } } + /// Get n_audio_ctx. + /// + /// # Returns + /// c_int + /// + /// # C++ equivalent + /// `int whisper_n_audio_ctx (struct whisper_context * ctx)` + #[inline] + pub fn n_audio_ctx(&self) -> c_int { + unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) } + } + /// Does this model support multiple languages? /// /// # C++ equivalent @@ -272,25 +361,46 @@ impl WhisperContext { unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 } } - /// The probabilities for the next token. - /// Make sure to call [WhisperContext::decode] first. + // logit functions + /// Get the logits obtained from the last call to [WhisperContext::decode]. + /// The logits for the last token are stored in the last row of the matrix. + /// + /// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it + /// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible. + /// + /// # Arguments + /// * segment: The segment to fetch data for. /// /// # Returns - /// Ok(*const f32) on success, Err(WhisperError) on failure. + /// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab. /// /// # C++ equivalent - /// `float * whisper_get_probs(struct whisper_context * ctx)` - pub fn get_probs(&mut self) -> Result<*const f32, WhisperError> { - if !self.decode_once { - return Err(WhisperError::DecodeNotComplete); + /// `float * whisper_get_logits(struct whisper_context * ctx)` + pub fn get_logits(&self, segment: c_int) -> Result>, WhisperError> { + if !self.spectrogram_initialized { + return Err(WhisperError::SpectrogramNotInitialized); } - let ret = unsafe { whisper_rs_sys::whisper_get_probs(self.ctx) }; + + let ret = unsafe { whisper_rs_sys::whisper_get_logits(self.ctx) }; if ret.is_null() { return Err(WhisperError::NullPointer); } - Ok(ret) + let mut logits = Vec::new(); + let n_vocab = self.n_vocab(); + let n_tokens = self.full_n_tokens(segment); + for i in 0..n_tokens { + let mut row = Vec::new(); + for j in 0..n_vocab { + let idx = (i * n_vocab) + j; + let val = unsafe { *ret.offset(idx as isize) }; + row.push(val); + } + logits.push(row); + } + Ok(logits) } + // token functions /// Convert a token ID to a string. /// /// # Arguments @@ -311,7 +421,6 @@ impl WhisperContext { Ok(r_str.to_string()) } - // special tokens /// Get the ID of the eot token. /// /// # C++ equivalent @@ -366,6 +475,18 @@ impl WhisperContext { unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) } } + /// Get the ID of a specified language token + /// + /// # Arguments + /// * lang_id: ID of the language + /// + /// # C++ equivalent + /// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)` + #[inline] + pub fn token_lang(&self, lang_id: c_int) -> WhisperToken { + unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) } + } + /// Print performance statistics to stderr. /// /// # C++ equivalent diff --git a/src/whisper_params.rs b/src/whisper_params.rs index f3c1045..c9beee7 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -1,16 +1,16 @@ -use std::ffi::{c_int, CString}; +use std::ffi::{c_float, c_int, CString}; use std::marker::PhantomData; use whisper_rs_sys::whisper_token; pub enum SamplingStrategy { Greedy { - n_past: c_int, + best_of: c_int, }, /// not implemented yet, results of using this unknown BeamSearch { - n_past: c_int, - beam_width: c_int, - n_best: c_int, + beam_size: c_int, + // not implemented in whisper.cpp as of this writing (v1.2.0) + patience: c_float, }, } @@ -35,17 +35,15 @@ impl<'a, 'b> FullParams<'a, 'b> { }; match sampling_strategy { - SamplingStrategy::Greedy { n_past } => { - fp.greedy.n_past = n_past; + SamplingStrategy::Greedy { best_of } => { + fp.greedy.best_of = best_of; } SamplingStrategy::BeamSearch { - n_past, - beam_width, - n_best, + beam_size, + patience, } => { - fp.beam_search.n_past = n_past; - fp.beam_search.beam_width = beam_width; - fp.beam_search.n_best = n_best; + fp.beam_search.beam_size = beam_size; + fp.beam_search.patience = patience; } } @@ -63,7 +61,7 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.n_threads = n_threads; } - /// Set n_max_text_ctx. + /// Max tokens to use from past text as prompt for the decoder /// /// Defaults to 16384. pub fn set_n_max_text_ctx(&mut self, n_max_text_ctx: c_int) { @@ -91,7 +89,7 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.translate = translate; } - /// Set no_context. Usage unknown. + /// Do not use past transcription (if any) as initial prompt for the decoder. /// /// Defaults to false. pub fn set_no_context(&mut self, no_context: bool) { @@ -105,7 +103,7 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.single_segment = single_segment; } - /// Set print_special. Usage unknown. + /// Print special tokens (e.g. , , , etc.) /// /// Defaults to false. pub fn set_print_special(&mut self, print_special: bool) { @@ -119,14 +117,17 @@ impl<'a, 'b> FullParams<'a, 'b> { self.fp.print_progress = print_progress; } - /// Set print_realtime. Usage unknown. + /// Print results from within whisper.cpp. + /// Try to use the callback methods instead: [set_new_segment_callback](FullParams::set_new_segment_callback), + /// [set_new_segment_callback_user_data](FullParams::set_new_segment_callback_user_data). /// /// Defaults to false. pub fn set_print_realtime(&mut self, print_realtime: bool) { self.fp.print_realtime = print_realtime; } - /// Set whether to print timestamps. + /// Print timestamps for each text segment when printing realtime. Only has an effect if + /// [set_print_realtime](FullParams::set_print_realtime) is set to true. /// /// Defaults to true. pub fn set_print_timestamps(&mut self, print_timestamps: bool) { @@ -181,6 +182,7 @@ impl<'a, 'b> FullParams<'a, 'b> { /// # EXPERIMENTAL /// /// Speed up audio ~2x by using phase vocoder. + /// Note that this can significantly reduce the accuracy of the transcription. /// /// Defaults to false. pub fn set_speed_up(&mut self, speed_up: bool) { @@ -190,6 +192,7 @@ impl<'a, 'b> FullParams<'a, 'b> { /// # EXPERIMENTAL /// /// Overwrite the audio context size. 0 = default. + /// As with [set_speed_up](FullParams::set_speed_up), this can significantly reduce the accuracy of the transcription. /// /// Defaults to 0. pub fn set_audio_ctx(&mut self, audio_ctx: c_int) { @@ -215,10 +218,78 @@ impl<'a, 'b> FullParams<'a, 'b> { /// Set the target language. /// + /// For auto-detection, set this to either "auto" or None. + /// /// Defaults to "en". - pub fn set_language(&mut self, language: &'a str) { - let c_lang = CString::new(language).expect("Language contains null byte"); - self.fp.language = c_lang.into_raw() as *const _; + pub fn set_language(&mut self, language: Option<&'a str>) { + self.fp.language = match language { + Some(language) => CString::new(language) + .expect("Language contains null byte") + .into_raw() as *const _, + None => std::ptr::null(), + }; + } + + /// Set suppress_blank. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89 + /// for more information. + /// + /// Defaults to true. + pub fn set_suppress_blank(&mut self, suppress_blank: bool) { + self.fp.suppress_blank = suppress_blank; + } + + /// Set initial decoding temperature. See https://ai.stackexchange.com/a/32478 for more information. + /// + /// Defaults to 0.0. + pub fn set_temperature(&mut self, temperature: f32) { + self.fp.temperature = temperature; + } + + /// Set max_initial_ts. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97 + /// for more information. + /// + /// Defaults to 1.0. + pub fn set_max_initial_ts(&mut self, max_initial_ts: f32) { + self.fp.max_initial_ts = max_initial_ts; + } + + /// Set length_penalty. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267 + /// for more information. + /// + /// Defaults to -1.0. + pub fn set_length_penalty(&mut self, length_penalty: f32) { + self.fp.length_penalty = length_penalty; + } + + /// Set temperature_inc. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + /// for more information. + /// + /// Defaults to 0.2. + pub fn set_temperature_inc(&mut self, temperature_inc: f32) { + self.fp.temperature_inc = temperature_inc; + } + + /// Set entropy_thold. Similar to OpenAI's compression_ratio_threshold. + /// See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 for more information. + /// + /// Defaults to 2.4. + pub fn set_entropy_thold(&mut self, entropy_thold: f32) { + self.fp.entropy_thold = entropy_thold; + } + + /// Set logprob_thold. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 + /// for more information. + /// + /// Defaults to -1.0. + pub fn set_logprob_thold(&mut self, logprob_thold: f32) { + self.fp.logprob_thold = logprob_thold; + } + + /// Set no_speech_thold. Currently (as of v1.2.0) not implemented. + /// + /// Defaults to 0.6. + pub fn set_no_speech_thold(&mut self, no_speech_thold: f32) { + self.fp.no_speech_thold = no_speech_thold; } /// Set the callback for new segments.