From 50fdb08a382685cc46cb42db69b30bf1186af00a Mon Sep 17 00:00:00 2001 From: Harald Hoyer Date: Tue, 24 Feb 2026 11:44:03 +0100 Subject: [PATCH] Add Wyoming protocol ASR server and nix devshell New wyoming-whisper-rs binary crate implementing the Wyoming protocol over TCP, making whisper-rs usable with Home Assistant's voice pipeline. Includes nix flake devshell with Vulkan, ROCm/hipBLAS, clippy, and rustfmt support. Co-Authored-By: Claude Opus 4.6 --- .envrc | 1 + CLAUDE.md | 97 +++++++++++++ Cargo.toml | 2 +- flake.lock | 27 ++++ flake.nix | 56 ++++++++ wyoming-whisper-rs/Cargo.toml | 27 ++++ wyoming-whisper-rs/src/cli.rs | 49 +++++++ wyoming-whisper-rs/src/error.rs | 22 +++ wyoming-whisper-rs/src/main.rs | 111 ++++++++++++++ wyoming-whisper-rs/src/protocol.rs | 90 ++++++++++++ wyoming-whisper-rs/src/session.rs | 208 +++++++++++++++++++++++++++ wyoming-whisper-rs/src/transcribe.rs | 151 +++++++++++++++++++ 12 files changed, 840 insertions(+), 1 deletion(-) create mode 100644 .envrc create mode 100644 CLAUDE.md create mode 100644 flake.lock create mode 100644 flake.nix create mode 100644 wyoming-whisper-rs/Cargo.toml create mode 100644 wyoming-whisper-rs/src/cli.rs create mode 100644 wyoming-whisper-rs/src/error.rs create mode 100644 wyoming-whisper-rs/src/main.rs create mode 100644 wyoming-whisper-rs/src/protocol.rs create mode 100644 wyoming-whisper-rs/src/session.rs create mode 100644 wyoming-whisper-rs/src/transcribe.rs diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..2c44afd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,97 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +whisper-rs provides safe Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ speech recognition library. It's a two-crate workspace: + +- **`whisper-rs`** (root) — Safe public API +- **`whisper-rs-sys`** (`sys/`) — FFI bindings generated via bindgen, with a CMake-based build of the whisper.cpp submodule + +The upstream C++ source lives in `sys/whisper.cpp/` (git submodule — clone with `--recursive`). + +## Build Commands + +```bash +cargo build # Default build +cargo build --release --features vulkan # With Vulkan GPU support +cargo build --release --features hipblas # With AMD ROCm support +cargo build --release --features cuda # With NVIDIA CUDA support +cargo test # Run all tests +cargo fmt # Format code +cargo clippy # Lint +``` + +**Running examples** (require a GGML model file and a WAV audio file): +```bash +cargo run --example basic_use -- model.bin audio.wav +cargo run --example audio_transcription -- model.bin audio.wav +cargo run --example vad -- model.bin audio.wav output.wav +``` + +**Skipping bindgen** (use pre-generated bindings): set `WHISPER_DONT_GENERATE_BINDINGS=1`. + +All `WHISPER_*` and `CMAKE_*` env vars are forwarded to the CMake build. + +## Architecture + +``` +whisper-rs (safe Rust API) + → whisper-rs-sys (bindgen FFI) + → whisper.cpp (C++ submodule, built via CMake) + → GGML (tensor library with CPU/GPU backends) +``` + +**Key types and their relationships:** + +- `WhisperContext` — Arc-wrapped model handle, thread-safe, created from a model file +- `WhisperState` — Inference state created from a context; multiple states can share one context +- `FullParams` — Transcription configuration (sampling strategy, language, callbacks) +- `WhisperSegment` / `WhisperToken` — Result types with timestamps and probabilities + +**Core flow:** `WhisperContext::new_with_params()` → `ctx.create_state()` → `state.full(params, &audio_data)` → iterate segments via `state.as_iter()` + +**Module layout in `src/`:** +- `whisper_ctx.rs` — Raw context wrapper (`WhisperInnerContext`) +- `whisper_ctx_wrapper.rs` — Safe public `WhisperContext` +- `whisper_state/` — State management, segments, tokens, iterators +- `whisper_params.rs` — `FullParams`, `SamplingStrategy` (Greedy / BeamSearch) +- `whisper_vad.rs` — Voice Activity Detection +- `whisper_grammar.rs` — GBNF grammar-constrained decoding +- `vulkan.rs` — Vulkan device enumeration (behind `vulkan` feature) + +## Build System (sys/build.rs) + +The build script: +1. Copies `whisper.cpp` sources into `OUT_DIR` +2. Runs bindgen on `wrapper.h` to generate FFI bindings (falls back to `src/bindings.rs` on failure) +3. Configures and builds whisper.cpp via CMake with feature-dependent flags (`GGML_CUDA`, `GGML_HIP`, `GGML_VULKAN`, `GGML_METAL`, etc.) +4. Statically links: `whisper`, `ggml`, `ggml-base`, `ggml-cpu`, plus backend-specific libs + +## Feature Flags + +| Feature | Purpose | +|---------|---------| +| `cuda` | NVIDIA GPU (needs CUDA toolkit) | +| `hipblas` | AMD GPU via ROCm | +| `metal` | Apple Metal GPU | +| `vulkan` | Vulkan GPU | +| `openblas` | OpenBLAS acceleration (requires `BLAS_INCLUDE_DIRS` env var) | +| `openmp` | OpenMP threading | +| `coreml` | Apple CoreML | +| `intel-sycl` | Intel SYCL | +| `raw-api` | Re-export `whisper-rs-sys` types publicly | +| `log_backend` | Route C++ logs to the `log` crate | +| `tracing_backend` | Route C++ logs to the `tracing` crate | + +## Nix Development Environment + +The `flake.nix` provides a devshell with all dependencies for Vulkan and ROCm/hipBLAS builds. Use `direnv allow` or `nix develop` to enter the environment. Key env vars (`LIBCLANG_PATH`, `BINDGEN_EXTRA_CLANG_ARGS`, `HIP_PATH`, `VULKAN_SDK`) are set automatically. + +## PR Conventions + +Per `.github/PULL_REQUEST_TEMPLATE.md`: +- Run `cargo fmt` and `cargo clippy` before submitting +- Self-review code for legibility +- No GenAI-generated code in PRs diff --git a/Cargo.toml b/Cargo.toml index 24849b6..685b1ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["sys"] +members = ["sys", "wyoming-whisper-rs"] exclude = ["examples/full_usage"] [package] diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..516965b --- /dev/null +++ b/flake.lock @@ -0,0 +1,27 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1771482645, + "narHash": "sha256-MpAKyXfJRDTgRU33Hja+G+3h9ywLAJJNRq4Pjbb4dQs=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "724cf38d99ba81fbb4a347081db93e2e3a9bc2ae", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..340736e --- /dev/null +++ b/flake.nix @@ -0,0 +1,56 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + }; + + outputs = { nixpkgs, ... }: + let + systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; + forAllSystems = f: nixpkgs.lib.genAttrs systems (system: f nixpkgs.legacyPackages.${system}); + in + { + devShells = forAllSystems (pkgs: { + default = pkgs.mkShell { + nativeBuildInputs = [ + pkgs.rustc + pkgs.cargo + pkgs.clippy + pkgs.rustfmt + pkgs.cmake + pkgs.pkg-config + pkgs.shaderc + pkgs.rocmPackages.llvm.clang + ]; + + buildInputs = [ + pkgs.libclang.lib + pkgs.openssl + pkgs.vulkan-headers + pkgs.vulkan-loader + pkgs.rocmPackages.clr + pkgs.rocmPackages.hipblas + pkgs.rocmPackages.rocblas + pkgs.rocmPackages.rocm-runtime + ]; + + env = { + LIBCLANG_PATH = "${pkgs.libclang.lib}/lib"; + VULKAN_SDK = "${pkgs.vulkan-headers}"; + HIP_PATH = "${pkgs.rocmPackages.clr}"; + BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [ + "-isystem ${pkgs.libclang.lib}/lib/clang/${pkgs.lib.versions.major pkgs.libclang.version}/include" + "-isystem ${pkgs.glibc.dev}/include" + ]; + }; + + LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [ + pkgs.vulkan-loader + pkgs.rocmPackages.clr + pkgs.rocmPackages.hipblas + pkgs.rocmPackages.rocblas + pkgs.rocmPackages.rocm-runtime + ]; + }; + }); + }; +} diff --git a/wyoming-whisper-rs/Cargo.toml b/wyoming-whisper-rs/Cargo.toml new file mode 100644 index 0000000..c6c55d8 --- /dev/null +++ b/wyoming-whisper-rs/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "wyoming-whisper-rs" +version = "0.1.0" +edition = "2021" +description = "Wyoming protocol ASR server powered by whisper-rs" +license = "Unlicense" + +[dependencies] +whisper-rs = { path = "..", features = ["tracing_backend"] } +tokio = { version = "1", features = ["rt-multi-thread", "net", "io-util", "macros", "signal", "sync"] } +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +thiserror = "2" +anyhow = "1" + +[features] +default = [] +cuda = ["whisper-rs/cuda"] +hipblas = ["whisper-rs/hipblas"] +metal = ["whisper-rs/metal"] +vulkan = ["whisper-rs/vulkan"] +openblas = ["whisper-rs/openblas"] +openmp = ["whisper-rs/openmp"] +coreml = ["whisper-rs/coreml"] diff --git a/wyoming-whisper-rs/src/cli.rs b/wyoming-whisper-rs/src/cli.rs new file mode 100644 index 0000000..e4b7328 --- /dev/null +++ b/wyoming-whisper-rs/src/cli.rs @@ -0,0 +1,49 @@ +use clap::Parser; +use std::path::PathBuf; + +#[derive(Parser, Debug)] +#[command( + name = "wyoming-whisper-rs", + about = "Wyoming protocol ASR server powered by whisper-rs" +)] +pub struct Args { + /// Path to the GGML whisper model file + #[arg(long)] + pub model: PathBuf, + + /// Host address to bind to + #[arg(long, default_value = "0.0.0.0")] + pub host: String, + + /// Port to listen on + #[arg(long, default_value_t = 10300)] + pub port: u16, + + /// Language code (e.g. "en", "de"). Omit for auto-detection + #[arg(long)] + pub language: Option, + + /// Beam search size + #[arg(long, default_value_t = 5)] + pub beam_size: i32, + + /// Number of threads for whisper inference (0 = whisper default) + #[arg(long, default_value_t = 0)] + pub threads: i32, + + /// GPU device index + #[arg(long, default_value_t = 0)] + pub gpu_device: i32, + + /// Disable GPU acceleration + #[arg(long)] + pub no_gpu: bool, + + /// Maximum concurrent transcriptions + #[arg(long, default_value_t = 1)] + pub max_concurrent: usize, + + /// Model name reported in Wyoming info + #[arg(long, default_value = "whisper")] + pub model_name: String, +} diff --git a/wyoming-whisper-rs/src/error.rs b/wyoming-whisper-rs/src/error.rs new file mode 100644 index 0000000..bd83731 --- /dev/null +++ b/wyoming-whisper-rs/src/error.rs @@ -0,0 +1,22 @@ +use std::io; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("I/O error: {0}")] + Io(#[from] io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Whisper error: {0}")] + Whisper(#[from] whisper_rs::WhisperError), + + #[error("payload too large: {size} bytes (max {max})")] + PayloadTooLarge { size: usize, max: usize }, + + #[error("missing field: {0}")] + MissingField(&'static str), + + #[error("invalid audio format: {0}")] + InvalidAudio(String), +} diff --git a/wyoming-whisper-rs/src/main.rs b/wyoming-whisper-rs/src/main.rs new file mode 100644 index 0000000..404ccce --- /dev/null +++ b/wyoming-whisper-rs/src/main.rs @@ -0,0 +1,111 @@ +mod cli; +mod error; +mod protocol; +mod session; +mod transcribe; + +use std::sync::Arc; + +use clap::Parser; +use serde_json::json; +use tokio::net::TcpListener; +use tracing::{error, info}; +use whisper_rs::{WhisperContext, WhisperContextParameters}; + +use crate::cli::Args; +use crate::session::SessionConfig; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + // Validate model path + if !args.model.exists() { + error!(path = %args.model.display(), "model file not found"); + std::process::exit(1); + } + + // Validate language if specified + if let Some(ref lang) = args.language { + if whisper_rs::get_lang_id(lang).is_none() { + error!(language = lang, "unknown language code"); + std::process::exit(1); + } + } + + // Route whisper.cpp logs through tracing + whisper_rs::install_logging_hooks(); + + // Load model + info!(model = %args.model.display(), "loading whisper model"); + let mut ctx_params = WhisperContextParameters::default(); + if args.no_gpu { + ctx_params.use_gpu(false); + } + ctx_params.gpu_device(args.gpu_device); + + let model_path = args.model.to_string_lossy().to_string(); + let ctx = WhisperContext::new_with_params(&model_path, ctx_params)?; + let ctx = Arc::new(ctx); + info!("model loaded"); + + // Build language list + let max_id = whisper_rs::get_lang_max_id(); + let languages: Vec = (0..=max_id) + .filter_map(|id| { + whisper_rs::get_lang_str(id).map(|code| { + json!({ + "name": code, + "description": whisper_rs::get_lang_str_full(id).unwrap_or(code), + }) + }) + }) + .collect(); + + let semaphore = Arc::new(tokio::sync::Semaphore::new(args.max_concurrent)); + + let session_config = Arc::new(SessionConfig { + ctx, + semaphore, + model_name: args.model_name, + languages, + default_language: args.language, + beam_size: args.beam_size, + threads: args.threads, + }); + + let bind_addr = format!("{}:{}", args.host, args.port); + let listener = TcpListener::bind(&bind_addr).await?; + info!(address = %bind_addr, "wyoming server listening"); + + loop { + tokio::select! { + result = listener.accept() => { + let (stream, addr) = result?; + info!(peer = %addr, "new connection"); + let config = Arc::clone(&session_config); + + tokio::spawn(async move { + let (reader, writer) = stream.into_split(); + if let Err(e) = session::run(config, reader, writer).await { + error!(peer = %addr, error = %e, "session error"); + } + info!(peer = %addr, "connection closed"); + }); + } + _ = tokio::signal::ctrl_c() => { + info!("shutting down"); + break; + } + } + } + + Ok(()) +} diff --git a/wyoming-whisper-rs/src/protocol.rs b/wyoming-whisper-rs/src/protocol.rs new file mode 100644 index 0000000..02ddbe1 --- /dev/null +++ b/wyoming-whisper-rs/src/protocol.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; + +use crate::error::Error; + +const MAX_PAYLOAD: usize = 100 * 1024 * 1024; // 100 MB + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Event { + #[serde(rename = "type")] + pub event_type: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub data: Option, + + #[serde(default, skip_serializing_if = "is_zero")] + pub data_length: usize, + + #[serde(default, skip_serializing_if = "is_zero")] + pub payload_length: usize, + + #[serde(skip)] + pub payload: Option>, +} + +fn is_zero(v: &usize) -> bool { + *v == 0 +} + +impl Event { + pub fn new(event_type: impl Into) -> Self { + Self { + event_type: event_type.into(), + data: None, + data_length: 0, + payload_length: 0, + payload: None, + } + } + + pub fn with_data(mut self, data: Value) -> Self { + let serialized = serde_json::to_string(&data).unwrap_or_default(); + self.data_length = serialized.len(); + self.data = Some(data); + self + } +} + +pub async fn read_event( + reader: &mut BufReader, +) -> Result, Error> { + let mut line = String::new(); + let n = reader.read_line(&mut line).await?; + if n == 0 { + return Ok(None); + } + + let mut event: Event = serde_json::from_str(line.trim())?; + + if event.payload_length > 0 { + if event.payload_length > MAX_PAYLOAD { + return Err(Error::PayloadTooLarge { + size: event.payload_length, + max: MAX_PAYLOAD, + }); + } + let mut buf = vec![0u8; event.payload_length]; + reader.read_exact(&mut buf).await?; + event.payload = Some(buf); + } + + Ok(Some(event)) +} + +pub async fn write_event( + writer: &mut W, + event: &Event, +) -> Result<(), Error> { + let json = serde_json::to_string(event)?; + writer.write_all(json.as_bytes()).await?; + writer.write_all(b"\n").await?; + + if let Some(ref payload) = event.payload { + writer.write_all(payload).await?; + } + + writer.flush().await?; + Ok(()) +} diff --git a/wyoming-whisper-rs/src/session.rs b/wyoming-whisper-rs/src/session.rs new file mode 100644 index 0000000..5c719d9 --- /dev/null +++ b/wyoming-whisper-rs/src/session.rs @@ -0,0 +1,208 @@ +use std::sync::Arc; + +use serde_json::{json, Value}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader}; +use tokio::sync::Semaphore; +use tracing::{debug, warn}; +use whisper_rs::WhisperContext; + +use crate::error::Error; +use crate::protocol::{read_event, write_event, Event}; +use crate::transcribe::{AudioBuffer, TranscribeConfig}; + +pub struct SessionConfig { + pub ctx: Arc, + pub semaphore: Arc, + pub model_name: String, + pub languages: Vec, + pub default_language: Option, + pub beam_size: i32, + pub threads: i32, +} + +enum State { + Idle, + AwaitingAudioStart { + language: Option, + }, + Streaming { + buffer: AudioBuffer, + language: Option, + }, +} + +impl State { + fn name(&self) -> &'static str { + match self { + State::Idle => "idle", + State::AwaitingAudioStart { .. } => "awaiting_audio_start", + State::Streaming { .. } => "streaming", + } + } +} + +pub async fn run(config: Arc, reader: R, mut writer: W) -> Result<(), Error> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut reader = BufReader::new(reader); + let mut state = State::Idle; + + loop { + let event = match read_event(&mut reader).await? { + Some(e) => e, + None => { + debug!("client disconnected"); + return Ok(()); + } + }; + + debug!(event_type = %event.event_type, "received event"); + + match event.event_type.as_str() { + "describe" => { + let info = Event::new("info").with_data(json!({ + "type": "asr", + "name": config.model_name, + "description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION), + "installed": true, + "attribution": { + "name": "whisper.cpp", + "url": "https://github.com/ggerganov/whisper.cpp" + }, + "models": [{ + "name": config.model_name, + "description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION), + "installed": true, + "attribution": { + "name": "whisper.cpp", + "url": "https://github.com/ggerganov/whisper.cpp" + }, + "languages": config.languages, + }], + "version": env!("CARGO_PKG_VERSION"), + })); + write_event(&mut writer, &info).await?; + } + + "transcribe" => { + let language = event + .data + .as_ref() + .and_then(|d| d.get("language")) + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| config.default_language.clone()); + + state = State::AwaitingAudioStart { language }; + } + + "audio-start" => { + let (rate, width, channels) = parse_audio_start(&event)?; + let language = match state { + State::AwaitingAudioStart { language } => language, + _ => config.default_language.clone(), + }; + + state = State::Streaming { + buffer: AudioBuffer::new(rate, width, channels), + language, + }; + } + + "audio-chunk" => { + if let State::Streaming { ref mut buffer, .. } = state { + if let Some(ref payload) = event.payload { + buffer.append(payload); + } + } else { + warn!(state = state.name(), "audio-chunk received in wrong state"); + } + } + + "audio-stop" => { + if let State::Streaming { buffer, language } = state { + state = State::Idle; + + let tc = TranscribeConfig { + language, + beam_size: config.beam_size, + threads: config.threads, + }; + + match do_transcribe(&config.ctx, &config.semaphore, tc, buffer).await { + Ok(text) => { + let transcript = Event::new("transcript").with_data(json!({ + "text": text, + })); + write_event(&mut writer, &transcript).await?; + } + Err(e) => { + warn!(error = %e, "transcription failed"); + let err = Event::new("error").with_data(json!({ + "text": format!("transcription failed: {e}"), + })); + write_event(&mut writer, &err).await?; + } + } + } else { + warn!(state = state.name(), "audio-stop received in wrong state"); + state = State::Idle; + } + } + + other => { + warn!(event_type = other, "unknown event type"); + let err = Event::new("error").with_data(json!({ + "text": format!("unknown event type: {other}"), + })); + write_event(&mut writer, &err).await?; + } + } + } +} + +fn parse_audio_start(event: &Event) -> Result<(u32, u16, u16), Error> { + let data = event + .data + .as_ref() + .ok_or(Error::MissingField("data in audio-start"))?; + + let rate = data + .get("rate") + .and_then(|v| v.as_u64()) + .ok_or(Error::MissingField("rate"))? as u32; + + let width = data + .get("width") + .and_then(|v| v.as_u64()) + .ok_or(Error::MissingField("width"))? as u16; + + let channels = data + .get("channels") + .and_then(|v| v.as_u64()) + .ok_or(Error::MissingField("channels"))? as u16; + + Ok((rate, width, channels)) +} + +async fn do_transcribe( + ctx: &Arc, + semaphore: &Arc, + config: TranscribeConfig, + buffer: AudioBuffer, +) -> Result { + let audio = buffer.into_f32_16khz_mono()?; + + if audio.is_empty() { + return Ok(String::new()); + } + + let _permit = semaphore.acquire().await.expect("semaphore closed"); + let ctx = Arc::clone(ctx); + + tokio::task::spawn_blocking(move || crate::transcribe::transcribe(&ctx, &config, audio)) + .await + .expect("transcription task panicked") +} diff --git a/wyoming-whisper-rs/src/transcribe.rs b/wyoming-whisper-rs/src/transcribe.rs new file mode 100644 index 0000000..e7220d7 --- /dev/null +++ b/wyoming-whisper-rs/src/transcribe.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; + +use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; + +use crate::error::Error; + +pub struct TranscribeConfig { + pub language: Option, + pub beam_size: i32, + pub threads: i32, +} + +pub struct AudioBuffer { + data: Vec, + rate: u32, + width: u16, + channels: u16, +} + +impl AudioBuffer { + pub fn new(rate: u32, width: u16, channels: u16) -> Self { + Self { + data: Vec::new(), + rate, + width, + channels, + } + } + + pub fn append(&mut self, chunk: &[u8]) { + self.data.extend_from_slice(chunk); + } + + pub fn into_f32_16khz_mono(self) -> Result, Error> { + if self.width != 2 { + return Err(Error::InvalidAudio(format!( + "expected 16-bit audio (width=2), got width={}", + self.width + ))); + } + + if !self.data.len().is_multiple_of(2) { + return Err(Error::InvalidAudio( + "audio data has odd number of bytes for 16-bit samples".into(), + )); + } + + // Interpret as i16 little-endian + let samples_i16: Vec = self + .data + .chunks_exact(2) + .map(|c| i16::from_le_bytes([c[0], c[1]])) + .collect(); + + // Convert i16 -> f32 + let mut samples_f32 = vec![0.0f32; samples_i16.len()]; + whisper_rs::convert_integer_to_float_audio(&samples_i16, &mut samples_f32) + .map_err(|e| Error::InvalidAudio(format!("i16 to f32 conversion failed: {e}")))?; + + // Convert stereo to mono if needed + let mono = if self.channels == 2 { + let mut mono = vec![0.0f32; samples_f32.len() / 2]; + whisper_rs::convert_stereo_to_mono_audio(&samples_f32, &mut mono) + .map_err(|e| Error::InvalidAudio(format!("stereo to mono failed: {e}")))?; + mono + } else if self.channels == 1 { + samples_f32 + } else { + return Err(Error::InvalidAudio(format!( + "unsupported channel count: {}", + self.channels + ))); + }; + + // Resample if not 16kHz + if self.rate == 16000 { + Ok(mono) + } else { + Ok(resample(&mono, self.rate, 16000)) + } + } +} + +/// Simple linear interpolation resampler. +fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec { + if from_rate == to_rate || input.is_empty() { + return input.to_vec(); + } + + let ratio = from_rate as f64 / to_rate as f64; + let output_len = ((input.len() as f64) / ratio).ceil() as usize; + let mut output = Vec::with_capacity(output_len); + + for i in 0..output_len { + let src_pos = i as f64 * ratio; + let idx = src_pos as usize; + let frac = src_pos - idx as f64; + + let sample = if idx + 1 < input.len() { + input[idx] as f64 * (1.0 - frac) + input[idx + 1] as f64 * frac + } else { + input[idx.min(input.len() - 1)] as f64 + }; + + output.push(sample as f32); + } + + output +} + +pub fn transcribe( + ctx: &Arc, + config: &TranscribeConfig, + audio: Vec, +) -> Result { + let mut state = ctx.create_state()?; + + let mut params = FullParams::new(SamplingStrategy::BeamSearch { + beam_size: config.beam_size, + patience: -1.0, + }); + + if let Some(ref lang) = config.language { + params.set_language(Some(lang)); + } else { + params.set_language(None); + params.set_detect_language(true); + } + + if config.threads > 0 { + params.set_n_threads(config.threads); + } + + params.set_print_special(false); + params.set_print_progress(false); + params.set_print_realtime(false); + params.set_print_timestamps(false); + params.set_no_context(true); + params.set_single_segment(false); + + state.full(params, &audio)?; + + let mut text = String::new(); + for segment in state.as_iter() { + if let Ok(s) = segment.to_str_lossy() { + text.push_str(&s); + } + } + + Ok(text.trim().to_string()) +}