diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..998e28e --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,10 @@ +# Version 0.2.0 (2022-10-28) +* Update upstream whisper.cpp to 2c281d190b7ec351b8128ba386d110f100993973. +* Fix breaking changes in update, which cascade to users: + * `DecodeStrategy` has been renamed to `SamplingStrategy` + * `WhisperContext::sample_best`'s signature has changed: `needs_timestamp` has been removed. +* New features + * `WhisperContext::full_n_tokens` + * `WhisperContext::full_get_token_text` + * `WhisperContext::full_get_token_id` + * `WhisperContext::full_get_token_prob` diff --git a/Cargo.toml b/Cargo.toml index 30cf26b..7ace970 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ members = ["sys"] [package] name = "whisper-rs" -version = "0.1.3" +version = "0.2.0" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -13,7 +13,7 @@ repository = "https://github.com/tazz4843/whisper-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -whisper-rs-sys = { path = "sys", version = "0.1" } +whisper-rs-sys = { path = "sys", version = "0.2" } [features] simd = [] diff --git a/README.md b/README.md index 9d3fdfd..a04d5fb 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,14 @@ See the docs: https://docs.rs/whisper-rs/ for more details. * Windows/macOS/Android aren't working! * I don't have a way to test these platforms, so I can't really help you. * If you can get it working, please open a PR! +* I get a panic during binding generation build! + * You can attempt to fix it yourself, or you can set the `WHISPER_DONT_GENERATE_BINDINGS` environment variable. + This skips attempting to build the bindings whatsoever and copies the existing ones. They may be out of date, + but it's better than nothing. + * `WHISPER_DONT_GENERATE_BINDINGS=1 cargo build` + * If you can fix the issue, please open a PR! +* M1 build info: + * See [this issue](https://github.com/tazz4843/whisper-rs/pull/2) for more info. ## License [Unlicense](LICENSE) diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 2ff5ba1..b021c93 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -1,4 +1,4 @@ -use whisper_rs::{DecodeStrategy, FullParams, WhisperContext}; +use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; // note that running this example will not do anything, as it is just a // demonstration of how to use the library, and actual usage requires @@ -10,7 +10,7 @@ pub fn usage() { // create a params object // note that currently the only implemented strategy is Greedy, BeamSearch is a WIP // n_past defaults to 0 - let mut params = FullParams::new(DecodeStrategy::Greedy { n_past: 0 }); + let mut params = FullParams::new(SamplingStrategy::Greedy { n_past: 0 }); // edit things as needed // here we set the number of threads to use to 1 diff --git a/src/lib.rs b/src/lib.rs index 7bb2fd6..fcc614c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,6 +10,7 @@ pub use error::WhisperError; pub use standalone::*; pub use utilities::*; pub use whisper_ctx::WhisperContext; -pub use whisper_params::{DecodeStrategy, FullParams}; +pub use whisper_params::{FullParams, SamplingStrategy}; pub type WhisperToken = std::ffi::c_int; +pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback; diff --git a/src/standalone.rs b/src/standalone.rs index e111568..3f36e71 100644 --- a/src/standalone.rs +++ b/src/standalone.rs @@ -1,7 +1,7 @@ //! Standalone functions that have no associated type. use crate::WhisperToken; -use std::ffi::{c_int, CString}; +use std::ffi::{c_int, CStr, CString}; /// Return the id of the specified language, returns -1 if not found /// @@ -42,3 +42,13 @@ pub fn token_translate() -> WhisperToken { pub fn token_transcribe() -> WhisperToken { unsafe { whisper_rs_sys::whisper_token_transcribe() } } + +/// Print system information. +/// +/// # C++ equivalent +/// `const char * whisper_print_system_info()` +pub fn print_system_info() -> &'static str { + let c_buf = unsafe { whisper_rs_sys::whisper_print_system_info() }; + let c_str = unsafe { CStr::from_ptr(c_buf) }; + c_str.to_str().unwrap() +} \ No newline at end of file diff --git a/src/utilities.rs b/src/utilities.rs index 47b3374..4d210b8 100644 --- a/src/utilities.rs +++ b/src/utilities.rs @@ -111,6 +111,7 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec { mono } +#[cfg(feature = "simd")] #[cfg(test)] mod test { use super::*; diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 7c3ee96..514f172 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -194,11 +194,11 @@ impl WhisperContext { /// /// # C++ equivalent /// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)` - pub fn sample_best(&mut self, needs_timestamp: bool) -> Result { + pub fn sample_best(&mut self) -> Result { if !self.decode_once { return Err(WhisperError::DecodeNotComplete); } - let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) }; + let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) }; Ok(ret) } @@ -363,7 +363,7 @@ impl WhisperContext { unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) } } - /// Print performance statistics to stdout. + /// Print performance statistics to stderr. /// /// # C++ equivalent /// `void whisper_print_timings(struct whisper_context * ctx)` @@ -446,6 +446,90 @@ impl WhisperContext { let r_str = c_str.to_str()?; Ok(r_str.to_string()) } + + /// Get number of tokens in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// + /// # Returns + /// Ok(c_int) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)` + pub fn full_n_tokens(&self, segment: c_int) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as c_int) + } + } + + /// Get the token text of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// Ok(String) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_text( + &self, + segment: c_int, + token: c_int, + ) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + let c_str = unsafe { CStr::from_ptr(ret) }; + let r_str = c_str.to_str()?; + Ok(r_str.to_string()) + } + + /// Get the token ID of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// Ok(WhisperToken) on success, Err(WhisperError) on failure. + /// + /// # C++ equivalent + /// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_id( + &self, + segment: c_int, + token: c_int, + ) -> Result { + let ret = unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) }; + if ret < 0 { + Err(WhisperError::GenericError(ret)) + } else { + Ok(ret as WhisperToken) + } + } + + /// Get the probability of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// f32 + /// + /// # C++ equivalent + /// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)` + #[inline] + pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 { + unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) } + } } impl Drop for WhisperContext { diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 7d36c78..6121a8c 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -1,7 +1,7 @@ -use std::ffi::c_int; +use std::ffi::{c_int, CString}; use std::marker::PhantomData; -pub enum DecodeStrategy { +pub enum SamplingStrategy { Greedy { n_past: c_int, }, @@ -20,26 +20,30 @@ pub struct FullParams<'a> { impl<'a> FullParams<'a> { /// Create a new set of parameters for the decoder. - pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> { + pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> { let mut fp = unsafe { - whisper_rs_sys::whisper_full_default_params(match decode_strategy { - DecodeStrategy::Greedy { .. } => 0, - DecodeStrategy::BeamSearch { .. } => 1, + whisper_rs_sys::whisper_full_default_params(match sampling_strategy { + SamplingStrategy::Greedy { .. } => { + whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY + } + SamplingStrategy::BeamSearch { .. } => { + whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH + } } as _) }; - match decode_strategy { - DecodeStrategy::Greedy { n_past } => { - fp.__bindgen_anon_1.greedy.n_past = n_past; + match sampling_strategy { + SamplingStrategy::Greedy { n_past } => { + fp.greedy.n_past = n_past; } - DecodeStrategy::BeamSearch { + SamplingStrategy::BeamSearch { n_past, beam_width, n_best, } => { - fp.__bindgen_anon_1.beam_search.n_past = n_past; - fp.__bindgen_anon_1.beam_search.beam_width = beam_width; - fp.__bindgen_anon_1.beam_search.n_best = n_best; + fp.beam_search.n_past = n_past; + fp.beam_search.beam_width = beam_width; + fp.beam_search.n_best = n_best; } } @@ -109,7 +113,36 @@ impl<'a> FullParams<'a> { /// /// Defaults to "en". pub fn set_language(&mut self, language: &'a str) { - self.fp.language = language.as_ptr() as *const _; + let c_lang = CString::new(language).expect("Language contains null byte"); + self.fp.language = c_lang.into_raw() as *const _; + } + + /// Set the callback for new segments. + /// + /// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so). + /// It is still a C callback. + /// + /// # Safety + /// Do not use this function unless you know what you are doing. + /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// + /// Defaults to None. + pub unsafe fn set_new_segment_callback( + &mut self, + new_segment_callback: crate::WhisperNewSegmentCallback, + ) { + self.fp.new_segment_callback = new_segment_callback; + } + + /// Set the user data to be passed to the new segment callback. + /// + /// # Safety + /// See the safety notes for `set_new_segment_callback`. + /// + /// Defaults to None. + pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { + self.fp.new_segment_callback_user_data = user_data; } } diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 5479cf3..45ac4fd 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -13,4 +13,4 @@ links = "whisper" [dependencies] [build-dependencies] -bindgen = "0.60" +bindgen = "0.61" diff --git a/sys/build.rs b/sys/build.rs index 54a20f8..e7a60aa 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -8,29 +8,37 @@ fn main() { println!("cargo:rustc-link-lib=static=whisper"); println!("cargo:rerun-if-changed=wrapper.h"); - let bindings = bindgen::Builder::default() - .header("wrapper.h") - .clang_arg("-I./whisper.cpp") - .parse_callbacks(Box::new(bindgen::CargoCallbacks)) - .generate(); + if env::var("WHISPER_DONT_GENERATE_BINDINGS").is_ok() { + let _: u64 = std::fs::copy( + "src/bindings.rs", + env::var("OUT_DIR").unwrap() + "/bindings.rs", + ) + .expect("Failed to copy bindings.rs"); + } else { + let bindings = bindgen::Builder::default() + .header("wrapper.h") + .clang_arg("-I./whisper.cpp") + .parse_callbacks(Box::new(bindgen::CargoCallbacks)) + .generate(); - match bindings { - Ok(b) => { - let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); - b.write_to_file(out_path.join("bindings.rs")) - .expect("Couldn't write bindings!"); + match bindings { + Ok(b) => { + let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); + b.write_to_file(out_path.join("bindings.rs")) + .expect("Couldn't write bindings!"); + } + Err(e) => { + println!("cargo:warning=Unable to generate bindings: {}", e); + println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); + // copy src/bindings.rs to OUT_DIR + std::fs::copy( + "src/bindings.rs", + env::var("OUT_DIR").unwrap() + "/bindings.rs", + ) + .expect("Unable to copy bindings.rs"); + } } - Err(e) => { - println!("cargo:warning=Unable to generate bindings: {}", e); - println!("cargo:warning=Using bundled bindings.rs, which may be out of date"); - // copy src/bindings.rs to OUT_DIR - std::fs::copy( - "src/bindings.rs", - env::var("OUT_DIR").unwrap() + "/bindings.rs", - ) - .expect("Unable to copy bindings.rs"); - } - } + }; // stop if we're on docs.rs if env::var("DOCS_RS").is_ok() {