diff --git a/Cargo.toml b/Cargo.toml index d4eabfd..384ee76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"] [package] name = "whisper-rs" -version = "0.13.2" +version = "0.14.1" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -whisper-rs-sys = { path = "sys", version = "0.11" } +whisper-rs-sys = { path = "sys", version = "0.12" } log = { version = "0.4", optional = true } tracing = { version = "0.1", optional = true } diff --git a/src/common_logging.rs b/src/common_logging.rs index e3cffee..bf192ce 100644 --- a/src/common_logging.rs +++ b/src/common_logging.rs @@ -52,7 +52,7 @@ pub(crate) use {generic_debug, generic_error, generic_info, generic_trace, gener // Of course Windows thinks it's a special little shit and // picks a signed integer for an unsigned type #[cfg_attr(all(windows, not(target_env = "gnu")), repr(i32))] -pub(crate) enum GGMLLogLevel { +pub enum GGMLLogLevel { None = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_NONE, Info = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO, Warn = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN, diff --git a/src/lib.rs b/src/lib.rs index 893f069..a6632a9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ mod whisper_logging_hook; mod whisper_params; mod whisper_state; +pub use common_logging::GGMLLogLevel; pub use error::WhisperError; pub use standalone::*; pub use utilities::*; diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 36a709e..1bb2808 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -537,7 +537,7 @@ impl<'a, 'b> FullParams<'a, 'b> { /// # Safety /// Do not use this function unless you know what you are doing. /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. - /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. /// /// Defaults to None. pub unsafe fn set_progress_callback( @@ -652,7 +652,7 @@ impl<'a, 'b> FullParams<'a, 'b> { /// # Safety /// Do not use this function unless you know what you are doing. /// * Be careful not to mutate the state of the whisper_context pointer returned in the callback. - /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. + /// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library. /// /// Defaults to None. pub unsafe fn set_start_encoder_callback( @@ -799,8 +799,8 @@ impl<'a, 'b> FullParams<'a, 'b> { // following implementations are safe // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 // concurrent usage is prevented by &mut self on methods that modify the struct -unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {} -unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {} +unsafe impl Send for FullParams<'_, '_> {} +unsafe impl Sync for FullParams<'_, '_> {} #[cfg(test)] mod test_whisper_params_initial_prompt { diff --git a/src/whisper_state.rs b/src/whisper_state.rs index c82dd18..22418ee 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -346,25 +346,43 @@ impl WhisperState { Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) }) } + fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> { + let ret = + unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; + if ret.is_null() { + return Err(WhisperError::NullPointer); + } + unsafe { Ok(CStr::from_ptr(ret)) } + } + + /// Get the raw bytes of the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// + /// # Returns + /// `Ok(Vec)` on success, with the returned bytes or + /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) + /// + /// # C++ equivalent + /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` + pub fn full_get_segment_bytes(&self, segment: c_int) -> Result, WhisperError> { + Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec()) + } + /// Get the text of the specified segment. /// /// # Arguments /// * segment: Segment index. /// /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. + /// `Ok(String)` on success, with the UTF-8 validated string, or + /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) /// /// # C++ equivalent /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` pub fn full_get_segment_text(&self, segment: c_int) -> Result { - let ret = - unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - let c_str = unsafe { CStr::from_ptr(ret) }; - let r_str = c_str.to_str()?; - Ok(r_str.to_string()) + Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string()) } /// Get the text of the specified segment. @@ -376,38 +394,16 @@ impl WhisperState { /// * segment: Segment index. /// /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. + /// `Ok(String)` on success, or + /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) /// /// # C++ equivalent /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result { - let ret = - unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - let c_str = unsafe { CStr::from_ptr(ret) }; - Ok(c_str.to_string_lossy().to_string()) - } - - /// Get the bytes of the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// - /// # Returns - /// `Ok(Vec)` on success, `Err(WhisperError)` on failure. - /// - /// # C++ equivalent - /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` - pub fn full_get_segment_bytes(&self, segment: c_int) -> Result, WhisperError> { - let ret = - unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - let c_str = unsafe { CStr::from_ptr(ret) }; - Ok(c_str.to_bytes().to_vec()) + Ok(self + .full_get_segment_raw(segment)? + .to_string_lossy() + .to_string()) } /// Get number of tokens in the specified segment. @@ -425,22 +421,7 @@ impl WhisperState { Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) }) } - /// Get the token text of the specified token in the specified segment. - /// - /// # Arguments - /// * segment: Segment index. - /// * token: Token index. - /// - /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. - /// - /// # C++ equivalent - /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` - pub fn full_get_token_text( - &self, - segment: c_int, - token: c_int, - ) -> Result { + fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> { let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text_from_state( self.ctx.ctx, @@ -452,9 +433,53 @@ impl WhisperState { if ret.is_null() { return Err(WhisperError::NullPointer); } - let c_str = unsafe { CStr::from_ptr(ret) }; - let r_str = c_str.to_str()?; - Ok(r_str.to_string()) + unsafe { Ok(CStr::from_ptr(ret)) } + } + + /// Get the raw token bytes of the specified token in the specified segment. + /// + /// Useful if you're using a language for which whisper is known to split tokens + /// away from UTF-8 character boundaries. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// `Ok(Vec)` on success, with the returned bytes or + /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_bytes( + &self, + segment: c_int, + token: c_int, + ) -> Result, WhisperError> { + Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec()) + } + + /// Get the token text of the specified token in the specified segment. + /// + /// # Arguments + /// * segment: Segment index. + /// * token: Token index. + /// + /// # Returns + /// `Ok(String)` on success, with the UTF-8 validated string, or + /// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`) + /// + /// # C++ equivalent + /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` + pub fn full_get_token_text( + &self, + segment: c_int, + token: c_int, + ) -> Result { + Ok(self + .full_get_token_raw(segment, token)? + .to_str()? + .to_string()) } /// Get the token text of the specified token in the specified segment. @@ -467,7 +492,8 @@ impl WhisperState { /// * token: Token index. /// /// # Returns - /// Ok(String) on success, Err(WhisperError) on failure. + /// `Ok(String)` on success, or + /// `Err(WhisperError::NullPointer)` on failure (this is the only possible error) /// /// # C++ equivalent /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` @@ -476,19 +502,10 @@ impl WhisperState { segment: c_int, token: c_int, ) -> Result { - let ret = unsafe { - whisper_rs_sys::whisper_full_get_token_text_from_state( - self.ctx.ctx, - self.ptr, - segment, - token, - ) - }; - if ret.is_null() { - return Err(WhisperError::NullPointer); - } - let c_str = unsafe { CStr::from_ptr(ret) }; - Ok(c_str.to_string_lossy().to_string()) + Ok(self + .full_get_token_raw(segment, token)? + .to_string_lossy() + .to_string()) } /// Get the token ID of the specified token in the specified segment. diff --git a/sys/Cargo.toml b/sys/Cargo.toml index 0707ec2..c97c9a2 100644 --- a/sys/Cargo.toml +++ b/sys/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "whisper-rs-sys" -version = "0.11.1" +version = "0.12.1" edition = "2021" description = "Rust bindings for whisper.cpp (FFI bindings)" license = "Unlicense" @@ -8,20 +8,20 @@ documentation = "https://docs.rs/whisper-rs-sys" repository = "https://github.com/tazz4843/whisper-rs" links = "whisper" include = [ - "whisper.cpp/bindings/javascript/package-tmpl.json", - "whisper.cpp/bindings/CMakeLists.txt", - "whisper.cpp/CMakeLists.txt", - "whisper.cpp/cmake", - "whisper.cpp/src/**", - "whisper.cpp/include/whisper.h", - "whisper.cpp/ggml/cmake", - "whisper.cpp/ggml/CMakeLists.txt", - "whisper.cpp/ggml/src/**", - "whisper.cpp/ggml/include/*.h", - "whisper.cpp/LICENSE", - "src/*.rs", - "build.rs", - "wrapper.h", + "whisper.cpp/bindings/javascript/package-tmpl.json", + "whisper.cpp/bindings/CMakeLists.txt", + "whisper.cpp/CMakeLists.txt", + "whisper.cpp/cmake", + "whisper.cpp/src/**", + "whisper.cpp/include/whisper.h", + "whisper.cpp/ggml/cmake", + "whisper.cpp/ggml/CMakeLists.txt", + "whisper.cpp/ggml/src/**", + "whisper.cpp/ggml/include/*.h", + "whisper.cpp/LICENSE", + "src/*.rs", + "build.rs", + "wrapper.h", ] # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html @@ -38,6 +38,6 @@ openmp = [] [build-dependencies] cmake = "0.1" -bindgen = "0.69" +bindgen = "0.71" cfg-if = "1" fs_extra = "1.3"