diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index c3352fd..9ad5355 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -1,10 +1,10 @@ #![allow(clippy::uninlined_format_args)] use hound::{SampleFormat, WavReader}; -use std::path::Path; +use std::path::{Path, PathBuf}; use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; -fn parse_wav_file(path: &Path) -> Vec { +fn parse_wav_file(path: PathBuf) -> Vec { let reader = WavReader::open(path).expect("failed to read file"); if reader.spec().channels != 1 { @@ -27,20 +27,22 @@ fn parse_wav_file(path: &Path) -> Vec { } fn main() { - let arg1 = std::env::args() - .nth(1) - .expect("first argument should be path to WAV file"); - let audio_path = Path::new(&arg1); - if !audio_path.exists() { - panic!("audio file doesn't exist"); - } - let arg2 = std::env::args() - .nth(2) - .expect("second argument should be path to Whisper model"); - let whisper_path = Path::new(&arg2); + let whisper_path = PathBuf::from( + std::env::args() + .nth(1) + .expect("first argument should be path to audio file"), + ); if !whisper_path.exists() { panic!("whisper file doesn't exist") } + let audio_path = PathBuf::from( + std::env::args() + .nth(2) + .expect("second argument should be path to whisper model file"), + ); + if !audio_path.exists() { + panic!("audio file doesn't exist"); + } let original_samples = parse_wav_file(audio_path); let mut samples = vec![0.0f32; original_samples.len()]; @@ -53,7 +55,10 @@ fn main() { ) .expect("failed to open model"); let mut state = ctx.create_state().expect("failed to create key"); - let mut params = FullParams::new(SamplingStrategy::default()); + let mut params = FullParams::new(SamplingStrategy::BeamSearch { + beam_size: 5, + patience: -1.0, + }); params.set_initial_prompt("experience"); params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); @@ -63,19 +68,9 @@ fn main() { .expect("failed to convert samples"); let et = std::time::Instant::now(); - let num_segments = state - .full_n_segments() - .expect("failed to get number of segments"); - for i in 0..num_segments { - let segment = state - .full_get_segment_text(i) - .expect("failed to get segment"); - let start_timestamp = state - .full_get_segment_t0(i) - .expect("failed to get start timestamp"); - let end_timestamp = state - .full_get_segment_t1(i) - .expect("failed to get end timestamp"); + for segment in state.as_iter() { + let start_timestamp = segment.start_timestamp(); + let end_timestamp = segment.end_timestamp(); println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment); } println!("took {}ms", (et - st).as_millis());