Fix lang_detect function

This commit is contained in:
arizhih 2024-07-10 10:24:25 +02:00
parent 84522742da
commit bbdc8a07ef
2 changed files with 12 additions and 22 deletions

View file

@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
[package]
name = "whisper-rs"
version = "0.12.0"
version = "0.12.1"
edition = "2021"
description = "Rust bindings for whisper.cpp"
license = "Unlicense"
@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
whisper-rs-sys = { path = "sys", version = "0.10.0" }
whisper-rs-sys = { path = "sys", version = "0.10.1" }
log = { version = "0.4", optional = true }
tracing = { version = "0.1", optional = true }

View file

@ -225,11 +225,16 @@ impl WhisperState {
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
///
/// # Returns
/// `Ok(Vec<f32>)` on success, `Err(WhisperError)` on failure.
/// `Ok((i32, Vec<f32>))` on success where the i32 is detected language id and Vec<f32>
/// is array with the probabilities of all languages, `Err(WhisperError)` on failure.
///
/// # C++ equivalent
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
pub fn lang_detect(&self, offset_ms: usize, threads: usize) -> Result<Vec<f32>, WhisperError> {
pub fn lang_detect(
&self,
offset_ms: usize,
threads: usize,
) -> Result<(i32, Vec<f32>), WhisperError> {
if threads < 1 {
return Err(WhisperError::InvalidThreadCount);
}
@ -244,25 +249,10 @@ impl WhisperState {
lang_probs.as_mut_ptr(),
)
};
if ret == -1 {
Err(WhisperError::UnableToCalculateEvaluation)
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
assert_eq!(
ret as usize,
lang_probs.len(),
"lang_probs length mismatch: this is a bug in whisper.cpp"
);
// if we're still running, double check that the length is correct, otherwise print to stderr
// and abort, as this will cause Undefined Behavior
// might get here due to the unwind being caught by a user-installed panic handler
if lang_probs.len() != ret as usize {
eprintln!(
"lang_probs length mismatch: this is a bug in whisper.cpp, \
aborting to avoid Undefined Behavior"
);
std::process::abort();
}
Ok(lang_probs)
Ok((ret as i32, lang_probs))
}
}