Clean up VAD example and write to output file

This commit is contained in:
Niko 2025-09-29 12:40:49 -07:00
parent c2fba5055b
commit 2fa5df779c
No known key found for this signature in database

View file

@ -1,5 +1,6 @@
use hound::SampleFormat; use hound::{SampleFormat, WavSpec, WavWriter};
use std::io::Read; use std::io::Read;
use std::time::Instant;
use whisper_rs::{WhisperVadContext, WhisperVadContextParams, WhisperVadParams, WhisperVadSegment}; use whisper_rs::{WhisperVadContext, WhisperVadContextParams, WhisperVadParams, WhisperVadSegment};
fn main() { fn main() {
@ -9,14 +10,15 @@ fn main() {
let wav_path = std::env::args() let wav_path = std::env::args()
.nth(2) .nth(2)
.expect("Please specify path to WAV file as argument 2"); .expect("Please specify path to WAV file as argument 2");
let dest_path = std::env::args()
.nth(3)
.expect("Please specify output path as argument 3");
let wav_reader = hound::WavReader::open(wav_path).expect("failed to open wav file"); let wav_reader = hound::WavReader::open(wav_path).expect("failed to open wav file");
assert_eq!( let input_sample_rate = wav_reader.spec().sample_rate;
wav_reader.spec().sample_rate, let input_channels = wav_reader.spec().channels;
16000, assert_eq!(input_sample_rate, 16000, "expected 16kHz sample rate");
"expected 16kHz sample rate" assert_eq!(input_channels, 1, "expected mono audio");
);
assert_eq!(wav_reader.spec().channels, 1, "expected mono audio");
let samples = decode_to_float(wav_reader); let samples = decode_to_float(wav_reader);
@ -31,18 +33,39 @@ fn main() {
WhisperVadContext::new(&model_path, vad_ctx_params).expect("failed to load model"); WhisperVadContext::new(&model_path, vad_ctx_params).expect("failed to load model");
let vad_params = WhisperVadParams::new(); let vad_params = WhisperVadParams::new();
let st = Instant::now();
let result = vad_ctx let result = vad_ctx
.segments_from_samples(vad_params, &samples) .segments_from_samples(vad_params, &samples)
.expect("failed to run VAD"); .expect("failed to run VAD");
let et = Instant::now();
let dt = et.duration_since(st);
println!("took {:?} to run the VAD model", dt);
let mut output = WavWriter::create(
dest_path,
WavSpec {
channels: input_channels,
sample_rate: 16000,
bits_per_sample: 32,
sample_format: SampleFormat::Float,
},
)
.expect("failed to open output file");
for WhisperVadSegment { start, end } in result { for WhisperVadSegment { start, end } in result {
println!( // convert from centiseconds to seconds
"detected speech between {}s and {}s", let start_ts = start / 100.0;
// each segment is in centiseconds so must be modified let end_ts = end / 100.0;
start / 100.0, println!("detected speech between {}s and {}s", start_ts, end_ts);
end / 100.0
); let start_sample_idx = (start_ts * input_sample_rate as f32) as usize;
let end_sample_idx = (end_ts * input_sample_rate as f32) as usize;
for sample in &samples[start_sample_idx..end_sample_idx] {
output
.write_sample(*sample)
.expect("failed to write sample");
} }
}
output.finalize().expect("failed to finalize dest file");
} }
fn decode_to_float<T: Read>(rdr: hound::WavReader<T>) -> Vec<f32> { fn decode_to_float<T: Read>(rdr: hound::WavReader<T>) -> Vec<f32> {