diff --git a/Cargo.toml b/Cargo.toml index 1de051c..117dabf 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"] [package] name = "whisper-rs" -version = "0.4.0" +version = "0.5.0" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -16,8 +16,11 @@ repository = "https://github.com/tazz4843/whisper-rs" [dependencies] whisper-rs-sys = { path = "sys", version = "0.3" } +[dev-dependencies] +hound = "3.5.0" + [features] simd = [] [package.metadata.docs.rs] -features = ["simd"] \ No newline at end of file +features = ["simd"] diff --git a/examples/audio_transcription.rs b/examples/audio_transcription.rs index bf6b3d8..7ab716d 100644 --- a/examples/audio_transcription.rs +++ b/examples/audio_transcription.rs @@ -1,12 +1,13 @@ // This example is not going to build in this folder. -// You need to copy this code into your project and add the whisper_rs dependency in your cargo.toml +// You need to copy this code into your project and add the dependencies whisper_rs and hound in your cargo.toml +use hound; use std::fs::File; use std::io::Write; use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. -fn main() { +fn main() -> Result<(), &'static str> { // Load a context and model. let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") .expect("failed to load model"); @@ -14,7 +15,7 @@ fn main() { // Create a params object for running the model. // Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP. // The number of past samples to consider defaults to 0. - let mut params = FullParams::new(SamplingStrategy::Greedy { n_past: 0 }); + let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 }); // Edit params as needed. // Set the number of threads to use to 1. @@ -22,7 +23,7 @@ fn main() { // Enable translation. params.set_translate(true); // Set the language to translate to to English. - params.set_language("en"); + params.set_language(Some("en")); // Disable anything that prints to stdout. params.set_print_special(false); params.set_print_progress(false); @@ -31,6 +32,7 @@ fn main() { // Open the audio file. let mut reader = hound::WavReader::open("audio.wav").expect("failed to open file"); + #[allow(unused_variables)] let hound::WavSpec { channels, sample_rate, @@ -50,7 +52,7 @@ fn main() { // These utilities are provided for convenience, but can be replaced with custom conversion logic. // SIMD variants of these functions are also available on nightly Rust (see the docs). if channels == 2 { - audio = whisper_rs::convert_stereo_to_mono_audio(&audio); + audio = whisper_rs::convert_stereo_to_mono_audio(&audio)?; } else if channels != 1 { panic!(">2 channels unsupported"); } @@ -83,4 +85,5 @@ fn main() { file.write_all(line.as_bytes()) .expect("failed to write to file"); } + Ok(()) } diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 8d0f219..727deba 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -5,7 +5,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; // note that running this example will not do anything, as it is just a // demonstration of how to use the library, and actual usage requires // more dependencies than the base library. -pub fn usage() { +pub fn usage() -> Result<(), &'static str> { // load a context and model let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model"); @@ -38,7 +38,7 @@ pub fn usage() { // SIMD variants of these functions are also available, but only on nightly Rust: see the docs let audio_data = whisper_rs::convert_stereo_to_mono_audio( &whisper_rs::convert_integer_to_float_audio(&audio_data), - ); + )?; // now we can run the model ctx.full(params, &audio_data[..]) @@ -52,6 +52,8 @@ pub fn usage() { let end_timestamp = ctx.full_get_segment_t1(i); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); } + + Ok(()) } fn main() { diff --git a/src/utilities.rs b/src/utilities.rs index 4d210b8..b976475 100644 --- a/src/utilities.rs +++ b/src/utilities.rs @@ -61,12 +61,16 @@ pub fn convert_integer_to_float_audio_simd(samples: &[i16]) -> Vec { /// /// # Returns /// A vector of 32 bit floating point mono PCM audio samples. -pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec { - let mut mono = Vec::with_capacity(samples.len() / 2); - for i in (0..samples.len()).step_by(2) { - mono.push((samples[i] + samples[i + 1]) / 2.0); +pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result, &'static str> { + if samples.len() & 1 != 0 { + return Err("The stereo audio vector has an odd number of samples. \ + This means a half-sample is missing somewhere"); } - mono + + Ok(samples + .chunks_exact(2) + .map(|x| (x[0] + x[1]) / 2.0) + .collect()) } /// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. @@ -80,7 +84,7 @@ pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec { /// # Returns /// A vector of 32 bit floating point mono PCM audio samples. #[cfg(feature = "simd")] -pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec { +pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Result, &'static str> { let mut mono = Vec::with_capacity(samples.len() / 2); let div_array = f32x16::splat(2.0); @@ -104,11 +108,9 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec { // Handle the remainder. // do this normally because it's only a few samples and the overhead of // converting to SIMD is not worth it. - for i in (0..remainder.len()).step_by(2) { - mono.push((remainder[i] + remainder[i + 1]) / 2.0); - } + mono.extend(convert_stereo_to_mono_audio(remainder)?); - mono + Ok(mono) } #[cfg(feature = "simd")] @@ -116,13 +118,33 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec { mod test { use super::*; + #[test] + pub fn assert_stereo_to_mono_err() { + // fake some sample data + let samples = (0u16..1029).map(f32::from).collect::>(); + let mono = convert_stereo_to_mono_audio(&samples); + assert!(mono.is_err()); + } +} + +#[cfg(feature = "simd")] +#[cfg(test)] +mod test_simd { + use super::*; + #[test] pub fn assert_stereo_to_mono_simd() { - // fake some sample data, of 1028 elements - let mut samples = Vec::with_capacity(1028); - for i in 0..1028 { - samples.push(i as f32); - } + // fake some sample data + let samples = (0u16..1028).map(f32::from).collect::>(); + let mono_simd = convert_stereo_to_mono_audio_simd(&samples); + let mono = convert_stereo_to_mono_audio(&samples); + assert_eq!(mono_simd, mono); + } + + #[test] + pub fn assert_stereo_to_mono_simd_err() { + // fake some sample data + let samples = (0u16..1029).map(f32::from).collect::>(); let mono_simd = convert_stereo_to_mono_audio_simd(&samples); let mono = convert_stereo_to_mono_audio(&samples); assert_eq!(mono_simd, mono);