Merge pull request #21 from bruskajp/master

Fix panic in stereo to mono audio conversion
This commit is contained in:
0/0 2023-03-27 16:56:26 +00:00 committed by GitHub
commit 34260d4aed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 54 additions and 24 deletions

View file

@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
[package] [package]
name = "whisper-rs" name = "whisper-rs"
version = "0.4.0" version = "0.5.0"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp" description = "Rust bindings for whisper.cpp"
license = "Unlicense" license = "Unlicense"
@ -16,6 +16,9 @@ repository = "https://github.com/tazz4843/whisper-rs"
[dependencies] [dependencies]
whisper-rs-sys = { path = "sys", version = "0.3" } whisper-rs-sys = { path = "sys", version = "0.3" }
[dev-dependencies]
hound = "3.5.0"
[features] [features]
simd = [] simd = []

View file

@ -1,12 +1,13 @@
// This example is not going to build in this folder. // This example is not going to build in this folder.
// You need to copy this code into your project and add the whisper_rs dependency in your cargo.toml // You need to copy this code into your project and add the dependencies whisper_rs and hound in your cargo.toml
use hound;
use std::fs::File; use std::fs::File;
use std::io::Write; use std::io::Write;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() { fn main() -> Result<(), &'static str> {
// Load a context and model. // Load a context and model.
let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
.expect("failed to load model"); .expect("failed to load model");
@ -14,7 +15,7 @@ fn main() {
// Create a params object for running the model. // Create a params object for running the model.
// Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP. // Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP.
// The number of past samples to consider defaults to 0. // The number of past samples to consider defaults to 0.
let mut params = FullParams::new(SamplingStrategy::Greedy { n_past: 0 }); let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 0 });
// Edit params as needed. // Edit params as needed.
// Set the number of threads to use to 1. // Set the number of threads to use to 1.
@ -22,7 +23,7 @@ fn main() {
// Enable translation. // Enable translation.
params.set_translate(true); params.set_translate(true);
// Set the language to translate to to English. // Set the language to translate to to English.
params.set_language("en"); params.set_language(Some("en"));
// Disable anything that prints to stdout. // Disable anything that prints to stdout.
params.set_print_special(false); params.set_print_special(false);
params.set_print_progress(false); params.set_print_progress(false);
@ -31,6 +32,7 @@ fn main() {
// Open the audio file. // Open the audio file.
let mut reader = hound::WavReader::open("audio.wav").expect("failed to open file"); let mut reader = hound::WavReader::open("audio.wav").expect("failed to open file");
#[allow(unused_variables)]
let hound::WavSpec { let hound::WavSpec {
channels, channels,
sample_rate, sample_rate,
@ -50,7 +52,7 @@ fn main() {
// These utilities are provided for convenience, but can be replaced with custom conversion logic. // These utilities are provided for convenience, but can be replaced with custom conversion logic.
// SIMD variants of these functions are also available on nightly Rust (see the docs). // SIMD variants of these functions are also available on nightly Rust (see the docs).
if channels == 2 { if channels == 2 {
audio = whisper_rs::convert_stereo_to_mono_audio(&audio); audio = whisper_rs::convert_stereo_to_mono_audio(&audio)?;
} else if channels != 1 { } else if channels != 1 {
panic!(">2 channels unsupported"); panic!(">2 channels unsupported");
} }
@ -83,4 +85,5 @@ fn main() {
file.write_all(line.as_bytes()) file.write_all(line.as_bytes())
.expect("failed to write to file"); .expect("failed to write to file");
} }
Ok(())
} }

View file

@ -5,7 +5,7 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
// note that running this example will not do anything, as it is just a // note that running this example will not do anything, as it is just a
// demonstration of how to use the library, and actual usage requires // demonstration of how to use the library, and actual usage requires
// more dependencies than the base library. // more dependencies than the base library.
pub fn usage() { pub fn usage() -> Result<(), &'static str> {
// load a context and model // load a context and model
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model"); let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model");
@ -38,7 +38,7 @@ pub fn usage() {
// SIMD variants of these functions are also available, but only on nightly Rust: see the docs // SIMD variants of these functions are also available, but only on nightly Rust: see the docs
let audio_data = whisper_rs::convert_stereo_to_mono_audio( let audio_data = whisper_rs::convert_stereo_to_mono_audio(
&whisper_rs::convert_integer_to_float_audio(&audio_data), &whisper_rs::convert_integer_to_float_audio(&audio_data),
); )?;
// now we can run the model // now we can run the model
ctx.full(params, &audio_data[..]) ctx.full(params, &audio_data[..])
@ -52,6 +52,8 @@ pub fn usage() {
let end_timestamp = ctx.full_get_segment_t1(i); let end_timestamp = ctx.full_get_segment_t1(i);
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
} }
Ok(())
} }
fn main() { fn main() {

View file

@ -61,12 +61,16 @@ pub fn convert_integer_to_float_audio_simd(samples: &[i16]) -> Vec<f32> {
/// ///
/// # Returns /// # Returns
/// A vector of 32 bit floating point mono PCM audio samples. /// A vector of 32 bit floating point mono PCM audio samples.
pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec<f32> { pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result<Vec<f32>, &'static str> {
let mut mono = Vec::with_capacity(samples.len() / 2); if samples.len() & 1 != 0 {
for i in (0..samples.len()).step_by(2) { return Err("The stereo audio vector has an odd number of samples. \
mono.push((samples[i] + samples[i + 1]) / 2.0); This means a half-sample is missing somewhere");
} }
mono
Ok(samples
.chunks_exact(2)
.map(|x| (x[0] + x[1]) / 2.0)
.collect())
} }
/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. /// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio.
@ -80,7 +84,7 @@ pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Vec<f32> {
/// # Returns /// # Returns
/// A vector of 32 bit floating point mono PCM audio samples. /// A vector of 32 bit floating point mono PCM audio samples.
#[cfg(feature = "simd")] #[cfg(feature = "simd")]
pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> { pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Result<Vec<f32>, &'static str> {
let mut mono = Vec::with_capacity(samples.len() / 2); let mut mono = Vec::with_capacity(samples.len() / 2);
let div_array = f32x16::splat(2.0); let div_array = f32x16::splat(2.0);
@ -104,11 +108,9 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> {
// Handle the remainder. // Handle the remainder.
// do this normally because it's only a few samples and the overhead of // do this normally because it's only a few samples and the overhead of
// converting to SIMD is not worth it. // converting to SIMD is not worth it.
for i in (0..remainder.len()).step_by(2) { mono.extend(convert_stereo_to_mono_audio(remainder)?);
mono.push((remainder[i] + remainder[i + 1]) / 2.0);
}
mono Ok(mono)
} }
#[cfg(feature = "simd")] #[cfg(feature = "simd")]
@ -116,13 +118,33 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> {
mod test { mod test {
use super::*; use super::*;
#[test]
pub fn assert_stereo_to_mono_err() {
// fake some sample data
let samples = (0u16..1029).map(f32::from).collect::<Vec<f32>>();
let mono = convert_stereo_to_mono_audio(&samples);
assert!(mono.is_err());
}
}
#[cfg(feature = "simd")]
#[cfg(test)]
mod test_simd {
use super::*;
#[test] #[test]
pub fn assert_stereo_to_mono_simd() { pub fn assert_stereo_to_mono_simd() {
// fake some sample data, of 1028 elements // fake some sample data
let mut samples = Vec::with_capacity(1028); let samples = (0u16..1028).map(f32::from).collect::<Vec<f32>>();
for i in 0..1028 { let mono_simd = convert_stereo_to_mono_audio_simd(&samples);
samples.push(i as f32); let mono = convert_stereo_to_mono_audio(&samples);
} assert_eq!(mono_simd, mono);
}
#[test]
pub fn assert_stereo_to_mono_simd_err() {
// fake some sample data
let samples = (0u16..1029).map(f32::from).collect::<Vec<f32>>();
let mono_simd = convert_stereo_to_mono_audio_simd(&samples); let mono_simd = convert_stereo_to_mono_audio_simd(&samples);
let mono = convert_stereo_to_mono_audio(&samples); let mono = convert_stereo_to_mono_audio(&samples);
assert_eq!(mono_simd, mono); assert_eq!(mono_simd, mono);