Fix up audio_transcription example
This commit is contained in:
parent
3d70621a51
commit
a9dea32c81
1 changed files with 18 additions and 29 deletions
|
|
@ -8,7 +8,13 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar
|
||||||
|
|
||||||
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
||||||
fn main() -> Result<(), &'static str> {
|
fn main() -> Result<(), &'static str> {
|
||||||
// Load a context and model.
|
let model_path = std::env::args()
|
||||||
|
.nth(1)
|
||||||
|
.expect("Please specify path to model as argument 1");
|
||||||
|
let wav_path = std::env::args()
|
||||||
|
.nth(2)
|
||||||
|
.expect("Please specify path to wav file as argument 2");
|
||||||
|
|
||||||
let mut context_param = WhisperContextParameters::default();
|
let mut context_param = WhisperContextParameters::default();
|
||||||
|
|
||||||
// Enable DTW token level timestamp for known model by using model preset
|
// Enable DTW token level timestamp for known model by using model preset
|
||||||
|
|
@ -37,13 +43,11 @@ fn main() -> Result<(), &'static str> {
|
||||||
aheads: &custom_aheads,
|
aheads: &custom_aheads,
|
||||||
};
|
};
|
||||||
|
|
||||||
let ctx = WhisperContext::new_with_params(
|
// Load a context and model
|
||||||
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
|
let ctx =
|
||||||
context_param,
|
WhisperContext::new_with_params(&model_path, context_param).expect("failed to load model");
|
||||||
)
|
|
||||||
.expect("failed to load model");
|
|
||||||
// Create a state
|
// Create a state
|
||||||
let mut state = ctx.create_state().expect("failed to create key");
|
let mut state = ctx.create_state().expect("failed to create state");
|
||||||
|
|
||||||
// Create a params object for running the model.
|
// Create a params object for running the model.
|
||||||
// The number of past samples to consider defaults to 0.
|
// The number of past samples to consider defaults to 0.
|
||||||
|
|
@ -65,7 +69,7 @@ fn main() -> Result<(), &'static str> {
|
||||||
params.set_token_timestamps(true);
|
params.set_token_timestamps(true);
|
||||||
|
|
||||||
// Open the audio file.
|
// Open the audio file.
|
||||||
let reader = hound::WavReader::open("audio.wav").expect("failed to open file");
|
let reader = hound::WavReader::open(wav_path).expect("failed to open file");
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
let hound::WavSpec {
|
let hound::WavSpec {
|
||||||
channels,
|
channels,
|
||||||
|
|
@ -84,7 +88,6 @@ fn main() -> Result<(), &'static str> {
|
||||||
|
|
||||||
// Convert audio to 16KHz mono f32 samples, as required by the model.
|
// Convert audio to 16KHz mono f32 samples, as required by the model.
|
||||||
// These utilities are provided for convenience, but can be replaced with custom conversion logic.
|
// These utilities are provided for convenience, but can be replaced with custom conversion logic.
|
||||||
// SIMD variants of these functions are also available on nightly Rust (see the docs).
|
|
||||||
if channels == 2 {
|
if channels == 2 {
|
||||||
audio = whisper_rs::convert_stereo_to_mono_audio(&audio).expect("Conversion error");
|
audio = whisper_rs::convert_stereo_to_mono_audio(&audio).expect("Conversion error");
|
||||||
} else if channels != 1 {
|
} else if channels != 1 {
|
||||||
|
|
@ -102,28 +105,14 @@ fn main() -> Result<(), &'static str> {
|
||||||
let mut file = File::create("transcript.txt").expect("failed to create file");
|
let mut file = File::create("transcript.txt").expect("failed to create file");
|
||||||
|
|
||||||
// Iterate through the segments of the transcript.
|
// Iterate through the segments of the transcript.
|
||||||
let num_segments = state
|
for segment in state.as_iter() {
|
||||||
.full_n_segments()
|
|
||||||
.expect("failed to get number of segments");
|
|
||||||
for i in 0..num_segments {
|
|
||||||
// Get the transcribed text and timestamps for the current segment.
|
// Get the transcribed text and timestamps for the current segment.
|
||||||
let segment = state
|
let start_timestamp = segment.start_timestamp();
|
||||||
.full_get_segment_text(i)
|
let end_timestamp = segment.end_timestamp();
|
||||||
.expect("failed to get segment");
|
|
||||||
let start_timestamp = state
|
|
||||||
.full_get_segment_t0(i)
|
|
||||||
.expect("failed to get start timestamp");
|
|
||||||
let end_timestamp = state
|
|
||||||
.full_get_segment_t1(i)
|
|
||||||
.expect("failed to get end timestamp");
|
|
||||||
|
|
||||||
let first_token_dtw_ts = if let Ok(token_count) = state.full_n_tokens(i) {
|
let first_token_dtw_ts = if segment.n_tokens() > 0 {
|
||||||
if token_count > 0 {
|
if let Some(token) = segment.get_token(0) {
|
||||||
if let Ok(token_data) = state.full_get_token_data(i, 0) {
|
token.token_data().map_or(0, |token| token.t_dtw)
|
||||||
token_data.t_dtw
|
|
||||||
} else {
|
|
||||||
-1i64
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
-1i64
|
-1i64
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue