Merge remote-tracking branch 'origin/master'
This commit is contained in:
commit
7350453ffe
3 changed files with 116 additions and 1 deletions
|
|
@ -4,12 +4,14 @@ Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp/)
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
```rust
|
```rust
|
||||||
|
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// load a context and model
|
// load a context and model
|
||||||
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model");
|
let mut ctx = WhisperContext::new("path/to/model").expect("failed to load model");
|
||||||
|
|
||||||
// create a params object
|
// create a params object
|
||||||
let mut params = FullParams::new(DecodeStrategy::Greedy { n_past: 0 });
|
let mut params = FullParams::new(SamplingStrategy::Greedy { n_past: 0 });
|
||||||
|
|
||||||
// assume we have a buffer of audio data
|
// assume we have a buffer of audio data
|
||||||
// here we'll make a fake one, floating point samples, 32 bit, 16KHz, mono
|
// here we'll make a fake one, floating point samples, 32 bit, 16KHz, mono
|
||||||
|
|
|
||||||
86
examples/audio_transcription.rs
Normal file
86
examples/audio_transcription.rs
Normal file
|
|
@ -0,0 +1,86 @@
|
||||||
|
// This example is not going to build in this folder.
|
||||||
|
// You need to copy this code into your project and add the whisper_rs dependency in your cargo.toml
|
||||||
|
|
||||||
|
use std::fs::File;
|
||||||
|
use std::io::Write;
|
||||||
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
|
||||||
|
|
||||||
|
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
||||||
|
fn main() {
|
||||||
|
// Load a context and model.
|
||||||
|
let mut ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
|
||||||
|
.expect("failed to load model");
|
||||||
|
|
||||||
|
// Create a params object for running the model.
|
||||||
|
// Currently, only the Greedy sampling strategy is implemented, with BeamSearch as a WIP.
|
||||||
|
// The number of past samples to consider defaults to 0.
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::Greedy { n_past: 0 });
|
||||||
|
|
||||||
|
// Edit params as needed.
|
||||||
|
// Set the number of threads to use to 1.
|
||||||
|
params.set_n_threads(1);
|
||||||
|
// Enable translation.
|
||||||
|
params.set_translate(true);
|
||||||
|
// Set the language to translate to to English.
|
||||||
|
params.set_language("en");
|
||||||
|
// Disable anything that prints to stdout.
|
||||||
|
params.set_print_special(false);
|
||||||
|
params.set_print_progress(false);
|
||||||
|
params.set_print_realtime(false);
|
||||||
|
params.set_print_timestamps(false);
|
||||||
|
|
||||||
|
// Open the audio file.
|
||||||
|
let mut reader = hound::WavReader::open("audio.wav").expect("failed to open file");
|
||||||
|
let hound::WavSpec {
|
||||||
|
channels,
|
||||||
|
sample_rate,
|
||||||
|
bits_per_sample,
|
||||||
|
..
|
||||||
|
} = reader.spec();
|
||||||
|
|
||||||
|
// Convert the audio to floating point samples.
|
||||||
|
let mut audio = whisper_rs::convert_integer_to_float_audio(
|
||||||
|
&reader
|
||||||
|
.samples::<i16>()
|
||||||
|
.map(|s| s.expect("invalid sample"))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Convert audio to 16KHz mono f32 samples, as required by the model.
|
||||||
|
// These utilities are provided for convenience, but can be replaced with custom conversion logic.
|
||||||
|
// SIMD variants of these functions are also available on nightly Rust (see the docs).
|
||||||
|
if channels == 2 {
|
||||||
|
audio = whisper_rs::convert_stereo_to_mono_audio(&audio);
|
||||||
|
} else if channels != 1 {
|
||||||
|
panic!(">2 channels unsupported");
|
||||||
|
}
|
||||||
|
|
||||||
|
if sample_rate != 16000 {
|
||||||
|
panic!("sample rate must be 16KHz");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the model.
|
||||||
|
ctx.full(params, &audio[..]).expect("failed to run model");
|
||||||
|
|
||||||
|
// Create a file to write the transcript to.
|
||||||
|
let mut file = File::create("transcript.txt").expect("failed to create file");
|
||||||
|
|
||||||
|
// Iterate through the segments of the transcript.
|
||||||
|
let num_segments = ctx.full_n_segments();
|
||||||
|
for i in 0..num_segments {
|
||||||
|
// Get the transcribed text and timestamps for the current segment.
|
||||||
|
let segment = ctx.full_get_segment_text(i).expect("failed to get segment");
|
||||||
|
let start_timestamp = ctx.full_get_segment_t0(i);
|
||||||
|
let end_timestamp = ctx.full_get_segment_t1(i);
|
||||||
|
|
||||||
|
// Print the segment to stdout.
|
||||||
|
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
||||||
|
|
||||||
|
// Format the segment information as a string.
|
||||||
|
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
|
||||||
|
|
||||||
|
// Write the segment information to the file.
|
||||||
|
file.write_all(line.as_bytes())
|
||||||
|
.expect("failed to write to file");
|
||||||
|
}
|
||||||
|
}
|
||||||
27
sys/build.rs
27
sys/build.rs
|
|
@ -6,6 +6,16 @@ use std::env;
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
|
let target = env::var("TARGET").unwrap();
|
||||||
|
// Link C++ standard library
|
||||||
|
if let Some(cpp_stdlib) = get_cpp_link_stdlib(&target) {
|
||||||
|
println!("cargo:rustc-link-lib=dylib={}", cpp_stdlib);
|
||||||
|
}
|
||||||
|
// Link macOS Accelerate framework for matrix calculations
|
||||||
|
if target.contains("apple") {
|
||||||
|
println!("cargo:rustc-link-lib=framework=Accelerate");
|
||||||
|
}
|
||||||
|
|
||||||
println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap());
|
println!("cargo:rustc-link-search={}", env::var("OUT_DIR").unwrap());
|
||||||
println!("cargo:rustc-link-lib=static=whisper");
|
println!("cargo:rustc-link-lib=static=whisper");
|
||||||
println!("cargo:rerun-if-changed=wrapper.h");
|
println!("cargo:rerun-if-changed=wrapper.h");
|
||||||
|
|
@ -68,3 +78,20 @@ fn main() {
|
||||||
.status()
|
.status()
|
||||||
.expect("Failed to clean whisper build directory");
|
.expect("Failed to clean whisper build directory");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// From https://github.com/alexcrichton/cc-rs/blob/fba7feded71ee4f63cfe885673ead6d7b4f2f454/src/lib.rs#L2462
|
||||||
|
fn get_cpp_link_stdlib(target: &str) -> Option<&'static str> {
|
||||||
|
if target.contains("msvc") {
|
||||||
|
None
|
||||||
|
} else if target.contains("apple") {
|
||||||
|
Some("c++")
|
||||||
|
} else if target.contains("freebsd") {
|
||||||
|
Some("c++")
|
||||||
|
} else if target.contains("openbsd") {
|
||||||
|
Some("c++")
|
||||||
|
} else if target.contains("android") {
|
||||||
|
Some("c++_shared")
|
||||||
|
} else {
|
||||||
|
Some("stdc++")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue