Add Wyoming protocol ASR server and nix devshell
New wyoming-whisper-rs binary crate implementing the Wyoming protocol over TCP, making whisper-rs usable with Home Assistant's voice pipeline. Includes nix flake devshell with Vulkan, ROCm/hipBLAS, clippy, and rustfmt support. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
d38738df8d
commit
50fdb08a38
12 changed files with 840 additions and 1 deletions
1
.envrc
Normal file
1
.envrc
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
use flake
|
||||||
97
CLAUDE.md
Normal file
97
CLAUDE.md
Normal file
|
|
@ -0,0 +1,97 @@
|
||||||
|
# CLAUDE.md
|
||||||
|
|
||||||
|
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||||
|
|
||||||
|
## Project Overview
|
||||||
|
|
||||||
|
whisper-rs provides safe Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ speech recognition library. It's a two-crate workspace:
|
||||||
|
|
||||||
|
- **`whisper-rs`** (root) — Safe public API
|
||||||
|
- **`whisper-rs-sys`** (`sys/`) — FFI bindings generated via bindgen, with a CMake-based build of the whisper.cpp submodule
|
||||||
|
|
||||||
|
The upstream C++ source lives in `sys/whisper.cpp/` (git submodule — clone with `--recursive`).
|
||||||
|
|
||||||
|
## Build Commands
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo build # Default build
|
||||||
|
cargo build --release --features vulkan # With Vulkan GPU support
|
||||||
|
cargo build --release --features hipblas # With AMD ROCm support
|
||||||
|
cargo build --release --features cuda # With NVIDIA CUDA support
|
||||||
|
cargo test # Run all tests
|
||||||
|
cargo fmt # Format code
|
||||||
|
cargo clippy # Lint
|
||||||
|
```
|
||||||
|
|
||||||
|
**Running examples** (require a GGML model file and a WAV audio file):
|
||||||
|
```bash
|
||||||
|
cargo run --example basic_use -- model.bin audio.wav
|
||||||
|
cargo run --example audio_transcription -- model.bin audio.wav
|
||||||
|
cargo run --example vad -- model.bin audio.wav output.wav
|
||||||
|
```
|
||||||
|
|
||||||
|
**Skipping bindgen** (use pre-generated bindings): set `WHISPER_DONT_GENERATE_BINDINGS=1`.
|
||||||
|
|
||||||
|
All `WHISPER_*` and `CMAKE_*` env vars are forwarded to the CMake build.
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
whisper-rs (safe Rust API)
|
||||||
|
→ whisper-rs-sys (bindgen FFI)
|
||||||
|
→ whisper.cpp (C++ submodule, built via CMake)
|
||||||
|
→ GGML (tensor library with CPU/GPU backends)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Key types and their relationships:**
|
||||||
|
|
||||||
|
- `WhisperContext` — Arc-wrapped model handle, thread-safe, created from a model file
|
||||||
|
- `WhisperState` — Inference state created from a context; multiple states can share one context
|
||||||
|
- `FullParams` — Transcription configuration (sampling strategy, language, callbacks)
|
||||||
|
- `WhisperSegment` / `WhisperToken` — Result types with timestamps and probabilities
|
||||||
|
|
||||||
|
**Core flow:** `WhisperContext::new_with_params()` → `ctx.create_state()` → `state.full(params, &audio_data)` → iterate segments via `state.as_iter()`
|
||||||
|
|
||||||
|
**Module layout in `src/`:**
|
||||||
|
- `whisper_ctx.rs` — Raw context wrapper (`WhisperInnerContext`)
|
||||||
|
- `whisper_ctx_wrapper.rs` — Safe public `WhisperContext`
|
||||||
|
- `whisper_state/` — State management, segments, tokens, iterators
|
||||||
|
- `whisper_params.rs` — `FullParams`, `SamplingStrategy` (Greedy / BeamSearch)
|
||||||
|
- `whisper_vad.rs` — Voice Activity Detection
|
||||||
|
- `whisper_grammar.rs` — GBNF grammar-constrained decoding
|
||||||
|
- `vulkan.rs` — Vulkan device enumeration (behind `vulkan` feature)
|
||||||
|
|
||||||
|
## Build System (sys/build.rs)
|
||||||
|
|
||||||
|
The build script:
|
||||||
|
1. Copies `whisper.cpp` sources into `OUT_DIR`
|
||||||
|
2. Runs bindgen on `wrapper.h` to generate FFI bindings (falls back to `src/bindings.rs` on failure)
|
||||||
|
3. Configures and builds whisper.cpp via CMake with feature-dependent flags (`GGML_CUDA`, `GGML_HIP`, `GGML_VULKAN`, `GGML_METAL`, etc.)
|
||||||
|
4. Statically links: `whisper`, `ggml`, `ggml-base`, `ggml-cpu`, plus backend-specific libs
|
||||||
|
|
||||||
|
## Feature Flags
|
||||||
|
|
||||||
|
| Feature | Purpose |
|
||||||
|
|---------|---------|
|
||||||
|
| `cuda` | NVIDIA GPU (needs CUDA toolkit) |
|
||||||
|
| `hipblas` | AMD GPU via ROCm |
|
||||||
|
| `metal` | Apple Metal GPU |
|
||||||
|
| `vulkan` | Vulkan GPU |
|
||||||
|
| `openblas` | OpenBLAS acceleration (requires `BLAS_INCLUDE_DIRS` env var) |
|
||||||
|
| `openmp` | OpenMP threading |
|
||||||
|
| `coreml` | Apple CoreML |
|
||||||
|
| `intel-sycl` | Intel SYCL |
|
||||||
|
| `raw-api` | Re-export `whisper-rs-sys` types publicly |
|
||||||
|
| `log_backend` | Route C++ logs to the `log` crate |
|
||||||
|
| `tracing_backend` | Route C++ logs to the `tracing` crate |
|
||||||
|
|
||||||
|
## Nix Development Environment
|
||||||
|
|
||||||
|
The `flake.nix` provides a devshell with all dependencies for Vulkan and ROCm/hipBLAS builds. Use `direnv allow` or `nix develop` to enter the environment. Key env vars (`LIBCLANG_PATH`, `BINDGEN_EXTRA_CLANG_ARGS`, `HIP_PATH`, `VULKAN_SDK`) are set automatically.
|
||||||
|
|
||||||
|
## PR Conventions
|
||||||
|
|
||||||
|
Per `.github/PULL_REQUEST_TEMPLATE.md`:
|
||||||
|
- Run `cargo fmt` and `cargo clippy` before submitting
|
||||||
|
- Self-review code for legibility
|
||||||
|
- No GenAI-generated code in PRs
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
[workspace]
|
[workspace]
|
||||||
members = ["sys"]
|
members = ["sys", "wyoming-whisper-rs"]
|
||||||
exclude = ["examples/full_usage"]
|
exclude = ["examples/full_usage"]
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
|
|
|
||||||
27
flake.lock
generated
Normal file
27
flake.lock
generated
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1771482645,
|
||||||
|
"narHash": "sha256-MpAKyXfJRDTgRU33Hja+G+3h9ywLAJJNRq4Pjbb4dQs=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "724cf38d99ba81fbb4a347081db93e2e3a9bc2ae",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixpkgs-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
||||||
56
flake.nix
Normal file
56
flake.nix
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
{
|
||||||
|
inputs = {
|
||||||
|
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||||
|
};
|
||||||
|
|
||||||
|
outputs = { nixpkgs, ... }:
|
||||||
|
let
|
||||||
|
systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ];
|
||||||
|
forAllSystems = f: nixpkgs.lib.genAttrs systems (system: f nixpkgs.legacyPackages.${system});
|
||||||
|
in
|
||||||
|
{
|
||||||
|
devShells = forAllSystems (pkgs: {
|
||||||
|
default = pkgs.mkShell {
|
||||||
|
nativeBuildInputs = [
|
||||||
|
pkgs.rustc
|
||||||
|
pkgs.cargo
|
||||||
|
pkgs.clippy
|
||||||
|
pkgs.rustfmt
|
||||||
|
pkgs.cmake
|
||||||
|
pkgs.pkg-config
|
||||||
|
pkgs.shaderc
|
||||||
|
pkgs.rocmPackages.llvm.clang
|
||||||
|
];
|
||||||
|
|
||||||
|
buildInputs = [
|
||||||
|
pkgs.libclang.lib
|
||||||
|
pkgs.openssl
|
||||||
|
pkgs.vulkan-headers
|
||||||
|
pkgs.vulkan-loader
|
||||||
|
pkgs.rocmPackages.clr
|
||||||
|
pkgs.rocmPackages.hipblas
|
||||||
|
pkgs.rocmPackages.rocblas
|
||||||
|
pkgs.rocmPackages.rocm-runtime
|
||||||
|
];
|
||||||
|
|
||||||
|
env = {
|
||||||
|
LIBCLANG_PATH = "${pkgs.libclang.lib}/lib";
|
||||||
|
VULKAN_SDK = "${pkgs.vulkan-headers}";
|
||||||
|
HIP_PATH = "${pkgs.rocmPackages.clr}";
|
||||||
|
BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [
|
||||||
|
"-isystem ${pkgs.libclang.lib}/lib/clang/${pkgs.lib.versions.major pkgs.libclang.version}/include"
|
||||||
|
"-isystem ${pkgs.glibc.dev}/include"
|
||||||
|
];
|
||||||
|
};
|
||||||
|
|
||||||
|
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [
|
||||||
|
pkgs.vulkan-loader
|
||||||
|
pkgs.rocmPackages.clr
|
||||||
|
pkgs.rocmPackages.hipblas
|
||||||
|
pkgs.rocmPackages.rocblas
|
||||||
|
pkgs.rocmPackages.rocm-runtime
|
||||||
|
];
|
||||||
|
};
|
||||||
|
});
|
||||||
|
};
|
||||||
|
}
|
||||||
27
wyoming-whisper-rs/Cargo.toml
Normal file
27
wyoming-whisper-rs/Cargo.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
[package]
|
||||||
|
name = "wyoming-whisper-rs"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
description = "Wyoming protocol ASR server powered by whisper-rs"
|
||||||
|
license = "Unlicense"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
whisper-rs = { path = "..", features = ["tracing_backend"] }
|
||||||
|
tokio = { version = "1", features = ["rt-multi-thread", "net", "io-util", "macros", "signal", "sync"] }
|
||||||
|
clap = { version = "4", features = ["derive"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
thiserror = "2"
|
||||||
|
anyhow = "1"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = ["whisper-rs/cuda"]
|
||||||
|
hipblas = ["whisper-rs/hipblas"]
|
||||||
|
metal = ["whisper-rs/metal"]
|
||||||
|
vulkan = ["whisper-rs/vulkan"]
|
||||||
|
openblas = ["whisper-rs/openblas"]
|
||||||
|
openmp = ["whisper-rs/openmp"]
|
||||||
|
coreml = ["whisper-rs/coreml"]
|
||||||
49
wyoming-whisper-rs/src/cli.rs
Normal file
49
wyoming-whisper-rs/src/cli.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
||||||
|
use clap::Parser;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(
|
||||||
|
name = "wyoming-whisper-rs",
|
||||||
|
about = "Wyoming protocol ASR server powered by whisper-rs"
|
||||||
|
)]
|
||||||
|
pub struct Args {
|
||||||
|
/// Path to the GGML whisper model file
|
||||||
|
#[arg(long)]
|
||||||
|
pub model: PathBuf,
|
||||||
|
|
||||||
|
/// Host address to bind to
|
||||||
|
#[arg(long, default_value = "0.0.0.0")]
|
||||||
|
pub host: String,
|
||||||
|
|
||||||
|
/// Port to listen on
|
||||||
|
#[arg(long, default_value_t = 10300)]
|
||||||
|
pub port: u16,
|
||||||
|
|
||||||
|
/// Language code (e.g. "en", "de"). Omit for auto-detection
|
||||||
|
#[arg(long)]
|
||||||
|
pub language: Option<String>,
|
||||||
|
|
||||||
|
/// Beam search size
|
||||||
|
#[arg(long, default_value_t = 5)]
|
||||||
|
pub beam_size: i32,
|
||||||
|
|
||||||
|
/// Number of threads for whisper inference (0 = whisper default)
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
pub threads: i32,
|
||||||
|
|
||||||
|
/// GPU device index
|
||||||
|
#[arg(long, default_value_t = 0)]
|
||||||
|
pub gpu_device: i32,
|
||||||
|
|
||||||
|
/// Disable GPU acceleration
|
||||||
|
#[arg(long)]
|
||||||
|
pub no_gpu: bool,
|
||||||
|
|
||||||
|
/// Maximum concurrent transcriptions
|
||||||
|
#[arg(long, default_value_t = 1)]
|
||||||
|
pub max_concurrent: usize,
|
||||||
|
|
||||||
|
/// Model name reported in Wyoming info
|
||||||
|
#[arg(long, default_value = "whisper")]
|
||||||
|
pub model_name: String,
|
||||||
|
}
|
||||||
22
wyoming-whisper-rs/src/error.rs
Normal file
22
wyoming-whisper-rs/src/error.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
use std::io;
|
||||||
|
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("I/O error: {0}")]
|
||||||
|
Io(#[from] io::Error),
|
||||||
|
|
||||||
|
#[error("JSON error: {0}")]
|
||||||
|
Json(#[from] serde_json::Error),
|
||||||
|
|
||||||
|
#[error("Whisper error: {0}")]
|
||||||
|
Whisper(#[from] whisper_rs::WhisperError),
|
||||||
|
|
||||||
|
#[error("payload too large: {size} bytes (max {max})")]
|
||||||
|
PayloadTooLarge { size: usize, max: usize },
|
||||||
|
|
||||||
|
#[error("missing field: {0}")]
|
||||||
|
MissingField(&'static str),
|
||||||
|
|
||||||
|
#[error("invalid audio format: {0}")]
|
||||||
|
InvalidAudio(String),
|
||||||
|
}
|
||||||
111
wyoming-whisper-rs/src/main.rs
Normal file
111
wyoming-whisper-rs/src/main.rs
Normal file
|
|
@ -0,0 +1,111 @@
|
||||||
|
mod cli;
|
||||||
|
mod error;
|
||||||
|
mod protocol;
|
||||||
|
mod session;
|
||||||
|
mod transcribe;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use clap::Parser;
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::net::TcpListener;
|
||||||
|
use tracing::{error, info};
|
||||||
|
use whisper_rs::{WhisperContext, WhisperContextParameters};
|
||||||
|
|
||||||
|
use crate::cli::Args;
|
||||||
|
use crate::session::SessionConfig;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// Validate model path
|
||||||
|
if !args.model.exists() {
|
||||||
|
error!(path = %args.model.display(), "model file not found");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate language if specified
|
||||||
|
if let Some(ref lang) = args.language {
|
||||||
|
if whisper_rs::get_lang_id(lang).is_none() {
|
||||||
|
error!(language = lang, "unknown language code");
|
||||||
|
std::process::exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Route whisper.cpp logs through tracing
|
||||||
|
whisper_rs::install_logging_hooks();
|
||||||
|
|
||||||
|
// Load model
|
||||||
|
info!(model = %args.model.display(), "loading whisper model");
|
||||||
|
let mut ctx_params = WhisperContextParameters::default();
|
||||||
|
if args.no_gpu {
|
||||||
|
ctx_params.use_gpu(false);
|
||||||
|
}
|
||||||
|
ctx_params.gpu_device(args.gpu_device);
|
||||||
|
|
||||||
|
let model_path = args.model.to_string_lossy().to_string();
|
||||||
|
let ctx = WhisperContext::new_with_params(&model_path, ctx_params)?;
|
||||||
|
let ctx = Arc::new(ctx);
|
||||||
|
info!("model loaded");
|
||||||
|
|
||||||
|
// Build language list
|
||||||
|
let max_id = whisper_rs::get_lang_max_id();
|
||||||
|
let languages: Vec<serde_json::Value> = (0..=max_id)
|
||||||
|
.filter_map(|id| {
|
||||||
|
whisper_rs::get_lang_str(id).map(|code| {
|
||||||
|
json!({
|
||||||
|
"name": code,
|
||||||
|
"description": whisper_rs::get_lang_str_full(id).unwrap_or(code),
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let semaphore = Arc::new(tokio::sync::Semaphore::new(args.max_concurrent));
|
||||||
|
|
||||||
|
let session_config = Arc::new(SessionConfig {
|
||||||
|
ctx,
|
||||||
|
semaphore,
|
||||||
|
model_name: args.model_name,
|
||||||
|
languages,
|
||||||
|
default_language: args.language,
|
||||||
|
beam_size: args.beam_size,
|
||||||
|
threads: args.threads,
|
||||||
|
});
|
||||||
|
|
||||||
|
let bind_addr = format!("{}:{}", args.host, args.port);
|
||||||
|
let listener = TcpListener::bind(&bind_addr).await?;
|
||||||
|
info!(address = %bind_addr, "wyoming server listening");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
result = listener.accept() => {
|
||||||
|
let (stream, addr) = result?;
|
||||||
|
info!(peer = %addr, "new connection");
|
||||||
|
let config = Arc::clone(&session_config);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let (reader, writer) = stream.into_split();
|
||||||
|
if let Err(e) = session::run(config, reader, writer).await {
|
||||||
|
error!(peer = %addr, error = %e, "session error");
|
||||||
|
}
|
||||||
|
info!(peer = %addr, "connection closed");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
_ = tokio::signal::ctrl_c() => {
|
||||||
|
info!("shutting down");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
90
wyoming-whisper-rs/src/protocol.rs
Normal file
90
wyoming-whisper-rs/src/protocol.rs
Normal file
|
|
@ -0,0 +1,90 @@
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
|
||||||
|
|
||||||
|
use crate::error::Error;
|
||||||
|
|
||||||
|
const MAX_PAYLOAD: usize = 100 * 1024 * 1024; // 100 MB
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Event {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub event_type: String,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||||
|
pub data: Option<Value>,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "is_zero")]
|
||||||
|
pub data_length: usize,
|
||||||
|
|
||||||
|
#[serde(default, skip_serializing_if = "is_zero")]
|
||||||
|
pub payload_length: usize,
|
||||||
|
|
||||||
|
#[serde(skip)]
|
||||||
|
pub payload: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_zero(v: &usize) -> bool {
|
||||||
|
*v == 0
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Event {
|
||||||
|
pub fn new(event_type: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
event_type: event_type.into(),
|
||||||
|
data: None,
|
||||||
|
data_length: 0,
|
||||||
|
payload_length: 0,
|
||||||
|
payload: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_data(mut self, data: Value) -> Self {
|
||||||
|
let serialized = serde_json::to_string(&data).unwrap_or_default();
|
||||||
|
self.data_length = serialized.len();
|
||||||
|
self.data = Some(data);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn read_event<R: tokio::io::AsyncRead + Unpin>(
|
||||||
|
reader: &mut BufReader<R>,
|
||||||
|
) -> Result<Option<Event>, Error> {
|
||||||
|
let mut line = String::new();
|
||||||
|
let n = reader.read_line(&mut line).await?;
|
||||||
|
if n == 0 {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut event: Event = serde_json::from_str(line.trim())?;
|
||||||
|
|
||||||
|
if event.payload_length > 0 {
|
||||||
|
if event.payload_length > MAX_PAYLOAD {
|
||||||
|
return Err(Error::PayloadTooLarge {
|
||||||
|
size: event.payload_length,
|
||||||
|
max: MAX_PAYLOAD,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
let mut buf = vec![0u8; event.payload_length];
|
||||||
|
reader.read_exact(&mut buf).await?;
|
||||||
|
event.payload = Some(buf);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(event))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn write_event<W: tokio::io::AsyncWrite + Unpin>(
|
||||||
|
writer: &mut W,
|
||||||
|
event: &Event,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let json = serde_json::to_string(event)?;
|
||||||
|
writer.write_all(json.as_bytes()).await?;
|
||||||
|
writer.write_all(b"\n").await?;
|
||||||
|
|
||||||
|
if let Some(ref payload) = event.payload {
|
||||||
|
writer.write_all(payload).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
writer.flush().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
208
wyoming-whisper-rs/src/session.rs
Normal file
208
wyoming-whisper-rs/src/session.rs
Normal file
|
|
@ -0,0 +1,208 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use serde_json::{json, Value};
|
||||||
|
use tokio::io::{AsyncRead, AsyncWrite, BufReader};
|
||||||
|
use tokio::sync::Semaphore;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
use whisper_rs::WhisperContext;
|
||||||
|
|
||||||
|
use crate::error::Error;
|
||||||
|
use crate::protocol::{read_event, write_event, Event};
|
||||||
|
use crate::transcribe::{AudioBuffer, TranscribeConfig};
|
||||||
|
|
||||||
|
pub struct SessionConfig {
|
||||||
|
pub ctx: Arc<WhisperContext>,
|
||||||
|
pub semaphore: Arc<Semaphore>,
|
||||||
|
pub model_name: String,
|
||||||
|
pub languages: Vec<Value>,
|
||||||
|
pub default_language: Option<String>,
|
||||||
|
pub beam_size: i32,
|
||||||
|
pub threads: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
enum State {
|
||||||
|
Idle,
|
||||||
|
AwaitingAudioStart {
|
||||||
|
language: Option<String>,
|
||||||
|
},
|
||||||
|
Streaming {
|
||||||
|
buffer: AudioBuffer,
|
||||||
|
language: Option<String>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl State {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
State::Idle => "idle",
|
||||||
|
State::AwaitingAudioStart { .. } => "awaiting_audio_start",
|
||||||
|
State::Streaming { .. } => "streaming",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn run<R, W>(config: Arc<SessionConfig>, reader: R, mut writer: W) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
R: AsyncRead + Unpin,
|
||||||
|
W: AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let mut reader = BufReader::new(reader);
|
||||||
|
let mut state = State::Idle;
|
||||||
|
|
||||||
|
loop {
|
||||||
|
let event = match read_event(&mut reader).await? {
|
||||||
|
Some(e) => e,
|
||||||
|
None => {
|
||||||
|
debug!("client disconnected");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
debug!(event_type = %event.event_type, "received event");
|
||||||
|
|
||||||
|
match event.event_type.as_str() {
|
||||||
|
"describe" => {
|
||||||
|
let info = Event::new("info").with_data(json!({
|
||||||
|
"type": "asr",
|
||||||
|
"name": config.model_name,
|
||||||
|
"description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION),
|
||||||
|
"installed": true,
|
||||||
|
"attribution": {
|
||||||
|
"name": "whisper.cpp",
|
||||||
|
"url": "https://github.com/ggerganov/whisper.cpp"
|
||||||
|
},
|
||||||
|
"models": [{
|
||||||
|
"name": config.model_name,
|
||||||
|
"description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION),
|
||||||
|
"installed": true,
|
||||||
|
"attribution": {
|
||||||
|
"name": "whisper.cpp",
|
||||||
|
"url": "https://github.com/ggerganov/whisper.cpp"
|
||||||
|
},
|
||||||
|
"languages": config.languages,
|
||||||
|
}],
|
||||||
|
"version": env!("CARGO_PKG_VERSION"),
|
||||||
|
}));
|
||||||
|
write_event(&mut writer, &info).await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
"transcribe" => {
|
||||||
|
let language = event
|
||||||
|
.data
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|d| d.get("language"))
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(String::from)
|
||||||
|
.or_else(|| config.default_language.clone());
|
||||||
|
|
||||||
|
state = State::AwaitingAudioStart { language };
|
||||||
|
}
|
||||||
|
|
||||||
|
"audio-start" => {
|
||||||
|
let (rate, width, channels) = parse_audio_start(&event)?;
|
||||||
|
let language = match state {
|
||||||
|
State::AwaitingAudioStart { language } => language,
|
||||||
|
_ => config.default_language.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
state = State::Streaming {
|
||||||
|
buffer: AudioBuffer::new(rate, width, channels),
|
||||||
|
language,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
"audio-chunk" => {
|
||||||
|
if let State::Streaming { ref mut buffer, .. } = state {
|
||||||
|
if let Some(ref payload) = event.payload {
|
||||||
|
buffer.append(payload);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(state = state.name(), "audio-chunk received in wrong state");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
"audio-stop" => {
|
||||||
|
if let State::Streaming { buffer, language } = state {
|
||||||
|
state = State::Idle;
|
||||||
|
|
||||||
|
let tc = TranscribeConfig {
|
||||||
|
language,
|
||||||
|
beam_size: config.beam_size,
|
||||||
|
threads: config.threads,
|
||||||
|
};
|
||||||
|
|
||||||
|
match do_transcribe(&config.ctx, &config.semaphore, tc, buffer).await {
|
||||||
|
Ok(text) => {
|
||||||
|
let transcript = Event::new("transcript").with_data(json!({
|
||||||
|
"text": text,
|
||||||
|
}));
|
||||||
|
write_event(&mut writer, &transcript).await?;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(error = %e, "transcription failed");
|
||||||
|
let err = Event::new("error").with_data(json!({
|
||||||
|
"text": format!("transcription failed: {e}"),
|
||||||
|
}));
|
||||||
|
write_event(&mut writer, &err).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
warn!(state = state.name(), "audio-stop received in wrong state");
|
||||||
|
state = State::Idle;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
other => {
|
||||||
|
warn!(event_type = other, "unknown event type");
|
||||||
|
let err = Event::new("error").with_data(json!({
|
||||||
|
"text": format!("unknown event type: {other}"),
|
||||||
|
}));
|
||||||
|
write_event(&mut writer, &err).await?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_audio_start(event: &Event) -> Result<(u32, u16, u16), Error> {
|
||||||
|
let data = event
|
||||||
|
.data
|
||||||
|
.as_ref()
|
||||||
|
.ok_or(Error::MissingField("data in audio-start"))?;
|
||||||
|
|
||||||
|
let rate = data
|
||||||
|
.get("rate")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.ok_or(Error::MissingField("rate"))? as u32;
|
||||||
|
|
||||||
|
let width = data
|
||||||
|
.get("width")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.ok_or(Error::MissingField("width"))? as u16;
|
||||||
|
|
||||||
|
let channels = data
|
||||||
|
.get("channels")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.ok_or(Error::MissingField("channels"))? as u16;
|
||||||
|
|
||||||
|
Ok((rate, width, channels))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn do_transcribe(
|
||||||
|
ctx: &Arc<WhisperContext>,
|
||||||
|
semaphore: &Arc<Semaphore>,
|
||||||
|
config: TranscribeConfig,
|
||||||
|
buffer: AudioBuffer,
|
||||||
|
) -> Result<String, Error> {
|
||||||
|
let audio = buffer.into_f32_16khz_mono()?;
|
||||||
|
|
||||||
|
if audio.is_empty() {
|
||||||
|
return Ok(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let _permit = semaphore.acquire().await.expect("semaphore closed");
|
||||||
|
let ctx = Arc::clone(ctx);
|
||||||
|
|
||||||
|
tokio::task::spawn_blocking(move || crate::transcribe::transcribe(&ctx, &config, audio))
|
||||||
|
.await
|
||||||
|
.expect("transcription task panicked")
|
||||||
|
}
|
||||||
151
wyoming-whisper-rs/src/transcribe.rs
Normal file
151
wyoming-whisper-rs/src/transcribe.rs
Normal file
|
|
@ -0,0 +1,151 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
|
||||||
|
|
||||||
|
use crate::error::Error;
|
||||||
|
|
||||||
|
pub struct TranscribeConfig {
|
||||||
|
pub language: Option<String>,
|
||||||
|
pub beam_size: i32,
|
||||||
|
pub threads: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct AudioBuffer {
|
||||||
|
data: Vec<u8>,
|
||||||
|
rate: u32,
|
||||||
|
width: u16,
|
||||||
|
channels: u16,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioBuffer {
|
||||||
|
pub fn new(rate: u32, width: u16, channels: u16) -> Self {
|
||||||
|
Self {
|
||||||
|
data: Vec::new(),
|
||||||
|
rate,
|
||||||
|
width,
|
||||||
|
channels,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn append(&mut self, chunk: &[u8]) {
|
||||||
|
self.data.extend_from_slice(chunk);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_f32_16khz_mono(self) -> Result<Vec<f32>, Error> {
|
||||||
|
if self.width != 2 {
|
||||||
|
return Err(Error::InvalidAudio(format!(
|
||||||
|
"expected 16-bit audio (width=2), got width={}",
|
||||||
|
self.width
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.data.len().is_multiple_of(2) {
|
||||||
|
return Err(Error::InvalidAudio(
|
||||||
|
"audio data has odd number of bytes for 16-bit samples".into(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interpret as i16 little-endian
|
||||||
|
let samples_i16: Vec<i16> = self
|
||||||
|
.data
|
||||||
|
.chunks_exact(2)
|
||||||
|
.map(|c| i16::from_le_bytes([c[0], c[1]]))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Convert i16 -> f32
|
||||||
|
let mut samples_f32 = vec![0.0f32; samples_i16.len()];
|
||||||
|
whisper_rs::convert_integer_to_float_audio(&samples_i16, &mut samples_f32)
|
||||||
|
.map_err(|e| Error::InvalidAudio(format!("i16 to f32 conversion failed: {e}")))?;
|
||||||
|
|
||||||
|
// Convert stereo to mono if needed
|
||||||
|
let mono = if self.channels == 2 {
|
||||||
|
let mut mono = vec![0.0f32; samples_f32.len() / 2];
|
||||||
|
whisper_rs::convert_stereo_to_mono_audio(&samples_f32, &mut mono)
|
||||||
|
.map_err(|e| Error::InvalidAudio(format!("stereo to mono failed: {e}")))?;
|
||||||
|
mono
|
||||||
|
} else if self.channels == 1 {
|
||||||
|
samples_f32
|
||||||
|
} else {
|
||||||
|
return Err(Error::InvalidAudio(format!(
|
||||||
|
"unsupported channel count: {}",
|
||||||
|
self.channels
|
||||||
|
)));
|
||||||
|
};
|
||||||
|
|
||||||
|
// Resample if not 16kHz
|
||||||
|
if self.rate == 16000 {
|
||||||
|
Ok(mono)
|
||||||
|
} else {
|
||||||
|
Ok(resample(&mono, self.rate, 16000))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Simple linear interpolation resampler.
|
||||||
|
fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
|
||||||
|
if from_rate == to_rate || input.is_empty() {
|
||||||
|
return input.to_vec();
|
||||||
|
}
|
||||||
|
|
||||||
|
let ratio = from_rate as f64 / to_rate as f64;
|
||||||
|
let output_len = ((input.len() as f64) / ratio).ceil() as usize;
|
||||||
|
let mut output = Vec::with_capacity(output_len);
|
||||||
|
|
||||||
|
for i in 0..output_len {
|
||||||
|
let src_pos = i as f64 * ratio;
|
||||||
|
let idx = src_pos as usize;
|
||||||
|
let frac = src_pos - idx as f64;
|
||||||
|
|
||||||
|
let sample = if idx + 1 < input.len() {
|
||||||
|
input[idx] as f64 * (1.0 - frac) + input[idx + 1] as f64 * frac
|
||||||
|
} else {
|
||||||
|
input[idx.min(input.len() - 1)] as f64
|
||||||
|
};
|
||||||
|
|
||||||
|
output.push(sample as f32);
|
||||||
|
}
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn transcribe(
|
||||||
|
ctx: &Arc<WhisperContext>,
|
||||||
|
config: &TranscribeConfig,
|
||||||
|
audio: Vec<f32>,
|
||||||
|
) -> Result<String, Error> {
|
||||||
|
let mut state = ctx.create_state()?;
|
||||||
|
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::BeamSearch {
|
||||||
|
beam_size: config.beam_size,
|
||||||
|
patience: -1.0,
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(ref lang) = config.language {
|
||||||
|
params.set_language(Some(lang));
|
||||||
|
} else {
|
||||||
|
params.set_language(None);
|
||||||
|
params.set_detect_language(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.threads > 0 {
|
||||||
|
params.set_n_threads(config.threads);
|
||||||
|
}
|
||||||
|
|
||||||
|
params.set_print_special(false);
|
||||||
|
params.set_print_progress(false);
|
||||||
|
params.set_print_realtime(false);
|
||||||
|
params.set_print_timestamps(false);
|
||||||
|
params.set_no_context(true);
|
||||||
|
params.set_single_segment(false);
|
||||||
|
|
||||||
|
state.full(params, &audio)?;
|
||||||
|
|
||||||
|
let mut text = String::new();
|
||||||
|
for segment in state.as_iter() {
|
||||||
|
if let Ok(s) = segment.to_str_lossy() {
|
||||||
|
text.push_str(&s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(text.trim().to_string())
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue