New wyoming-whisper-rs binary crate implementing the Wyoming protocol over TCP, making whisper-rs usable with Home Assistant's voice pipeline. Includes nix flake devshell with Vulkan, ROCm/hipBLAS, clippy, and rustfmt support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
151 lines
4 KiB
Rust
151 lines
4 KiB
Rust
use std::sync::Arc;
|
|
|
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
|
|
|
|
use crate::error::Error;
|
|
|
|
pub struct TranscribeConfig {
|
|
pub language: Option<String>,
|
|
pub beam_size: i32,
|
|
pub threads: i32,
|
|
}
|
|
|
|
pub struct AudioBuffer {
|
|
data: Vec<u8>,
|
|
rate: u32,
|
|
width: u16,
|
|
channels: u16,
|
|
}
|
|
|
|
impl AudioBuffer {
|
|
pub fn new(rate: u32, width: u16, channels: u16) -> Self {
|
|
Self {
|
|
data: Vec::new(),
|
|
rate,
|
|
width,
|
|
channels,
|
|
}
|
|
}
|
|
|
|
pub fn append(&mut self, chunk: &[u8]) {
|
|
self.data.extend_from_slice(chunk);
|
|
}
|
|
|
|
pub fn into_f32_16khz_mono(self) -> Result<Vec<f32>, Error> {
|
|
if self.width != 2 {
|
|
return Err(Error::InvalidAudio(format!(
|
|
"expected 16-bit audio (width=2), got width={}",
|
|
self.width
|
|
)));
|
|
}
|
|
|
|
if !self.data.len().is_multiple_of(2) {
|
|
return Err(Error::InvalidAudio(
|
|
"audio data has odd number of bytes for 16-bit samples".into(),
|
|
));
|
|
}
|
|
|
|
// Interpret as i16 little-endian
|
|
let samples_i16: Vec<i16> = self
|
|
.data
|
|
.chunks_exact(2)
|
|
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
|
.collect();
|
|
|
|
// Convert i16 -> f32
|
|
let mut samples_f32 = vec![0.0f32; samples_i16.len()];
|
|
whisper_rs::convert_integer_to_float_audio(&samples_i16, &mut samples_f32)
|
|
.map_err(|e| Error::InvalidAudio(format!("i16 to f32 conversion failed: {e}")))?;
|
|
|
|
// Convert stereo to mono if needed
|
|
let mono = if self.channels == 2 {
|
|
let mut mono = vec![0.0f32; samples_f32.len() / 2];
|
|
whisper_rs::convert_stereo_to_mono_audio(&samples_f32, &mut mono)
|
|
.map_err(|e| Error::InvalidAudio(format!("stereo to mono failed: {e}")))?;
|
|
mono
|
|
} else if self.channels == 1 {
|
|
samples_f32
|
|
} else {
|
|
return Err(Error::InvalidAudio(format!(
|
|
"unsupported channel count: {}",
|
|
self.channels
|
|
)));
|
|
};
|
|
|
|
// Resample if not 16kHz
|
|
if self.rate == 16000 {
|
|
Ok(mono)
|
|
} else {
|
|
Ok(resample(&mono, self.rate, 16000))
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Simple linear interpolation resampler.
|
|
fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
|
|
if from_rate == to_rate || input.is_empty() {
|
|
return input.to_vec();
|
|
}
|
|
|
|
let ratio = from_rate as f64 / to_rate as f64;
|
|
let output_len = ((input.len() as f64) / ratio).ceil() as usize;
|
|
let mut output = Vec::with_capacity(output_len);
|
|
|
|
for i in 0..output_len {
|
|
let src_pos = i as f64 * ratio;
|
|
let idx = src_pos as usize;
|
|
let frac = src_pos - idx as f64;
|
|
|
|
let sample = if idx + 1 < input.len() {
|
|
input[idx] as f64 * (1.0 - frac) + input[idx + 1] as f64 * frac
|
|
} else {
|
|
input[idx.min(input.len() - 1)] as f64
|
|
};
|
|
|
|
output.push(sample as f32);
|
|
}
|
|
|
|
output
|
|
}
|
|
|
|
pub fn transcribe(
|
|
ctx: &Arc<WhisperContext>,
|
|
config: &TranscribeConfig,
|
|
audio: Vec<f32>,
|
|
) -> Result<String, Error> {
|
|
let mut state = ctx.create_state()?;
|
|
|
|
let mut params = FullParams::new(SamplingStrategy::BeamSearch {
|
|
beam_size: config.beam_size,
|
|
patience: -1.0,
|
|
});
|
|
|
|
if let Some(ref lang) = config.language {
|
|
params.set_language(Some(lang));
|
|
} else {
|
|
params.set_language(None);
|
|
params.set_detect_language(true);
|
|
}
|
|
|
|
if config.threads > 0 {
|
|
params.set_n_threads(config.threads);
|
|
}
|
|
|
|
params.set_print_special(false);
|
|
params.set_print_progress(false);
|
|
params.set_print_realtime(false);
|
|
params.set_print_timestamps(false);
|
|
params.set_no_context(true);
|
|
params.set_single_segment(false);
|
|
|
|
state.full(params, &audio)?;
|
|
|
|
let mut text = String::new();
|
|
for segment in state.as_iter() {
|
|
if let Ok(s) = segment.to_str_lossy() {
|
|
text.push_str(&s);
|
|
}
|
|
}
|
|
|
|
Ok(text.trim().to_string())
|
|
}
|