diff --git a/Cargo.toml b/Cargo.toml index e92fefd..2f34200 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ exclude = ["examples/full_usage"] [package] name = "whisper-rs" -version = "0.10.1" +version = "0.11.0" edition = "2021" description = "Rust bindings for whisper.cpp" license = "Unlicense" @@ -15,14 +15,16 @@ repository = "https://github.com/tazz4843/whisper-rs" [dependencies] whisper-rs-sys = { path = "sys", version = "0.8" } +log = { version = "0.4", optional = true } +tracing = { version = "0.1", optional = true } [dev-dependencies] hound = "3.5.0" +rand = "0.8.4" [features] default = [] -simd = [] coreml = ["whisper-rs-sys/coreml"] cuda = ["whisper-rs-sys/cuda", "_gpu"] opencl = ["whisper-rs-sys/opencl"] @@ -30,6 +32,8 @@ openblas = ["whisper-rs-sys/openblas"] metal = ["whisper-rs-sys/metal", "_gpu"] _gpu = [] test-with-tiny-model = [] +whisper-cpp-log = ["dep:log"] +whisper-cpp-tracing = ["dep:tracing"] [package.metadata.docs.rs] features = ["simd"] diff --git a/src/error.rs b/src/error.rs index 51ab0bc..6fe7420 100644 --- a/src/error.rs +++ b/src/error.rs @@ -41,6 +41,10 @@ pub enum WhisperError { FailedToCreateState, /// No samples were provided. NoSamples, + /// Input and output slices were not the same length. + InputOutputLengthMismatch { input_len: usize, output_len: usize }, + /// Input slice was not an even number of samples. + HalfSampleMissing(usize), } impl From for WhisperError { @@ -109,6 +113,24 @@ impl std::fmt::Display for WhisperError { c_int ), NoSamples => write!(f, "Input sample buffer was empty."), + InputOutputLengthMismatch { + output_len, + input_len, + } => { + write!( + f, + "Input and output slices were not the same length. Input: {}, Output: {}", + input_len, output_len + ) + } + HalfSampleMissing(size) => { + write!( + f, + "Input slice was not an even number of samples, got {}, expected {}", + size, + size + 1 + ) + } } } } diff --git a/src/lib.rs b/src/lib.rs index d93e588..bdd326f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,4 +1,5 @@ #![allow(clippy::uninlined_format_args)] +#![cfg_attr(test, feature(test))] mod error; mod standalone; @@ -7,15 +8,26 @@ mod whisper_ctx; mod whisper_grammar; mod whisper_params; mod whisper_state; +#[cfg(feature = "whisper-cpp-log")] +mod whisper_sys_log; +#[cfg(feature = "whisper-cpp-tracing")] +mod whisper_sys_tracing; + +static LOG_TRAMPOLINE_INSTALL: Once = Once::new(); pub use error::WhisperError; pub use standalone::*; +use std::sync::Once; pub use utilities::*; pub use whisper_ctx::WhisperContext; pub use whisper_ctx::WhisperContextParameters; pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType}; pub use whisper_params::{FullParams, SamplingStrategy}; pub use whisper_state::WhisperState; +#[cfg(feature = "whisper-cpp-log")] +pub use whisper_sys_log::install_whisper_log_trampoline; +#[cfg(feature = "whisper-cpp-tracing")] +pub use whisper_sys_tracing::install_whisper_tracing_trampoline; pub type WhisperSysContext = whisper_rs_sys::whisper_context; pub type WhisperSysState = whisper_rs_sys::whisper_state; diff --git a/src/utilities.rs b/src/utilities.rs index f8bc554..8dfc045 100644 --- a/src/utilities.rs +++ b/src/utilities.rs @@ -1,33 +1,59 @@ +use crate::WhisperError; + /// Convert an array of 16 bit mono audio samples to a vector of 32 bit floats. /// -/// This variant does not use SIMD instructions. -/// /// # Arguments /// * `samples` - The array of 16 bit mono audio samples. +/// * `output` - The vector of 32 bit floats to write the converted samples to. /// -/// # Returns -/// A vector of 32 bit floats. -pub fn convert_integer_to_float_audio(samples: &[i16]) -> Vec { - let mut floats = Vec::with_capacity(samples.len()); - for sample in samples { - floats.push(*sample as f32 / 32768.0); +/// # Panics +/// * if `samples.len != output.len()` +/// +/// # Examples +/// ``` +/// # use whisper_rs::convert_integer_to_float_audio; +/// let samples = [0i16; 1024]; +/// let mut output = vec![0.0f32; samples.len()]; +/// convert_integer_to_float_audio(&samples, &mut output).expect("input and output lengths should be equal"); +/// ``` +pub fn convert_integer_to_float_audio( + samples: &[i16], + output: &mut [f32], +) -> Result<(), WhisperError> { + if samples.len() != output.len() { + return Err(WhisperError::InputOutputLengthMismatch { + input_len: samples.len(), + output_len: output.len(), + }); } - floats + + for (input, output) in samples.iter().zip(output.iter_mut()) { + *output = *input as f32 / 32768.0; + } + + Ok(()) } -/// Convert 32 bit floating point stereo PCM audio to 32 bit floating point mono PCM audio. -/// -/// This variant does not use SIMD instructions. +/// Convert 32-bit floating point stereo PCM audio to 32-bit floating point mono PCM audio. /// /// # Arguments -/// * `samples` - The array of 32 bit floating point stereo PCM audio samples. +/// * `samples` - The array of 32-bit floating point stereo PCM audio samples. +/// +/// # Errors +/// * if `samples.len()` is odd /// /// # Returns -/// A vector of 32 bit floating point mono PCM audio samples. -pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result, &'static str> { +/// A vector of 32-bit floating point mono PCM audio samples. +/// +/// # Examples +/// ``` +/// # use whisper_rs::convert_stereo_to_mono_audio; +/// let samples = [0.0f32; 1024]; +/// let mono = convert_stereo_to_mono_audio(&samples).expect("should be no half samples missing"); +/// ``` +pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result, WhisperError> { if samples.len() & 1 != 0 { - return Err("The stereo audio vector has an odd number of samples. \ - This means a half-sample is missing somewhere"); + return Err(WhisperError::HalfSampleMissing(samples.len())); } Ok(samples @@ -36,16 +62,51 @@ pub fn convert_stereo_to_mono_audio(samples: &[f32]) -> Result, &'stati .collect()) } -#[cfg(feature = "simd")] #[cfg(test)] mod test { use super::*; + use rand::distributions::{Distribution, Standard}; + use rand::Rng; + use std::hint::black_box; + + extern crate test; + + fn random_sample_data() -> Vec + where + Standard: Distribution, + { + const SAMPLE_SIZE: usize = 1_048_576; + + let mut rng = rand::thread_rng(); + let mut samples = Vec::with_capacity(SAMPLE_SIZE); + for _ in 0..SAMPLE_SIZE { + samples.push(rng.gen::()); + } + samples + } #[test] pub fn assert_stereo_to_mono_err() { - // fake some sample data - let samples = (0u16..1029).map(f32::from).collect::>(); + let samples = random_sample_data::(); let mono = convert_stereo_to_mono_audio(&samples); assert!(mono.is_err()); } + + #[bench] + pub fn bench_stereo_to_mono(b: &mut test::Bencher) { + let samples = random_sample_data::(); + b.iter(|| black_box(convert_stereo_to_mono_audio(black_box(&samples)))); + } + + #[bench] + pub fn bench_integer_to_float(b: &mut test::Bencher) { + let samples = random_sample_data::(); + let mut output = vec![0.0f32; samples.len()]; + b.iter(|| { + black_box(convert_integer_to_float_audio( + black_box(&samples), + black_box(&mut output), + )) + }); + } } diff --git a/src/whisper_sys_log.rs b/src/whisper_sys_log.rs new file mode 100644 index 0000000..9b5be22 --- /dev/null +++ b/src/whisper_sys_log.rs @@ -0,0 +1,42 @@ +use log::{debug, error, info, warn}; +use whisper_rs_sys::ggml_log_level; + +unsafe extern "C" fn whisper_cpp_log_trampoline( + level: ggml_log_level, + text: *const std::os::raw::c_char, + _: *mut std::os::raw::c_void, // user_data +) { + if text.is_null() { + error!("whisper_cpp_log_trampoline: text is nullptr"); + } + + // SAFETY: we must trust whisper.cpp that it will not pass us a string that does not satisfy + // from_ptr's requirements. + let log_str = unsafe { std::ffi::CStr::from_ptr(text) }.to_string_lossy(); + // whisper.cpp gives newlines at the end of its log messages, so we trim them + let trimmed = log_str.trim(); + + match level { + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_DEBUG => debug!("{}", trimmed), + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO => info!("{}", trimmed), + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN => warn!("{}", trimmed), + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_ERROR => error!("{}", trimmed), + _ => { + warn!( + "whisper_cpp_log_trampoline: unknown log level {}: message: {}", + level, trimmed + ) + } + } +} + +/// Shortcut utility to redirect all whisper.cpp logging to the `log` crate. +/// +/// Filter for logs from the `whisper-rs` crate to see all log output from whisper.cpp. +/// +/// You should only call this once (subsequent calls have no ill effect). +pub fn install_whisper_log_trampoline() { + crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { + whisper_rs_sys::whisper_log_set(Some(whisper_cpp_log_trampoline), std::ptr::null_mut()) + }); +} diff --git a/src/whisper_sys_tracing.rs b/src/whisper_sys_tracing.rs new file mode 100644 index 0000000..6c6d316 --- /dev/null +++ b/src/whisper_sys_tracing.rs @@ -0,0 +1,42 @@ +use tracing::{debug, error, info, warn}; +use whisper_rs_sys::ggml_log_level; + +unsafe extern "C" fn whisper_cpp_tracing_trampoline( + level: ggml_log_level, + text: *const std::os::raw::c_char, + _: *mut std::os::raw::c_void, // user_data +) { + if text.is_null() { + error!("whisper_cpp_tracing_trampoline: text is nullptr"); + } + + // SAFETY: we must trust whisper.cpp that it will not pass us a string that does not satisfy + // from_ptr's requirements. + let log_str = unsafe { std::ffi::CStr::from_ptr(text) }.to_string_lossy(); + // whisper.cpp gives newlines at the end of its log messages, so we trim them + let trimmed = log_str.trim(); + + match level { + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_DEBUG => debug!("{}", trimmed), + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_INFO => info!("{}", trimmed), + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_WARN => warn!("{}", trimmed), + whisper_rs_sys::ggml_log_level_GGML_LOG_LEVEL_ERROR => error!("{}", trimmed), + _ => { + warn!( + "whisper_cpp_tracing_trampoline: unknown log level {}: message: {}", + level, trimmed + ) + } + } +} + +/// Shortcut utility to redirect all whisper.cpp logging to the `tracing` crate. +/// +/// Filter for logs from the `whisper-rs` crate to see all log output from whisper.cpp. +/// +/// You should only call this once (subsequent calls have no effect). +pub fn install_whisper_tracing_trampoline() { + crate::LOG_TRAMPOLINE_INSTALL.call_once(|| unsafe { + whisper_rs_sys::whisper_log_set(Some(whisper_cpp_tracing_trampoline), std::ptr::null_mut()) + }); +}