diff --git a/examples/vad.rs b/examples/vad.rs index d4b207a..24c193e 100644 --- a/examples/vad.rs +++ b/examples/vad.rs @@ -1,5 +1,6 @@ -use hound::SampleFormat; +use hound::{SampleFormat, WavSpec, WavWriter}; use std::io::Read; +use std::time::Instant; use whisper_rs::{WhisperVadContext, WhisperVadContextParams, WhisperVadParams, WhisperVadSegment}; fn main() { @@ -9,14 +10,15 @@ fn main() { let wav_path = std::env::args() .nth(2) .expect("Please specify path to WAV file as argument 2"); + let dest_path = std::env::args() + .nth(3) + .expect("Please specify output path as argument 3"); let wav_reader = hound::WavReader::open(wav_path).expect("failed to open wav file"); - assert_eq!( - wav_reader.spec().sample_rate, - 16000, - "expected 16kHz sample rate" - ); - assert_eq!(wav_reader.spec().channels, 1, "expected mono audio"); + let input_sample_rate = wav_reader.spec().sample_rate; + let input_channels = wav_reader.spec().channels; + assert_eq!(input_sample_rate, 16000, "expected 16kHz sample rate"); + assert_eq!(input_channels, 1, "expected mono audio"); let samples = decode_to_float(wav_reader); @@ -31,18 +33,39 @@ fn main() { WhisperVadContext::new(&model_path, vad_ctx_params).expect("failed to load model"); let vad_params = WhisperVadParams::new(); + let st = Instant::now(); let result = vad_ctx .segments_from_samples(vad_params, &samples) .expect("failed to run VAD"); + let et = Instant::now(); + let dt = et.duration_since(st); + println!("took {:?} to run the VAD model", dt); + let mut output = WavWriter::create( + dest_path, + WavSpec { + channels: input_channels, + sample_rate: 16000, + bits_per_sample: 32, + sample_format: SampleFormat::Float, + }, + ) + .expect("failed to open output file"); for WhisperVadSegment { start, end } in result { - println!( - "detected speech between {}s and {}s", - // each segment is in centiseconds so must be modified - start / 100.0, - end / 100.0 - ); + // convert from centiseconds to seconds + let start_ts = start / 100.0; + let end_ts = end / 100.0; + println!("detected speech between {}s and {}s", start_ts, end_ts); + + let start_sample_idx = (start_ts * input_sample_rate as f32) as usize; + let end_sample_idx = (end_ts * input_sample_rate as f32) as usize; + for sample in &samples[start_sample_idx..end_sample_idx] { + output + .write_sample(*sample) + .expect("failed to write sample"); + } } + output.finalize().expect("failed to finalize dest file"); } fn decode_to_float(rdr: hound::WavReader) -> Vec {