Merge remote-tracking branch 'origin/master'

# Conflicts:
#	Cargo.toml
This commit is contained in:
Niko 2025-02-24 12:45:02 -07:00
commit c34b566bf5
No known key found for this signature in database
6 changed files with 111 additions and 93 deletions

View file

@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
[package] [package]
name = "whisper-rs" name = "whisper-rs"
version = "0.13.2" version = "0.14.1"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp" description = "Rust bindings for whisper.cpp"
license = "Unlicense" license = "Unlicense"
@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
whisper-rs-sys = { path = "sys", version = "0.11" } whisper-rs-sys = { path = "sys", version = "0.12" }
log = { version = "0.4", optional = true } log = { version = "0.4", optional = true }
tracing = { version = "0.1", optional = true } tracing = { version = "0.1", optional = true }

View file

@ -52,7 +52,7 @@ pub(crate) use {generic_debug, generic_error, generic_info, generic_trace, gener
// Of course Windows thinks it's a special little shit and // Of course Windows thinks it's a special little shit and
// picks a signed integer for an unsigned type // picks a signed integer for an unsigned type
#[cfg_attr(all(windows, not(target_env = "gnu")), repr(i32))] #[cfg_attr(all(windows, not(target_env = "gnu")), repr(i32))]
pub(crate) enum GGMLLogLevel { pub enum GGMLLogLevel {
None = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_NONE, None = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_NONE,
Info = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO, Info = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO,
Warn = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN, Warn = whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN,

View file

@ -13,6 +13,7 @@ mod whisper_logging_hook;
mod whisper_params; mod whisper_params;
mod whisper_state; mod whisper_state;
pub use common_logging::GGMLLogLevel;
pub use error::WhisperError; pub use error::WhisperError;
pub use standalone::*; pub use standalone::*;
pub use utilities::*; pub use utilities::*;

View file

@ -799,8 +799,8 @@ impl<'a, 'b> FullParams<'a, 'b> {
// following implementations are safe // following implementations are safe
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
// concurrent usage is prevented by &mut self on methods that modify the struct // concurrent usage is prevented by &mut self on methods that modify the struct
unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {} unsafe impl Send for FullParams<'_, '_> {}
unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {} unsafe impl Sync for FullParams<'_, '_> {}
#[cfg(test)] #[cfg(test)]
mod test_whisper_params_initial_prompt { mod test_whisper_params_initial_prompt {

View file

@ -346,25 +346,43 @@ impl WhisperState {
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) }) Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) })
} }
fn full_get_segment_raw(&self, segment: c_int) -> Result<&CStr, WhisperError> {
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
unsafe { Ok(CStr::from_ptr(ret)) }
}
/// Get the raw bytes of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// `Ok(Vec<u8>)` on success, with the returned bytes or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
Ok(self.full_get_segment_raw(segment)?.to_bytes().to_vec())
}
/// Get the text of the specified segment. /// Get the text of the specified segment.
/// ///
/// # Arguments /// # Arguments
/// * segment: Segment index. /// * segment: Segment index.
/// ///
/// # Returns /// # Returns
/// Ok(String) on success, Err(WhisperError) on failure. /// `Ok(String)` on success, with the UTF-8 validated string, or
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> { pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
let ret = Ok(self.full_get_segment_raw(segment)?.to_str()?.to_string())
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
let r_str = c_str.to_str()?;
Ok(r_str.to_string())
} }
/// Get the text of the specified segment. /// Get the text of the specified segment.
@ -376,38 +394,16 @@ impl WhisperState {
/// * segment: Segment index. /// * segment: Segment index.
/// ///
/// # Returns /// # Returns
/// Ok(String) on success, Err(WhisperError) on failure. /// `Ok(String)` on success, or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result<String, WhisperError> { pub fn full_get_segment_text_lossy(&self, segment: c_int) -> Result<String, WhisperError> {
let ret = Ok(self
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) }; .full_get_segment_raw(segment)?
if ret.is_null() { .to_string_lossy()
return Err(WhisperError::NullPointer); .to_string())
}
let c_str = unsafe { CStr::from_ptr(ret) };
Ok(c_str.to_string_lossy().to_string())
}
/// Get the bytes of the specified segment.
///
/// # Arguments
/// * segment: Segment index.
///
/// # Returns
/// `Ok(Vec<u8>)` on success, `Err(WhisperError)` on failure.
///
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_bytes(&self, segment: c_int) -> Result<Vec<u8>, WhisperError> {
let ret =
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
Ok(c_str.to_bytes().to_vec())
} }
/// Get number of tokens in the specified segment. /// Get number of tokens in the specified segment.
@ -425,22 +421,7 @@ impl WhisperState {
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) }) Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
} }
/// Get the token text of the specified token in the specified segment. fn full_get_token_raw(&self, segment: c_int, token: c_int) -> Result<&CStr, WhisperError> {
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// Ok(String) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_text(
&self,
segment: c_int,
token: c_int,
) -> Result<String, WhisperError> {
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_full_get_token_text_from_state( whisper_rs_sys::whisper_full_get_token_text_from_state(
self.ctx.ctx, self.ctx.ctx,
@ -452,9 +433,53 @@ impl WhisperState {
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
let c_str = unsafe { CStr::from_ptr(ret) }; unsafe { Ok(CStr::from_ptr(ret)) }
let r_str = c_str.to_str()?; }
Ok(r_str.to_string())
/// Get the raw token bytes of the specified token in the specified segment.
///
/// Useful if you're using a language for which whisper is known to split tokens
/// away from UTF-8 character boundaries.
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// `Ok(Vec<u8>)` on success, with the returned bytes or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
///
/// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_bytes(
&self,
segment: c_int,
token: c_int,
) -> Result<Vec<u8>, WhisperError> {
Ok(self.full_get_token_raw(segment, token)?.to_bytes().to_vec())
}
/// Get the token text of the specified token in the specified segment.
///
/// # Arguments
/// * segment: Segment index.
/// * token: Token index.
///
/// # Returns
/// `Ok(String)` on success, with the UTF-8 validated string, or
/// `Err(WhisperError)` on failure (either `NullPointer` or `InvalidUtf8`)
///
/// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
pub fn full_get_token_text(
&self,
segment: c_int,
token: c_int,
) -> Result<String, WhisperError> {
Ok(self
.full_get_token_raw(segment, token)?
.to_str()?
.to_string())
} }
/// Get the token text of the specified token in the specified segment. /// Get the token text of the specified token in the specified segment.
@ -467,7 +492,8 @@ impl WhisperState {
/// * token: Token index. /// * token: Token index.
/// ///
/// # Returns /// # Returns
/// Ok(String) on success, Err(WhisperError) on failure. /// `Ok(String)` on success, or
/// `Err(WhisperError::NullPointer)` on failure (this is the only possible error)
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)` /// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
@ -476,19 +502,10 @@ impl WhisperState {
segment: c_int, segment: c_int,
token: c_int, token: c_int,
) -> Result<String, WhisperError> { ) -> Result<String, WhisperError> {
let ret = unsafe { Ok(self
whisper_rs_sys::whisper_full_get_token_text_from_state( .full_get_token_raw(segment, token)?
self.ctx.ctx, .to_string_lossy()
self.ptr, .to_string())
segment,
token,
)
};
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
let c_str = unsafe { CStr::from_ptr(ret) };
Ok(c_str.to_string_lossy().to_string())
} }
/// Get the token ID of the specified token in the specified segment. /// Get the token ID of the specified token in the specified segment.

View file

@ -1,6 +1,6 @@
[package] [package]
name = "whisper-rs-sys" name = "whisper-rs-sys"
version = "0.11.1" version = "0.12.1"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp (FFI bindings)" description = "Rust bindings for whisper.cpp (FFI bindings)"
license = "Unlicense" license = "Unlicense"
@ -38,6 +38,6 @@ openmp = []
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"
bindgen = "0.69" bindgen = "0.71"
cfg-if = "1" cfg-if = "1"
fs_extra = "1.3" fs_extra = "1.3"