From 8d462f5dd3e2f792e60e2538e32440d44d85a3e4 Mon Sep 17 00:00:00 2001 From: Chris Raethke Date: Mon, 20 Nov 2023 20:42:34 +1000 Subject: [PATCH] Expose WhisperContextParameters and update examples --- README.md | 7 +++++-- examples/audio_transcription.rs | 8 +++++--- examples/basic_use.rs | 7 +++++-- examples/full_usage/src/main.rs | 7 +++++-- src/lib.rs | 1 + 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 49ab481..bda4c20 100644 --- a/README.md +++ b/README.md @@ -15,13 +15,16 @@ cargo run --example audio_transcription ``` ```rust -use whisper_rs::{WhisperContext, FullParams, SamplingStrategy}; +use whisper_rs::{WhisperContext, WhisperContextParameters, FullParams, SamplingStrategy}; fn main() { let path_to_model = std::env::args().nth(1).unwrap(); // load a context and model - let ctx = WhisperContext::new(&path_to_model).expect("failed to load model"); + let ctx = WhisperContext::new_with_params( + path_to_model, + WhisperContextParameters::default() + ).expect("failed to load model"); // create a params object let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 }); diff --git a/examples/audio_transcription.rs b/examples/audio_transcription.rs index 71dec68..68fdf2f 100644 --- a/examples/audio_transcription.rs +++ b/examples/audio_transcription.rs @@ -4,13 +4,15 @@ use hound; use std::fs::File; use std::io::Write; -use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; +use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; /// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout. fn main() -> Result<(), &'static str> { // Load a context and model. - let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin") - .expect("failed to load model"); + let ctx = WhisperContext::new_with_params( + "example/path/to/model/whisper.cpp/models/ggml-base.en.bin", + WhisperContextParameters::default() + ).expect("failed to load model"); // Create a state let mut state = ctx.create_state().expect("failed to create key"); diff --git a/examples/basic_use.rs b/examples/basic_use.rs index f0d81e0..56860a0 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -1,13 +1,16 @@ #![allow(clippy::uninlined_format_args)] -use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; +use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; // note that running this example will not do anything, as it is just a // demonstration of how to use the library, and actual usage requires // more dependencies than the base library. pub fn usage() -> Result<(), &'static str> { // load a context and model - let ctx = WhisperContext::new("path/to/model").expect("failed to load model"); + let ctx = WhisperContext::new_with_params( + "path/to/model", + WhisperContextParameters::default() + ).expect("failed to load model"); // make a state let mut state = ctx.create_state().expect("failed to create state"); diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index c5eca17..5826a5b 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -2,7 +2,7 @@ use hound::{SampleFormat, WavReader}; use std::path::Path; -use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; +use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters}; fn parse_wav_file(path: &Path) -> Vec { let reader = WavReader::open(path).expect("failed to read file"); @@ -45,7 +45,10 @@ fn main() { let original_samples = parse_wav_file(audio_path); let samples = whisper_rs::convert_integer_to_float_audio(&original_samples); - let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model"); + let ctx = WhisperContext::new_with_params( + &whisper_path.to_string_lossy(), + WhisperContextParameters::default() + ).expect("failed to open model"); let mut state = ctx.create_state().expect("failed to create key"); let mut params = FullParams::new(SamplingStrategy::default()); params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); diff --git a/src/lib.rs b/src/lib.rs index 6b7754c..ed17306 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,6 +13,7 @@ pub use error::WhisperError; pub use standalone::*; pub use utilities::*; pub use whisper_ctx::WhisperContext; +pub use whisper_ctx::WhisperContextParameters; pub use whisper_params::{FullParams, SamplingStrategy}; pub use whisper_state::WhisperState;