Add Wyoming protocol ASR server and nix devshell
New wyoming-whisper-rs binary crate implementing the Wyoming protocol over TCP, making whisper-rs usable with Home Assistant's voice pipeline. Includes nix flake devshell with Vulkan, ROCm/hipBLAS, clippy, and rustfmt support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d38738df8d
commit
50fdb08a38
12 changed files with 840 additions and 1 deletions
1
.envrc
Normal file
1
.envrc
Normal file
|
|
@ -0,0 +1 @@
|
|||
use flake
|
||||
97
CLAUDE.md
Normal file
97
CLAUDE.md
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
whisper-rs provides safe Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ speech recognition library. It's a two-crate workspace:
|
||||
|
||||
- **`whisper-rs`** (root) — Safe public API
|
||||
- **`whisper-rs-sys`** (`sys/`) — FFI bindings generated via bindgen, with a CMake-based build of the whisper.cpp submodule
|
||||
|
||||
The upstream C++ source lives in `sys/whisper.cpp/` (git submodule — clone with `--recursive`).
|
||||
|
||||
## Build Commands
|
||||
|
||||
```bash
|
||||
cargo build # Default build
|
||||
cargo build --release --features vulkan # With Vulkan GPU support
|
||||
cargo build --release --features hipblas # With AMD ROCm support
|
||||
cargo build --release --features cuda # With NVIDIA CUDA support
|
||||
cargo test # Run all tests
|
||||
cargo fmt # Format code
|
||||
cargo clippy # Lint
|
||||
```
|
||||
|
||||
**Running examples** (require a GGML model file and a WAV audio file):
|
||||
```bash
|
||||
cargo run --example basic_use -- model.bin audio.wav
|
||||
cargo run --example audio_transcription -- model.bin audio.wav
|
||||
cargo run --example vad -- model.bin audio.wav output.wav
|
||||
```
|
||||
|
||||
**Skipping bindgen** (use pre-generated bindings): set `WHISPER_DONT_GENERATE_BINDINGS=1`.
|
||||
|
||||
All `WHISPER_*` and `CMAKE_*` env vars are forwarded to the CMake build.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
whisper-rs (safe Rust API)
|
||||
→ whisper-rs-sys (bindgen FFI)
|
||||
→ whisper.cpp (C++ submodule, built via CMake)
|
||||
→ GGML (tensor library with CPU/GPU backends)
|
||||
```
|
||||
|
||||
**Key types and their relationships:**
|
||||
|
||||
- `WhisperContext` — Arc-wrapped model handle, thread-safe, created from a model file
|
||||
- `WhisperState` — Inference state created from a context; multiple states can share one context
|
||||
- `FullParams` — Transcription configuration (sampling strategy, language, callbacks)
|
||||
- `WhisperSegment` / `WhisperToken` — Result types with timestamps and probabilities
|
||||
|
||||
**Core flow:** `WhisperContext::new_with_params()` → `ctx.create_state()` → `state.full(params, &audio_data)` → iterate segments via `state.as_iter()`
|
||||
|
||||
**Module layout in `src/`:**
|
||||
- `whisper_ctx.rs` — Raw context wrapper (`WhisperInnerContext`)
|
||||
- `whisper_ctx_wrapper.rs` — Safe public `WhisperContext`
|
||||
- `whisper_state/` — State management, segments, tokens, iterators
|
||||
- `whisper_params.rs` — `FullParams`, `SamplingStrategy` (Greedy / BeamSearch)
|
||||
- `whisper_vad.rs` — Voice Activity Detection
|
||||
- `whisper_grammar.rs` — GBNF grammar-constrained decoding
|
||||
- `vulkan.rs` — Vulkan device enumeration (behind `vulkan` feature)
|
||||
|
||||
## Build System (sys/build.rs)
|
||||
|
||||
The build script:
|
||||
1. Copies `whisper.cpp` sources into `OUT_DIR`
|
||||
2. Runs bindgen on `wrapper.h` to generate FFI bindings (falls back to `src/bindings.rs` on failure)
|
||||
3. Configures and builds whisper.cpp via CMake with feature-dependent flags (`GGML_CUDA`, `GGML_HIP`, `GGML_VULKAN`, `GGML_METAL`, etc.)
|
||||
4. Statically links: `whisper`, `ggml`, `ggml-base`, `ggml-cpu`, plus backend-specific libs
|
||||
|
||||
## Feature Flags
|
||||
|
||||
| Feature | Purpose |
|
||||
|---------|---------|
|
||||
| `cuda` | NVIDIA GPU (needs CUDA toolkit) |
|
||||
| `hipblas` | AMD GPU via ROCm |
|
||||
| `metal` | Apple Metal GPU |
|
||||
| `vulkan` | Vulkan GPU |
|
||||
| `openblas` | OpenBLAS acceleration (requires `BLAS_INCLUDE_DIRS` env var) |
|
||||
| `openmp` | OpenMP threading |
|
||||
| `coreml` | Apple CoreML |
|
||||
| `intel-sycl` | Intel SYCL |
|
||||
| `raw-api` | Re-export `whisper-rs-sys` types publicly |
|
||||
| `log_backend` | Route C++ logs to the `log` crate |
|
||||
| `tracing_backend` | Route C++ logs to the `tracing` crate |
|
||||
|
||||
## Nix Development Environment
|
||||
|
||||
The `flake.nix` provides a devshell with all dependencies for Vulkan and ROCm/hipBLAS builds. Use `direnv allow` or `nix develop` to enter the environment. Key env vars (`LIBCLANG_PATH`, `BINDGEN_EXTRA_CLANG_ARGS`, `HIP_PATH`, `VULKAN_SDK`) are set automatically.
|
||||
|
||||
## PR Conventions
|
||||
|
||||
Per `.github/PULL_REQUEST_TEMPLATE.md`:
|
||||
- Run `cargo fmt` and `cargo clippy` before submitting
|
||||
- Self-review code for legibility
|
||||
- No GenAI-generated code in PRs
|
||||
|
|
@ -1,5 +1,5 @@
|
|||
[workspace]
|
||||
members = ["sys"]
|
||||
members = ["sys", "wyoming-whisper-rs"]
|
||||
exclude = ["examples/full_usage"]
|
||||
|
||||
[package]
|
||||
|
|
|
|||
27
flake.lock
generated
Normal file
27
flake.lock
generated
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
{
|
||||
"nodes": {
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1771482645,
|
||||
"narHash": "sha256-MpAKyXfJRDTgRU33Hja+G+3h9ywLAJJNRq4Pjbb4dQs=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "724cf38d99ba81fbb4a347081db93e2e3a9bc2ae",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
},
|
||||
"root": {
|
||||
"inputs": {
|
||||
"nixpkgs": "nixpkgs"
|
||||
}
|
||||
}
|
||||
},
|
||||
"root": "root",
|
||||
"version": 7
|
||||
}
|
||||
56
flake.nix
Normal file
56
flake.nix
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
{
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||
};
|
||||
|
||||
outputs = { nixpkgs, ... }:
|
||||
let
|
||||
systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
|
||||
forAllSystems = f: nixpkgs.lib.genAttrs systems (system: f nixpkgs.legacyPackages.${system});
|
||||
in
|
||||
{
|
||||
devShells = forAllSystems (pkgs: {
|
||||
default = pkgs.mkShell {
|
||||
nativeBuildInputs = [
|
||||
pkgs.rustc
|
||||
pkgs.cargo
|
||||
pkgs.clippy
|
||||
pkgs.rustfmt
|
||||
pkgs.cmake
|
||||
pkgs.pkg-config
|
||||
pkgs.shaderc
|
||||
pkgs.rocmPackages.llvm.clang
|
||||
];
|
||||
|
||||
buildInputs = [
|
||||
pkgs.libclang.lib
|
||||
pkgs.openssl
|
||||
pkgs.vulkan-headers
|
||||
pkgs.vulkan-loader
|
||||
pkgs.rocmPackages.clr
|
||||
pkgs.rocmPackages.hipblas
|
||||
pkgs.rocmPackages.rocblas
|
||||
pkgs.rocmPackages.rocm-runtime
|
||||
];
|
||||
|
||||
env = {
|
||||
LIBCLANG_PATH = "${pkgs.libclang.lib}/lib";
|
||||
VULKAN_SDK = "${pkgs.vulkan-headers}";
|
||||
HIP_PATH = "${pkgs.rocmPackages.clr}";
|
||||
BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [
|
||||
"-isystem ${pkgs.libclang.lib}/lib/clang/${pkgs.lib.versions.major pkgs.libclang.version}/include"
|
||||
"-isystem ${pkgs.glibc.dev}/include"
|
||||
];
|
||||
};
|
||||
|
||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [
|
||||
pkgs.vulkan-loader
|
||||
pkgs.rocmPackages.clr
|
||||
pkgs.rocmPackages.hipblas
|
||||
pkgs.rocmPackages.rocblas
|
||||
pkgs.rocmPackages.rocm-runtime
|
||||
];
|
||||
};
|
||||
});
|
||||
};
|
||||
}
|
||||
27
wyoming-whisper-rs/Cargo.toml
Normal file
27
wyoming-whisper-rs/Cargo.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "wyoming-whisper-rs"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
description = "Wyoming protocol ASR server powered by whisper-rs"
|
||||
license = "Unlicense"
|
||||
|
||||
[dependencies]
|
||||
whisper-rs = { path = "..", features = ["tracing_backend"] }
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "net", "io-util", "macros", "signal", "sync"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
thiserror = "2"
|
||||
anyhow = "1"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = ["whisper-rs/cuda"]
|
||||
hipblas = ["whisper-rs/hipblas"]
|
||||
metal = ["whisper-rs/metal"]
|
||||
vulkan = ["whisper-rs/vulkan"]
|
||||
openblas = ["whisper-rs/openblas"]
|
||||
openmp = ["whisper-rs/openmp"]
|
||||
coreml = ["whisper-rs/coreml"]
|
||||
49
wyoming-whisper-rs/src/cli.rs
Normal file
49
wyoming-whisper-rs/src/cli.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use clap::Parser;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "wyoming-whisper-rs",
|
||||
about = "Wyoming protocol ASR server powered by whisper-rs"
|
||||
)]
|
||||
pub struct Args {
|
||||
/// Path to the GGML whisper model file
|
||||
#[arg(long)]
|
||||
pub model: PathBuf,
|
||||
|
||||
/// Host address to bind to
|
||||
#[arg(long, default_value = "0.0.0.0")]
|
||||
pub host: String,
|
||||
|
||||
/// Port to listen on
|
||||
#[arg(long, default_value_t = 10300)]
|
||||
pub port: u16,
|
||||
|
||||
/// Language code (e.g. "en", "de"). Omit for auto-detection
|
||||
#[arg(long)]
|
||||
pub language: Option<String>,
|
||||
|
||||
/// Beam search size
|
||||
#[arg(long, default_value_t = 5)]
|
||||
pub beam_size: i32,
|
||||
|
||||
/// Number of threads for whisper inference (0 = whisper default)
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub threads: i32,
|
||||
|
||||
/// GPU device index
|
||||
#[arg(long, default_value_t = 0)]
|
||||
pub gpu_device: i32,
|
||||
|
||||
/// Disable GPU acceleration
|
||||
#[arg(long)]
|
||||
pub no_gpu: bool,
|
||||
|
||||
/// Maximum concurrent transcriptions
|
||||
#[arg(long, default_value_t = 1)]
|
||||
pub max_concurrent: usize,
|
||||
|
||||
/// Model name reported in Wyoming info
|
||||
#[arg(long, default_value = "whisper")]
|
||||
pub model_name: String,
|
||||
}
|
||||
22
wyoming-whisper-rs/src/error.rs
Normal file
22
wyoming-whisper-rs/src/error.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
use std::io;
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum Error {
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
#[error("JSON error: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
#[error("Whisper error: {0}")]
|
||||
Whisper(#[from] whisper_rs::WhisperError),
|
||||
|
||||
#[error("payload too large: {size} bytes (max {max})")]
|
||||
PayloadTooLarge { size: usize, max: usize },
|
||||
|
||||
#[error("missing field: {0}")]
|
||||
MissingField(&'static str),
|
||||
|
||||
#[error("invalid audio format: {0}")]
|
||||
InvalidAudio(String),
|
||||
}
|
||||
111
wyoming-whisper-rs/src/main.rs
Normal file
111
wyoming-whisper-rs/src/main.rs
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
mod cli;
|
||||
mod error;
|
||||
mod protocol;
|
||||
mod session;
|
||||
mod transcribe;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use serde_json::json;
|
||||
use tokio::net::TcpListener;
|
||||
use tracing::{error, info};
|
||||
use whisper_rs::{WhisperContext, WhisperContextParameters};
|
||||
|
||||
use crate::cli::Args;
|
||||
use crate::session::SessionConfig;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||
)
|
||||
.init();
|
||||
|
||||
// Validate model path
|
||||
if !args.model.exists() {
|
||||
error!(path = %args.model.display(), "model file not found");
|
||||
std::process::exit(1);
|
||||
}
|
||||
|
||||
// Validate language if specified
|
||||
if let Some(ref lang) = args.language {
|
||||
if whisper_rs::get_lang_id(lang).is_none() {
|
||||
error!(language = lang, "unknown language code");
|
||||
std::process::exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
// Route whisper.cpp logs through tracing
|
||||
whisper_rs::install_logging_hooks();
|
||||
|
||||
// Load model
|
||||
info!(model = %args.model.display(), "loading whisper model");
|
||||
let mut ctx_params = WhisperContextParameters::default();
|
||||
if args.no_gpu {
|
||||
ctx_params.use_gpu(false);
|
||||
}
|
||||
ctx_params.gpu_device(args.gpu_device);
|
||||
|
||||
let model_path = args.model.to_string_lossy().to_string();
|
||||
let ctx = WhisperContext::new_with_params(&model_path, ctx_params)?;
|
||||
let ctx = Arc::new(ctx);
|
||||
info!("model loaded");
|
||||
|
||||
// Build language list
|
||||
let max_id = whisper_rs::get_lang_max_id();
|
||||
let languages: Vec<serde_json::Value> = (0..=max_id)
|
||||
.filter_map(|id| {
|
||||
whisper_rs::get_lang_str(id).map(|code| {
|
||||
json!({
|
||||
"name": code,
|
||||
"description": whisper_rs::get_lang_str_full(id).unwrap_or(code),
|
||||
})
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let semaphore = Arc::new(tokio::sync::Semaphore::new(args.max_concurrent));
|
||||
|
||||
let session_config = Arc::new(SessionConfig {
|
||||
ctx,
|
||||
semaphore,
|
||||
model_name: args.model_name,
|
||||
languages,
|
||||
default_language: args.language,
|
||||
beam_size: args.beam_size,
|
||||
threads: args.threads,
|
||||
});
|
||||
|
||||
let bind_addr = format!("{}:{}", args.host, args.port);
|
||||
let listener = TcpListener::bind(&bind_addr).await?;
|
||||
info!(address = %bind_addr, "wyoming server listening");
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = listener.accept() => {
|
||||
let (stream, addr) = result?;
|
||||
info!(peer = %addr, "new connection");
|
||||
let config = Arc::clone(&session_config);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let (reader, writer) = stream.into_split();
|
||||
if let Err(e) = session::run(config, reader, writer).await {
|
||||
error!(peer = %addr, error = %e, "session error");
|
||||
}
|
||||
info!(peer = %addr, "connection closed");
|
||||
});
|
||||
}
|
||||
_ = tokio::signal::ctrl_c() => {
|
||||
info!("shutting down");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
90
wyoming-whisper-rs/src/protocol.rs
Normal file
90
wyoming-whisper-rs/src/protocol.rs
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
const MAX_PAYLOAD: usize = 100 * 1024 * 1024; // 100 MB
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Event {
|
||||
#[serde(rename = "type")]
|
||||
pub event_type: String,
|
||||
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
|
||||
#[serde(default, skip_serializing_if = "is_zero")]
|
||||
pub data_length: usize,
|
||||
|
||||
#[serde(default, skip_serializing_if = "is_zero")]
|
||||
pub payload_length: usize,
|
||||
|
||||
#[serde(skip)]
|
||||
pub payload: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
fn is_zero(v: &usize) -> bool {
|
||||
*v == 0
|
||||
}
|
||||
|
||||
impl Event {
|
||||
pub fn new(event_type: impl Into<String>) -> Self {
|
||||
Self {
|
||||
event_type: event_type.into(),
|
||||
data: None,
|
||||
data_length: 0,
|
||||
payload_length: 0,
|
||||
payload: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_data(mut self, data: Value) -> Self {
|
||||
let serialized = serde_json::to_string(&data).unwrap_or_default();
|
||||
self.data_length = serialized.len();
|
||||
self.data = Some(data);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn read_event<R: tokio::io::AsyncRead + Unpin>(
|
||||
reader: &mut BufReader<R>,
|
||||
) -> Result<Option<Event>, Error> {
|
||||
let mut line = String::new();
|
||||
let n = reader.read_line(&mut line).await?;
|
||||
if n == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let mut event: Event = serde_json::from_str(line.trim())?;
|
||||
|
||||
if event.payload_length > 0 {
|
||||
if event.payload_length > MAX_PAYLOAD {
|
||||
return Err(Error::PayloadTooLarge {
|
||||
size: event.payload_length,
|
||||
max: MAX_PAYLOAD,
|
||||
});
|
||||
}
|
||||
let mut buf = vec![0u8; event.payload_length];
|
||||
reader.read_exact(&mut buf).await?;
|
||||
event.payload = Some(buf);
|
||||
}
|
||||
|
||||
Ok(Some(event))
|
||||
}
|
||||
|
||||
pub async fn write_event<W: tokio::io::AsyncWrite + Unpin>(
|
||||
writer: &mut W,
|
||||
event: &Event,
|
||||
) -> Result<(), Error> {
|
||||
let json = serde_json::to_string(event)?;
|
||||
writer.write_all(json.as_bytes()).await?;
|
||||
writer.write_all(b"\n").await?;
|
||||
|
||||
if let Some(ref payload) = event.payload {
|
||||
writer.write_all(payload).await?;
|
||||
}
|
||||
|
||||
writer.flush().await?;
|
||||
Ok(())
|
||||
}
|
||||
208
wyoming-whisper-rs/src/session.rs
Normal file
208
wyoming-whisper-rs/src/session.rs
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
|
||||
use tokio::sync::Semaphore;
|
||||
use tracing::{debug, warn};
|
||||
use whisper_rs::WhisperContext;
|
||||
|
||||
use crate::error::Error;
|
||||
use crate::protocol::{read_event, write_event, Event};
|
||||
use crate::transcribe::{AudioBuffer, TranscribeConfig};
|
||||
|
||||
pub struct SessionConfig {
|
||||
pub ctx: Arc<WhisperContext>,
|
||||
pub semaphore: Arc<Semaphore>,
|
||||
pub model_name: String,
|
||||
pub languages: Vec<Value>,
|
||||
pub default_language: Option<String>,
|
||||
pub beam_size: i32,
|
||||
pub threads: i32,
|
||||
}
|
||||
|
||||
enum State {
|
||||
Idle,
|
||||
AwaitingAudioStart {
|
||||
language: Option<String>,
|
||||
},
|
||||
Streaming {
|
||||
buffer: AudioBuffer,
|
||||
language: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl State {
|
||||
fn name(&self) -> &'static str {
|
||||
match self {
|
||||
State::Idle => "idle",
|
||||
State::AwaitingAudioStart { .. } => "awaiting_audio_start",
|
||||
State::Streaming { .. } => "streaming",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run<R, W>(config: Arc<SessionConfig>, reader: R, mut writer: W) -> Result<(), Error>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let mut reader = BufReader::new(reader);
|
||||
let mut state = State::Idle;
|
||||
|
||||
loop {
|
||||
let event = match read_event(&mut reader).await? {
|
||||
Some(e) => e,
|
||||
None => {
|
||||
debug!("client disconnected");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
debug!(event_type = %event.event_type, "received event");
|
||||
|
||||
match event.event_type.as_str() {
|
||||
"describe" => {
|
||||
let info = Event::new("info").with_data(json!({
|
||||
"type": "asr",
|
||||
"name": config.model_name,
|
||||
"description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION),
|
||||
"installed": true,
|
||||
"attribution": {
|
||||
"name": "whisper.cpp",
|
||||
"url": "https://github.com/ggerganov/whisper.cpp"
|
||||
},
|
||||
"models": [{
|
||||
"name": config.model_name,
|
||||
"description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION),
|
||||
"installed": true,
|
||||
"attribution": {
|
||||
"name": "whisper.cpp",
|
||||
"url": "https://github.com/ggerganov/whisper.cpp"
|
||||
},
|
||||
"languages": config.languages,
|
||||
}],
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
}));
|
||||
write_event(&mut writer, &info).await?;
|
||||
}
|
||||
|
||||
"transcribe" => {
|
||||
let language = event
|
||||
.data
|
||||
.as_ref()
|
||||
.and_then(|d| d.get("language"))
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.or_else(|| config.default_language.clone());
|
||||
|
||||
state = State::AwaitingAudioStart { language };
|
||||
}
|
||||
|
||||
"audio-start" => {
|
||||
let (rate, width, channels) = parse_audio_start(&event)?;
|
||||
let language = match state {
|
||||
State::AwaitingAudioStart { language } => language,
|
||||
_ => config.default_language.clone(),
|
||||
};
|
||||
|
||||
state = State::Streaming {
|
||||
buffer: AudioBuffer::new(rate, width, channels),
|
||||
language,
|
||||
};
|
||||
}
|
||||
|
||||
"audio-chunk" => {
|
||||
if let State::Streaming { ref mut buffer, .. } = state {
|
||||
if let Some(ref payload) = event.payload {
|
||||
buffer.append(payload);
|
||||
}
|
||||
} else {
|
||||
warn!(state = state.name(), "audio-chunk received in wrong state");
|
||||
}
|
||||
}
|
||||
|
||||
"audio-stop" => {
|
||||
if let State::Streaming { buffer, language } = state {
|
||||
state = State::Idle;
|
||||
|
||||
let tc = TranscribeConfig {
|
||||
language,
|
||||
beam_size: config.beam_size,
|
||||
threads: config.threads,
|
||||
};
|
||||
|
||||
match do_transcribe(&config.ctx, &config.semaphore, tc, buffer).await {
|
||||
Ok(text) => {
|
||||
let transcript = Event::new("transcript").with_data(json!({
|
||||
"text": text,
|
||||
}));
|
||||
write_event(&mut writer, &transcript).await?;
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(error = %e, "transcription failed");
|
||||
let err = Event::new("error").with_data(json!({
|
||||
"text": format!("transcription failed: {e}"),
|
||||
}));
|
||||
write_event(&mut writer, &err).await?;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
warn!(state = state.name(), "audio-stop received in wrong state");
|
||||
state = State::Idle;
|
||||
}
|
||||
}
|
||||
|
||||
other => {
|
||||
warn!(event_type = other, "unknown event type");
|
||||
let err = Event::new("error").with_data(json!({
|
||||
"text": format!("unknown event type: {other}"),
|
||||
}));
|
||||
write_event(&mut writer, &err).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_audio_start(event: &Event) -> Result<(u32, u16, u16), Error> {
|
||||
let data = event
|
||||
.data
|
||||
.as_ref()
|
||||
.ok_or(Error::MissingField("data in audio-start"))?;
|
||||
|
||||
let rate = data
|
||||
.get("rate")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or(Error::MissingField("rate"))? as u32;
|
||||
|
||||
let width = data
|
||||
.get("width")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or(Error::MissingField("width"))? as u16;
|
||||
|
||||
let channels = data
|
||||
.get("channels")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or(Error::MissingField("channels"))? as u16;
|
||||
|
||||
Ok((rate, width, channels))
|
||||
}
|
||||
|
||||
async fn do_transcribe(
|
||||
ctx: &Arc<WhisperContext>,
|
||||
semaphore: &Arc<Semaphore>,
|
||||
config: TranscribeConfig,
|
||||
buffer: AudioBuffer,
|
||||
) -> Result<String, Error> {
|
||||
let audio = buffer.into_f32_16khz_mono()?;
|
||||
|
||||
if audio.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
let _permit = semaphore.acquire().await.expect("semaphore closed");
|
||||
let ctx = Arc::clone(ctx);
|
||||
|
||||
tokio::task::spawn_blocking(move || crate::transcribe::transcribe(&ctx, &config, audio))
|
||||
.await
|
||||
.expect("transcription task panicked")
|
||||
}
|
||||
151
wyoming-whisper-rs/src/transcribe.rs
Normal file
151
wyoming-whisper-rs/src/transcribe.rs
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
|
||||
|
||||
use crate::error::Error;
|
||||
|
||||
pub struct TranscribeConfig {
|
||||
pub language: Option<String>,
|
||||
pub beam_size: i32,
|
||||
pub threads: i32,
|
||||
}
|
||||
|
||||
pub struct AudioBuffer {
|
||||
data: Vec<u8>,
|
||||
rate: u32,
|
||||
width: u16,
|
||||
channels: u16,
|
||||
}
|
||||
|
||||
impl AudioBuffer {
|
||||
pub fn new(rate: u32, width: u16, channels: u16) -> Self {
|
||||
Self {
|
||||
data: Vec::new(),
|
||||
rate,
|
||||
width,
|
||||
channels,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn append(&mut self, chunk: &[u8]) {
|
||||
self.data.extend_from_slice(chunk);
|
||||
}
|
||||
|
||||
pub fn into_f32_16khz_mono(self) -> Result<Vec<f32>, Error> {
|
||||
if self.width != 2 {
|
||||
return Err(Error::InvalidAudio(format!(
|
||||
"expected 16-bit audio (width=2), got width={}",
|
||||
self.width
|
||||
)));
|
||||
}
|
||||
|
||||
if !self.data.len().is_multiple_of(2) {
|
||||
return Err(Error::InvalidAudio(
|
||||
"audio data has odd number of bytes for 16-bit samples".into(),
|
||||
));
|
||||
}
|
||||
|
||||
// Interpret as i16 little-endian
|
||||
let samples_i16: Vec<i16> = self
|
||||
.data
|
||||
.chunks_exact(2)
|
||||
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
||||
.collect();
|
||||
|
||||
// Convert i16 -> f32
|
||||
let mut samples_f32 = vec![0.0f32; samples_i16.len()];
|
||||
whisper_rs::convert_integer_to_float_audio(&samples_i16, &mut samples_f32)
|
||||
.map_err(|e| Error::InvalidAudio(format!("i16 to f32 conversion failed: {e}")))?;
|
||||
|
||||
// Convert stereo to mono if needed
|
||||
let mono = if self.channels == 2 {
|
||||
let mut mono = vec![0.0f32; samples_f32.len() / 2];
|
||||
whisper_rs::convert_stereo_to_mono_audio(&samples_f32, &mut mono)
|
||||
.map_err(|e| Error::InvalidAudio(format!("stereo to mono failed: {e}")))?;
|
||||
mono
|
||||
} else if self.channels == 1 {
|
||||
samples_f32
|
||||
} else {
|
||||
return Err(Error::InvalidAudio(format!(
|
||||
"unsupported channel count: {}",
|
||||
self.channels
|
||||
)));
|
||||
};
|
||||
|
||||
// Resample if not 16kHz
|
||||
if self.rate == 16000 {
|
||||
Ok(mono)
|
||||
} else {
|
||||
Ok(resample(&mono, self.rate, 16000))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple linear interpolation resampler.
|
||||
fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
|
||||
if from_rate == to_rate || input.is_empty() {
|
||||
return input.to_vec();
|
||||
}
|
||||
|
||||
let ratio = from_rate as f64 / to_rate as f64;
|
||||
let output_len = ((input.len() as f64) / ratio).ceil() as usize;
|
||||
let mut output = Vec::with_capacity(output_len);
|
||||
|
||||
for i in 0..output_len {
|
||||
let src_pos = i as f64 * ratio;
|
||||
let idx = src_pos as usize;
|
||||
let frac = src_pos - idx as f64;
|
||||
|
||||
let sample = if idx + 1 < input.len() {
|
||||
input[idx] as f64 * (1.0 - frac) + input[idx + 1] as f64 * frac
|
||||
} else {
|
||||
input[idx.min(input.len() - 1)] as f64
|
||||
};
|
||||
|
||||
output.push(sample as f32);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn transcribe(
|
||||
ctx: &Arc<WhisperContext>,
|
||||
config: &TranscribeConfig,
|
||||
audio: Vec<f32>,
|
||||
) -> Result<String, Error> {
|
||||
let mut state = ctx.create_state()?;
|
||||
|
||||
let mut params = FullParams::new(SamplingStrategy::BeamSearch {
|
||||
beam_size: config.beam_size,
|
||||
patience: -1.0,
|
||||
});
|
||||
|
||||
if let Some(ref lang) = config.language {
|
||||
params.set_language(Some(lang));
|
||||
} else {
|
||||
params.set_language(None);
|
||||
params.set_detect_language(true);
|
||||
}
|
||||
|
||||
if config.threads > 0 {
|
||||
params.set_n_threads(config.threads);
|
||||
}
|
||||
|
||||
params.set_print_special(false);
|
||||
params.set_print_progress(false);
|
||||
params.set_print_realtime(false);
|
||||
params.set_print_timestamps(false);
|
||||
params.set_no_context(true);
|
||||
params.set_single_segment(false);
|
||||
|
||||
state.full(params, &audio)?;
|
||||
|
||||
let mut text = String::new();
|
||||
for segment in state.as_iter() {
|
||||
if let Ok(s) = segment.to_str_lossy() {
|
||||
text.push_str(&s);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue