Changed convert_stereo_to_mono_audio to return a Result

This commit is contained in:
James Bruska 2023-03-27 11:49:13 -04:00
parent 30ff41989b
commit d8271e31d0
3 changed files with 43 additions and 17 deletions

View file

@ -7,7 +7,7 @@ use std::io::Write;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() { fn main() -> Result<(), &'static str> {
// Load a context and model. // Load a context and model.
let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
.expect("failed to load model"); .expect("failed to load model");
@ -52,7 +52,7 @@ fn main() {
// These utilities are provided for convenience, but can be replaced with custom conversion logic. // These utilities are provided for convenience, but can be replaced with custom conversion logic.
// SIMD variants of these functions are also available on nightly Rust (see the docs). // SIMD variants of these functions are also available on nightly Rust (see the docs).
if channels == 2 { if channels == 2 {
audio = whisper_rs::convert_stereo_to_mono_audio(&audio); audio = whisper_rs::convert_stereo_to_mono_audio(&audio)?;
} else if channels != 1 { } else if channels != 1 {
panic!(">2 channels unsupported"); panic!(">2 channels unsupported");
} }
@ -85,4 +85,5 @@ fn main() {
file.write_all(line.as_bytes()) file.write_all(line.as_bytes())
.expect("failed to write to file"); .expect("failed to write to file");
} }
Ok(())
} }

View file

@ -5,7 +5,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
// note that running this example will not do anything, as it is just a // note that running this example will not do anything, as it is just a
// demonstration of how to use the library, and actual usage requires // demonstration of how to use the library, and actual usage requires
// more dependencies than the base library. // more dependencies than the base library.
pub fn usage() { pub fn usage() -> Result<(), &'static str> {
// load a context and model // load a context and model
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model"); let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model");
@ -38,7 +38,7 @@ pub fn usage() {
// SIMD variants of these functions are also available, but only on nightly Rust: see the docs // SIMD variants of these functions are also available, but only on nightly Rust: see the docs
let audio_data = whisper_rs::convert_stereo_to_mono_audio( let audio_data = whisper_rs::convert_stereo_to_mono_audio(
&whisper_rs::convert_integer_to_float_audio(&audio_data), &whisper_rs::convert_integer_to_float_audio(&audio_data),
); )?;
// now we can run the model // now we can run the model
ctx.full(params, &audio_data[..]) ctx.full(params, &audio_data[..])
@ -52,6 +52,8 @@ pub fn usage() {
let end_timestamp = ctx.full_get_segment_t1(i); let end_timestamp = ctx.full_get_segment_t1(i);
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
} }
Ok(())
} }
fn main() { fn main() {

View file

@ -54,7 +54,6 @@ pub fn convert_integer_to_float_audio_simd(samples: &[i16]) -> Vec<f32> {
/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. /// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio.
/// ///
/// If there are an odd number of samples, the last half-sample is dropped.
/// This variant does not use SIMD instructions. /// This variant does not use SIMD instructions.
/// ///
/// # Arguments /// # Arguments
@ -62,16 +61,20 @@ pub fn convert_integer_to_float_audio_simd(samples: &[i16]) -> Vec<f32> {
/// ///
/// # Returns /// # Returns
/// A vector of 32 bit floating point mono PCM audio samples. /// A vector of 32 bit floating point mono PCM audio samples.
pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec<f32> { pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result<Vec<f32>, &'static str> {
samples if samples.len() & 1 != 0 {
return Err("The stereo audio vector has an odd number of samples. \
This means a half-sample is missing somewhere");
}
Ok(samples
.chunks_exact(2) .chunks_exact(2)
.map(|x| (x[0] + x[1]) / 2.0) .map(|x| (x[0] + x[1]) / 2.0)
.collect() .collect())
} }
/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. /// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio.
/// ///
/// If there are an odd number of samples, the last half-sample is dropped.
/// This variant uses SIMD instructions, and as such is only available on /// This variant uses SIMD instructions, and as such is only available on
/// nightly Rust. /// nightly Rust.
/// ///
@ -81,7 +84,7 @@ pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec<f32> {
/// # Returns /// # Returns
/// A vector of 32 bit floating point mono PCM audio samples. /// A vector of 32 bit floating point mono PCM audio samples.
#[cfg(feature = "simd")] #[cfg(feature = "simd")]
pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> { pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Result<Vec<f32>, &'static str> {
let mut mono = Vec::with_capacity(samples.len() / 2); let mut mono = Vec::with_capacity(samples.len() / 2);
let div_array = f32x16::splat(2.0); let div_array = f32x16::splat(2.0);
@ -105,9 +108,9 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> {
// Handle the remainder. // Handle the remainder.
// do this normally because it's only a few samples and the overhead of // do this normally because it's only a few samples and the overhead of
// converting to SIMD is not worth it. // converting to SIMD is not worth it.
mono.extend(convert_stereo_to_mono_audio(remainder)); mono.extend(convert_stereo_to_mono_audio(remainder)?);
mono Ok(mono)
} }
#[cfg(feature = "simd")] #[cfg(feature = "simd")]
@ -116,12 +119,32 @@ mod test {
use super::*; use super::*;
#[test] #[test]
pub fn assert_stereo_to_mono_simd() { pub fn assert_stereo_to_mono_err() {
// fake some sample data, of 1028 elements // fake some sample data
let mut samples = Vec::with_capacity(1028); let samples = (0u16..1029).map(f32::from).collect::<Vec<f32>>();
for i in 0..1029 { let mono = convert_stereo_to_mono_audio(&samples);
samples.push(i as f32); assert!(mono.is_err());
} }
}
#[cfg(feature = "simd")]
#[cfg(test)]
mod test_simd {
use super::*;
#[test]
pub fn assert_stereo_to_mono_simd() {
// fake some sample data
let samples = (0u16..1028).map(f32::from).collect::<Vec<f32>>();
let mono_simd = convert_stereo_to_mono_audio_simd(&samples);
let mono = convert_stereo_to_mono_audio(&samples);
assert_eq!(mono_simd, mono);
}
#[test]
pub fn assert_stereo_to_mono_simd_err() {
// fake some sample data
let samples = (0u16..1029).map(f32::from).collect::<Vec<f32>>();
let mono_simd = convert_stereo_to_mono_audio_simd(&samples); let mono_simd = convert_stereo_to_mono_audio_simd(&samples);
let mono = convert_stereo_to_mono_audio(&samples); let mono = convert_stereo_to_mono_audio(&samples);
assert_eq!(mono_simd, mono); assert_eq!(mono_simd, mono);