Merge branch 'master' of github.com:/tazz4843/whisper-rs
This commit is contained in:
commit
41736c1f0f
6 changed files with 205 additions and 22 deletions
|
|
@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "whisper-rs"
|
name = "whisper-rs"
|
||||||
version = "0.10.1"
|
version = "0.11.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Rust bindings for whisper.cpp"
|
description = "Rust bindings for whisper.cpp"
|
||||||
license = "Unlicense"
|
license = "Unlicense"
|
||||||
|
|
@ -15,14 +15,16 @@ repository = "https://github.com/tazz4843/whisper-rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
whisper-rs-sys = { path = "sys", version = "0.8" }
|
whisper-rs-sys = { path = "sys", version = "0.8" }
|
||||||
|
log = { version = "0.4", optional = true }
|
||||||
|
tracing = { version = "0.1", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
hound = "3.5.0"
|
hound = "3.5.0"
|
||||||
|
rand = "0.8.4"
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = []
|
||||||
|
|
||||||
simd = []
|
|
||||||
coreml = ["whisper-rs-sys/coreml"]
|
coreml = ["whisper-rs-sys/coreml"]
|
||||||
cuda = ["whisper-rs-sys/cuda", "_gpu"]
|
cuda = ["whisper-rs-sys/cuda", "_gpu"]
|
||||||
opencl = ["whisper-rs-sys/opencl"]
|
opencl = ["whisper-rs-sys/opencl"]
|
||||||
|
|
@ -30,6 +32,8 @@ openblas = ["whisper-rs-sys/openblas"]
|
||||||
metal = ["whisper-rs-sys/metal", "_gpu"]
|
metal = ["whisper-rs-sys/metal", "_gpu"]
|
||||||
_gpu = []
|
_gpu = []
|
||||||
test-with-tiny-model = []
|
test-with-tiny-model = []
|
||||||
|
whisper-cpp-log = ["dep:log"]
|
||||||
|
whisper-cpp-tracing = ["dep:tracing"]
|
||||||
|
|
||||||
[package.metadata.docs.rs]
|
[package.metadata.docs.rs]
|
||||||
features = ["simd"]
|
features = ["simd"]
|
||||||
|
|
|
||||||
22
src/error.rs
22
src/error.rs
|
|
@ -44,6 +44,10 @@ pub enum WhisperError {
|
||||||
FailedToCreateState,
|
FailedToCreateState,
|
||||||
/// No samples were provided.
|
/// No samples were provided.
|
||||||
NoSamples,
|
NoSamples,
|
||||||
|
/// Input and output slices were not the same length.
|
||||||
|
InputOutputLengthMismatch { input_len: usize, output_len: usize },
|
||||||
|
/// Input slice was not an even number of samples.
|
||||||
|
HalfSampleMissing(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Utf8Error> for WhisperError {
|
impl From<Utf8Error> for WhisperError {
|
||||||
|
|
@ -112,6 +116,24 @@ impl std::fmt::Display for WhisperError {
|
||||||
c_int
|
c_int
|
||||||
),
|
),
|
||||||
NoSamples => write!(f, "Input sample buffer was empty."),
|
NoSamples => write!(f, "Input sample buffer was empty."),
|
||||||
|
InputOutputLengthMismatch {
|
||||||
|
output_len,
|
||||||
|
input_len,
|
||||||
|
} => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Input and output slices were not the same length. Input: {}, Output: {}",
|
||||||
|
input_len, output_len
|
||||||
|
)
|
||||||
|
}
|
||||||
|
HalfSampleMissing(size) => {
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Input slice was not an even number of samples, got {}, expected {}",
|
||||||
|
size,
|
||||||
|
size + 1
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
12
src/lib.rs
12
src/lib.rs
|
|
@ -1,4 +1,5 @@
|
||||||
#![allow(clippy::uninlined_format_args)]
|
#![allow(clippy::uninlined_format_args)]
|
||||||
|
#![cfg_attr(test, feature(test))]
|
||||||
|
|
||||||
mod error;
|
mod error;
|
||||||
mod standalone;
|
mod standalone;
|
||||||
|
|
@ -7,15 +8,26 @@ mod whisper_ctx;
|
||||||
mod whisper_grammar;
|
mod whisper_grammar;
|
||||||
mod whisper_params;
|
mod whisper_params;
|
||||||
mod whisper_state;
|
mod whisper_state;
|
||||||
|
#[cfg(feature = "whisper-cpp-log")]
|
||||||
|
mod whisper_sys_log;
|
||||||
|
#[cfg(feature = "whisper-cpp-tracing")]
|
||||||
|
mod whisper_sys_tracing;
|
||||||
|
|
||||||
|
static LOG_TRAMPOLINE_INSTALL: Once = Once::new();
|
||||||
|
|
||||||
pub use error::WhisperError;
|
pub use error::WhisperError;
|
||||||
pub use standalone::*;
|
pub use standalone::*;
|
||||||
|
use std::sync::Once;
|
||||||
pub use utilities::*;
|
pub use utilities::*;
|
||||||
pub use whisper_ctx::WhisperContext;
|
pub use whisper_ctx::WhisperContext;
|
||||||
pub use whisper_ctx::WhisperContextParameters;
|
pub use whisper_ctx::WhisperContextParameters;
|
||||||
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
|
||||||
pub use whisper_params::{FullParams, SamplingStrategy};
|
pub use whisper_params::{FullParams, SamplingStrategy};
|
||||||
pub use whisper_state::WhisperState;
|
pub use whisper_state::WhisperState;
|
||||||
|
#[cfg(feature = "whisper-cpp-log")]
|
||||||
|
pub use whisper_sys_log::install_whisper_log_trampoline;
|
||||||
|
#[cfg(feature = "whisper-cpp-tracing")]
|
||||||
|
pub use whisper_sys_tracing::install_whisper_tracing_trampoline;
|
||||||
|
|
||||||
pub type WhisperSysContext = whisper_rs_sys::whisper_context;
|
pub type WhisperSysContext = whisper_rs_sys::whisper_context;
|
||||||
pub type WhisperSysState = whisper_rs_sys::whisper_state;
|
pub type WhisperSysState = whisper_rs_sys::whisper_state;
|
||||||
|
|
|
||||||
101
src/utilities.rs
101
src/utilities.rs
|
|
@ -1,33 +1,59 @@
|
||||||
|
use crate::WhisperError;
|
||||||
|
|
||||||
/// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats.
|
/// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats.
|
||||||
///
|
///
|
||||||
/// This variant does not use SIMD instructions.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `samples` - The array of 16 bit mono audio samples.
|
/// * `samples` - The array of 16 bit mono audio samples.
|
||||||
|
/// * `output` - The vector of 32 bit floats to write the converted samples to.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Panics
|
||||||
/// A vector of 32 bit floats.
|
/// * if `samples.len != output.len()`
|
||||||
pub fn convert_integer_to_float_audio(samples: &[i16]) -> Vec<f32> {
|
///
|
||||||
let mut floats = Vec::with_capacity(samples.len());
|
/// # Examples
|
||||||
for sample in samples {
|
/// ```
|
||||||
floats.push(*sample as f32 / 32768.0);
|
/// # use whisper_rs::convert_integer_to_float_audio;
|
||||||
|
/// let samples = [0i16; 1024];
|
||||||
|
/// let mut output = vec![0.0f32; samples.len()];
|
||||||
|
/// convert_integer_to_float_audio(&samples, &mut output).expect("input and output lengths should be equal");
|
||||||
|
/// ```
|
||||||
|
pub fn convert_integer_to_float_audio(
|
||||||
|
samples: &[i16],
|
||||||
|
output: &mut [f32],
|
||||||
|
) -> Result<(), WhisperError> {
|
||||||
|
if samples.len() != output.len() {
|
||||||
|
return Err(WhisperError::InputOutputLengthMismatch {
|
||||||
|
input_len: samples.len(),
|
||||||
|
output_len: output.len(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
floats
|
|
||||||
|
for (input, output) in samples.iter().zip(output.iter_mut()) {
|
||||||
|
*output = *input as f32 / 32768.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio.
|
/// Convert 32-bit floating point stereo PCM audio to 32-bit floating point mono PCM audio.
|
||||||
///
|
|
||||||
/// This variant does not use SIMD instructions.
|
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `samples` - The array of 32 bit floating point stereo PCM audio samples.
|
/// * `samples` - The array of 32-bit floating point stereo PCM audio samples.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// * if `samples.len()` is odd
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// A vector of 32 bit floating point mono PCM audio samples.
|
/// A vector of 32-bit floating point mono PCM audio samples.
|
||||||
pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result<Vec<f32>, &'static str> {
|
///
|
||||||
|
/// # Examples
|
||||||
|
/// ```
|
||||||
|
/// # use whisper_rs::convert_stereo_to_mono_audio;
|
||||||
|
/// let samples = [0.0f32; 1024];
|
||||||
|
/// let mono = convert_stereo_to_mono_audio(&samples).expect("should be no half samples missing");
|
||||||
|
/// ```
|
||||||
|
pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result<Vec<f32>, WhisperError> {
|
||||||
if samples.len() & 1 != 0 {
|
if samples.len() & 1 != 0 {
|
||||||
return Err("The stereo audio vector has an odd number of samples. \
|
return Err(WhisperError::HalfSampleMissing(samples.len()));
|
||||||
This means a half-sample is missing somewhere");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(samples
|
Ok(samples
|
||||||
|
|
@ -36,16 +62,51 @@ pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result<Vec<f32>, &'stati
|
||||||
.collect())
|
.collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "simd")]
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use rand::distributions::{Distribution, Standard};
|
||||||
|
use rand::Rng;
|
||||||
|
use std::hint::black_box;
|
||||||
|
|
||||||
|
extern crate test;
|
||||||
|
|
||||||
|
fn random_sample_data<T>() -> Vec<T>
|
||||||
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
|
{
|
||||||
|
const SAMPLE_SIZE: usize = 1_048_576;
|
||||||
|
|
||||||
|
let mut rng = rand::thread_rng();
|
||||||
|
let mut samples = Vec::with_capacity(SAMPLE_SIZE);
|
||||||
|
for _ in 0..SAMPLE_SIZE {
|
||||||
|
samples.push(rng.gen::<T>());
|
||||||
|
}
|
||||||
|
samples
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
pub fn assert_stereo_to_mono_err() {
|
pub fn assert_stereo_to_mono_err() {
|
||||||
// fake some sample data
|
let samples = random_sample_data::<f32>();
|
||||||
let samples = (0u16..1029).map(f32::from).collect::<Vec<f32>>();
|
|
||||||
let mono = convert_stereo_to_mono_audio(&samples);
|
let mono = convert_stereo_to_mono_audio(&samples);
|
||||||
assert!(mono.is_err());
|
assert!(mono.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
pub fn bench_stereo_to_mono(b: &mut test::Bencher) {
|
||||||
|
let samples = random_sample_data::<f32>();
|
||||||
|
b.iter(|| black_box(convert_stereo_to_mono_audio(black_box(&samples))));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[bench]
|
||||||
|
pub fn bench_integer_to_float(b: &mut test::Bencher) {
|
||||||
|
let samples = random_sample_data::<i16>();
|
||||||
|
let mut output = vec![0.0f32; samples.len()];
|
||||||
|
b.iter(|| {
|
||||||
|
black_box(convert_integer_to_float_audio(
|
||||||
|
black_box(&samples),
|
||||||
|
black_box(&mut output),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
42
src/whisper_sys_log.rs
Normal file
42
src/whisper_sys_log.rs
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
use log::{debug, error, info, warn};
|
||||||
|
use whisper_rs_sys::ggml_log_level;
|
||||||
|
|
||||||
|
unsafe extern "C" fn whisper_cpp_log_trampoline(
|
||||||
|
level: ggml_log_level,
|
||||||
|
text: *const std::os::raw::c_char,
|
||||||
|
_: *mut std::os::raw::c_void, // user_data
|
||||||
|
) {
|
||||||
|
if text.is_null() {
|
||||||
|
error!("whisper_cpp_log_trampoline: text is nullptr");
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: we must trust whisper.cpp that it will not pass us a string that does not satisfy
|
||||||
|
// from_ptr's requirements.
|
||||||
|
let log_str = unsafe { std::ffi::CStr::from_ptr(text) }.to_string_lossy();
|
||||||
|
// whisper.cpp gives newlines at the end of its log messages, so we trim them
|
||||||
|
let trimmed = log_str.trim();
|
||||||
|
|
||||||
|
match level {
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_DEBUG => debug!("{}", trimmed),
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO => info!("{}", trimmed),
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN => warn!("{}", trimmed),
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_ERROR => error!("{}", trimmed),
|
||||||
|
_ => {
|
||||||
|
warn!(
|
||||||
|
"whisper_cpp_log_trampoline: unknown log level {}: message: {}",
|
||||||
|
level, trimmed
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shortcut utility to redirect all whisper.cpp logging to the `log` crate.
|
||||||
|
///
|
||||||
|
/// Filter for logs from the `whisper-rs` crate to see all log output from whisper.cpp.
|
||||||
|
///
|
||||||
|
/// You should only call this once (subsequent calls have no ill effect).
|
||||||
|
pub fn install_whisper_log_trampoline() {
|
||||||
|
crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe {
|
||||||
|
whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut())
|
||||||
|
});
|
||||||
|
}
|
||||||
42
src/whisper_sys_tracing.rs
Normal file
42
src/whisper_sys_tracing.rs
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
use whisper_rs_sys::ggml_log_level;
|
||||||
|
|
||||||
|
unsafe extern "C" fn whisper_cpp_tracing_trampoline(
|
||||||
|
level: ggml_log_level,
|
||||||
|
text: *const std::os::raw::c_char,
|
||||||
|
_: *mut std::os::raw::c_void, // user_data
|
||||||
|
) {
|
||||||
|
if text.is_null() {
|
||||||
|
error!("whisper_cpp_tracing_trampoline: text is nullptr");
|
||||||
|
}
|
||||||
|
|
||||||
|
// SAFETY: we must trust whisper.cpp that it will not pass us a string that does not satisfy
|
||||||
|
// from_ptr's requirements.
|
||||||
|
let log_str = unsafe { std::ffi::CStr::from_ptr(text) }.to_string_lossy();
|
||||||
|
// whisper.cpp gives newlines at the end of its log messages, so we trim them
|
||||||
|
let trimmed = log_str.trim();
|
||||||
|
|
||||||
|
match level {
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_DEBUG => debug!("{}", trimmed),
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO => info!("{}", trimmed),
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN => warn!("{}", trimmed),
|
||||||
|
whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_ERROR => error!("{}", trimmed),
|
||||||
|
_ => {
|
||||||
|
warn!(
|
||||||
|
"whisper_cpp_tracing_trampoline: unknown log level {}: message: {}",
|
||||||
|
level, trimmed
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Shortcut utility to redirect all whisper.cpp logging to the `tracing` crate.
|
||||||
|
///
|
||||||
|
/// Filter for logs from the `whisper-rs` crate to see all log output from whisper.cpp.
|
||||||
|
///
|
||||||
|
/// You should only call this once (subsequent calls have no effect).
|
||||||
|
pub fn install_whisper_tracing_trampoline() {
|
||||||
|
crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe {
|
||||||
|
whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut())
|
||||||
|
});
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue