diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.gitignore b/.gitignore index 902abe5..8e7d90a 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ **/target -**/Cargo.lock /.idea /.vscode +/.direnv *.bin *.wav \ No newline at end of file diff --git a/.rustfmt.toml b/.rustfmt.toml new file mode 100644 index 0000000..e69de29 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..2c44afd --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,97 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +whisper-rs provides safe Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ speech recognition library. It's a two-crate workspace: + +- **`whisper-rs`** (root) — Safe public API +- **`whisper-rs-sys`** (`sys/`) — FFI bindings generated via bindgen, with a CMake-based build of the whisper.cpp submodule + +The upstream C++ source lives in `sys/whisper.cpp/` (git submodule — clone with `--recursive`). + +## Build Commands + +```bash +cargo build # Default build +cargo build --release --features vulkan # With Vulkan GPU support +cargo build --release --features hipblas # With AMD ROCm support +cargo build --release --features cuda # With NVIDIA CUDA support +cargo test # Run all tests +cargo fmt # Format code +cargo clippy # Lint +``` + +**Running examples** (require a GGML model file and a WAV audio file): +```bash +cargo run --example basic_use -- model.bin audio.wav +cargo run --example audio_transcription -- model.bin audio.wav +cargo run --example vad -- model.bin audio.wav output.wav +``` + +**Skipping bindgen** (use pre-generated bindings): set `WHISPER_DONT_GENERATE_BINDINGS=1`. + +All `WHISPER_*` and `CMAKE_*` env vars are forwarded to the CMake build. + +## Architecture + +``` +whisper-rs (safe Rust API) + → whisper-rs-sys (bindgen FFI) + → whisper.cpp (C++ submodule, built via CMake) + → GGML (tensor library with CPU/GPU backends) +``` + +**Key types and their relationships:** + +- `WhisperContext` — Arc-wrapped model handle, thread-safe, created from a model file +- `WhisperState` — Inference state created from a context; multiple states can share one context +- `FullParams` — Transcription configuration (sampling strategy, language, callbacks) +- `WhisperSegment` / `WhisperToken` — Result types with timestamps and probabilities + +**Core flow:** `WhisperContext::new_with_params()` → `ctx.create_state()` → `state.full(params, &audio_data)` → iterate segments via `state.as_iter()` + +**Module layout in `src/`:** +- `whisper_ctx.rs` — Raw context wrapper (`WhisperInnerContext`) +- `whisper_ctx_wrapper.rs` — Safe public `WhisperContext` +- `whisper_state/` — State management, segments, tokens, iterators +- `whisper_params.rs` — `FullParams`, `SamplingStrategy` (Greedy / BeamSearch) +- `whisper_vad.rs` — Voice Activity Detection +- `whisper_grammar.rs` — GBNF grammar-constrained decoding +- `vulkan.rs` — Vulkan device enumeration (behind `vulkan` feature) + +## Build System (sys/build.rs) + +The build script: +1. Copies `whisper.cpp` sources into `OUT_DIR` +2. Runs bindgen on `wrapper.h` to generate FFI bindings (falls back to `src/bindings.rs` on failure) +3. Configures and builds whisper.cpp via CMake with feature-dependent flags (`GGML_CUDA`, `GGML_HIP`, `GGML_VULKAN`, `GGML_METAL`, etc.) +4. Statically links: `whisper`, `ggml`, `ggml-base`, `ggml-cpu`, plus backend-specific libs + +## Feature Flags + +| Feature | Purpose | +|---------|---------| +| `cuda` | NVIDIA GPU (needs CUDA toolkit) | +| `hipblas` | AMD GPU via ROCm | +| `metal` | Apple Metal GPU | +| `vulkan` | Vulkan GPU | +| `openblas` | OpenBLAS acceleration (requires `BLAS_INCLUDE_DIRS` env var) | +| `openmp` | OpenMP threading | +| `coreml` | Apple CoreML | +| `intel-sycl` | Intel SYCL | +| `raw-api` | Re-export `whisper-rs-sys` types publicly | +| `log_backend` | Route C++ logs to the `log` crate | +| `tracing_backend` | Route C++ logs to the `tracing` crate | + +## Nix Development Environment + +The `flake.nix` provides a devshell with all dependencies for Vulkan and ROCm/hipBLAS builds. Use `direnv allow` or `nix develop` to enter the environment. Key env vars (`LIBCLANG_PATH`, `BINDGEN_EXTRA_CLANG_ARGS`, `HIP_PATH`, `VULKAN_SDK`) are set automatically. + +## PR Conventions + +Per `.github/PULL_REQUEST_TEMPLATE.md`: +- Run `cargo fmt` and `cargo clippy` before submitting +- Self-review code for legibility +- No GenAI-generated code in PRs diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..0d0f661 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,862 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 4 + +[[package]] +name = "aho-corasick" +version = "1.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" +dependencies = [ + "memchr", +] + +[[package]] +name = "anstream" +version = "0.6.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" + +[[package]] +name = "anstyle-parse" +version = "0.2.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" +dependencies = [ + "anstyle", + "once_cell_polyfill", + "windows-sys 0.61.2", +] + +[[package]] +name = "anyhow" +version = "1.0.102" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" + +[[package]] +name = "bindgen" +version = "0.71.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" +dependencies = [ + "bitflags", + "cexpr", + "clang-sys", + "itertools", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn", +] + +[[package]] +name = "bitflags" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" + +[[package]] +name = "bytes" +version = "1.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" + +[[package]] +name = "cc" +version = "1.2.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" +dependencies = [ + "find-msvc-tools", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + +[[package]] +name = "clap" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.55" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" + +[[package]] +name = "cmake" +version = "0.1.57" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" +dependencies = [ + "cc", +] + +[[package]] +name = "colorchoice" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" + +[[package]] +name = "either" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" + +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "getrandom" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "glob" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "hound" +version = "3.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" + +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" + +[[package]] +name = "lazy_static" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" + +[[package]] +name = "libc" +version = "0.2.182" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" + +[[package]] +name = "libloading" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" +dependencies = [ + "cfg-if", + "windows-link", +] + +[[package]] +name = "log" +version = "0.4.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" + +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + +[[package]] +name = "memchr" +version = "2.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" + +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + +[[package]] +name = "mio" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" +dependencies = [ + "libc", + "wasi", + "windows-sys 0.61.2", +] + +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + +[[package]] +name = "nu-ansi-term" +version = "0.50.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "once_cell" +version = "1.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" + +[[package]] +name = "once_cell_polyfill" +version = "1.70.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" + +[[package]] +name = "pin-project-lite" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" + +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn", +] + +[[package]] +name = "proc-macro2" +version = "1.0.106" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "regex" +version = "1.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "serde" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" +dependencies = [ + "serde_core", + "serde_derive", +] + +[[package]] +name = "serde_core" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.149" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +dependencies = [ + "itoa", + "memchr", + "serde", + "serde_core", + "zmij", +] + +[[package]] +name = "sharded-slab" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" +dependencies = [ + "lazy_static", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "signal-hook-registry" +version = "1.4.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" +dependencies = [ + "errno", + "libc", +] + +[[package]] +name = "smallvec" +version = "1.15.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" + +[[package]] +name = "socket2" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" +dependencies = [ + "libc", + "windows-sys 0.60.2", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.117" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "thread_local" +version = "1.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "tokio" +version = "1.49.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" +dependencies = [ + "bytes", + "libc", + "mio", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys 0.61.2", +] + +[[package]] +name = "tokio-macros" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing" +version = "0.1.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" +dependencies = [ + "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", +] + +[[package]] +name = "tracing-subscriber" +version = "0.3.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" +dependencies = [ + "matchers", + "nu-ansi-term", + "once_cell", + "regex-automata", + "sharded-slab", + "smallvec", + "thread_local", + "tracing", + "tracing-core", + "tracing-log", +] + +[[package]] +name = "unicode-ident" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + +[[package]] +name = "wasi" +version = "0.11.1+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" + +[[package]] +name = "whisper-rs" +version = "0.15.1" +dependencies = [ + "hound", + "libc", + "log", + "rand", + "tracing", + "whisper-rs-sys", +] + +[[package]] +name = "whisper-rs-sys" +version = "0.14.1" +dependencies = [ + "bindgen", + "cfg-if", + "cmake", + "fs_extra", +] + +[[package]] +name = "windows-link" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" + +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.61.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-targets" +version = "0.53.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" +dependencies = [ + "windows-link", + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" + +[[package]] +name = "windows_i686_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" + +[[package]] +name = "windows_i686_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" + +[[package]] +name = "wyoming-whisper-rs" +version = "0.1.0" +dependencies = [ + "anyhow", + "clap", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", + "tracing-subscriber", + "whisper-rs", +] + +[[package]] +name = "zerocopy" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "zmij" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml index 24849b6..685b1ba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["sys"] +members = ["sys", "wyoming-whisper-rs"] exclude = ["examples/full_usage"] [package] diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 53d2aab..54e7d5d 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -51,15 +51,16 @@ fn main() { // note that you don't need to use these, you can do it yourself or any other way you want // these are just provided for convenience let mut inter_samples = vec![Default::default(); samples.len()]; + let mut mono_samples = vec![Default::default(); samples.len() / 2]; whisper_rs::convert_integer_to_float_audio(&samples, &mut inter_samples) .expect("failed to convert audio data"); - let samples = whisper_rs::convert_stereo_to_mono_audio(&inter_samples) + whisper_rs::convert_stereo_to_mono_audio(&inter_samples, &mut mono_samples) .expect("failed to convert audio data"); // now we can run the model state - .full(params, &samples[..]) + .full(params, &mono_samples[..]) .expect("failed to run model"); // fetch the results diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..516965b --- /dev/null +++ b/flake.lock @@ -0,0 +1,27 @@ +{ + "nodes": { + "nixpkgs": { + "locked": { + "lastModified": 1771482645, + "narHash": "sha256-MpAKyXfJRDTgRU33Hja+G+3h9ywLAJJNRq4Pjbb4dQs=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "724cf38d99ba81fbb4a347081db93e2e3a9bc2ae", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "nixpkgs": "nixpkgs" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..9c99811 --- /dev/null +++ b/flake.nix @@ -0,0 +1,60 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + }; + + outputs = { self, nixpkgs, ... }: + let + systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; + forAllSystems = f: nixpkgs.lib.genAttrs systems (system: f nixpkgs.legacyPackages.${system}); + in + { + packages = forAllSystems (pkgs: { + default = pkgs.callPackage ./wyoming-whisper-rs/package.nix { src = self; }; + }); + + devShells = forAllSystems (pkgs: { + default = pkgs.mkShell { + nativeBuildInputs = [ + pkgs.rustc + pkgs.cargo + pkgs.clippy + pkgs.rustfmt + pkgs.cmake + pkgs.pkg-config + pkgs.shaderc + pkgs.rocmPackages.llvm.clang + ]; + + buildInputs = [ + pkgs.libclang.lib + pkgs.openssl + pkgs.vulkan-headers + pkgs.vulkan-loader + pkgs.rocmPackages.clr + pkgs.rocmPackages.hipblas + pkgs.rocmPackages.rocblas + pkgs.rocmPackages.rocm-runtime + ]; + + env = { + LIBCLANG_PATH = "${pkgs.libclang.lib}/lib"; + VULKAN_SDK = "${pkgs.vulkan-headers}"; + HIP_PATH = "${pkgs.rocmPackages.clr}"; + BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [ + "-isystem ${pkgs.libclang.lib}/lib/clang/${pkgs.lib.versions.major pkgs.libclang.version}/include" + "-isystem ${pkgs.glibc.dev}/include" + ]; + }; + + LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [ + pkgs.vulkan-loader + pkgs.rocmPackages.clr + pkgs.rocmPackages.hipblas + pkgs.rocmPackages.rocblas + pkgs.rocmPackages.rocm-runtime + ]; + }; + }); + }; +} diff --git a/src/utilities.rs b/src/utilities.rs index 2e11b7c..cf30539 100644 --- a/src/utilities.rs +++ b/src/utilities.rs @@ -51,7 +51,8 @@ pub fn convert_integer_to_float_audio( /// ``` /// # use whisper_rs::convert_stereo_to_mono_audio; /// let samples = [0.0f32; 1024]; -/// let mono = convert_stereo_to_mono_audio(&samples).expect("should be no half samples missing"); +/// let mut mono_samples = [0.0f32; 512]; +/// convert_stereo_to_mono_audio(&samples, &mut mono_samples).expect("should be no half samples missing"); /// ``` pub fn convert_stereo_to_mono_audio(input: &[f32], output: &mut [f32]) -> Result<(), WhisperError> { let (input, []) = input.as_chunks::<2>() else { diff --git a/src/whisper_params.rs b/src/whisper_params.rs index b51b602..846af3b 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -801,7 +801,7 @@ impl<'a, 'b> FullParams<'a, 'b> { /// # Examples /// ``` /// # use whisper_rs::{FullParams, SamplingStrategy}; - /// let mut params = FullParams::new(SamplingStrategy::default()); + /// let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 5 }); /// params.set_initial_prompt("Hello, world!"); /// // ... further usage of params ... /// ``` diff --git a/src/whisper_state/mod.rs b/src/whisper_state/mod.rs index 2210b1f..2b93b4a 100644 --- a/src/whisper_state/mod.rs +++ b/src/whisper_state/mod.rs @@ -280,7 +280,7 @@ impl WhisperState { /// See utilities in the root of this crate for functions to convert audio to this format. /// /// # Returns - /// Ok(c_int) on success, Err(WhisperError) on failure. + /// Ok(()) on success, Err(WhisperError) on failure. /// /// # C++ equivalent /// `int whisper_full_with_state( @@ -289,7 +289,7 @@ impl WhisperState { /// struct whisper_full_params params, /// const float * samples, /// int n_samples)` - pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result { + pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<(), WhisperError> { if data.is_empty() { // can randomly trigger segmentation faults if we don't check this return Err(WhisperError::NoSamples); @@ -311,7 +311,7 @@ impl WhisperState { } else if ret == 8 { Err(WhisperError::FailedToDecode) } else if ret == 0 { - Ok(ret) + Ok(()) } else { Err(WhisperError::GenericError(ret)) } diff --git a/sys/build.rs b/sys/build.rs index 86d23d8..a7f411a 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -102,7 +102,7 @@ fn main() { println!("cargo:rerun-if-changed=wrapper.h"); let out = PathBuf::from(env::var("OUT_DIR").unwrap()); - let whisper_root = out.join("whisper.cpp/"); + let whisper_root = out.join("whisper.cpp"); if !whisper_root.exists() { std::fs::create_dir_all(&whisper_root).unwrap(); diff --git a/sys/src/bindings.rs b/sys/src/bindings.rs index 9f4be14..6ac1bb5 100644 --- a/sys/src/bindings.rs +++ b/sys/src/bindings.rs @@ -183,7 +183,7 @@ pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; pub const __STDC_ISO_10646__: u32 = 201706; pub const __GNU_LIBRARY__: u32 = 6; pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 41; +pub const __GLIBC_MINOR__: u32 = 42; pub const _SYS_CDEFS_H: u32 = 1; pub const __glibc_c99_flexarr_available: u32 = 1; pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; @@ -302,9 +302,11 @@ pub const GGML_DEFAULT_GRAPH_SIZE: u32 = 2048; pub const GGML_MEM_ALIGN: u32 = 16; pub const GGML_EXIT_SUCCESS: u32 = 0; pub const GGML_EXIT_ABORTED: u32 = 1; +pub const GGML_ROPE_TYPE_NORMAL: u32 = 0; pub const GGML_ROPE_TYPE_NEOX: u32 = 2; pub const GGML_ROPE_TYPE_MROPE: u32 = 8; pub const GGML_ROPE_TYPE_VISION: u32 = 24; +pub const GGML_ROPE_TYPE_IMROPE: u32 = 40; pub const GGML_MROPE_SECTIONS: u32 = 4; pub const GGML_KQ_MASK_PAD: u32 = 64; pub const GGML_N_TASKS_MAX: i32 = -1; @@ -534,7 +536,9 @@ pub struct _IO_FILE { pub _freeres_buf: *mut ::std::os::raw::c_void, pub _prevchain: *mut *mut _IO_FILE, pub _mode: ::std::os::raw::c_int, - pub _unused2: [::std::os::raw::c_char; 20usize], + pub _unused3: ::std::os::raw::c_int, + pub _total_written: __uint64_t, + pub _unused2: [::std::os::raw::c_char; 8usize], } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { @@ -588,7 +592,10 @@ const _: () = { ["Offset of field: _IO_FILE::_prevchain"] [::std::mem::offset_of!(_IO_FILE, _prevchain) - 184usize]; ["Offset of field: _IO_FILE::_mode"][::std::mem::offset_of!(_IO_FILE, _mode) - 192usize]; - ["Offset of field: _IO_FILE::_unused2"][::std::mem::offset_of!(_IO_FILE, _unused2) - 196usize]; + ["Offset of field: _IO_FILE::_unused3"][::std::mem::offset_of!(_IO_FILE, _unused3) - 196usize]; + ["Offset of field: _IO_FILE::_total_written"] + [::std::mem::offset_of!(_IO_FILE, _total_written) - 200usize]; + ["Offset of field: _IO_FILE::_unused2"][::std::mem::offset_of!(_IO_FILE, _unused2) - 208usize]; }; impl _IO_FILE { #[inline] @@ -1314,79 +1321,85 @@ pub const ggml_op_GGML_OP_SIN: ggml_op = 12; pub const ggml_op_GGML_OP_COS: ggml_op = 13; pub const ggml_op_GGML_OP_SUM: ggml_op = 14; pub const ggml_op_GGML_OP_SUM_ROWS: ggml_op = 15; -pub const ggml_op_GGML_OP_MEAN: ggml_op = 16; -pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 17; -pub const ggml_op_GGML_OP_COUNT_EQUAL: ggml_op = 18; -pub const ggml_op_GGML_OP_REPEAT: ggml_op = 19; -pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 20; -pub const ggml_op_GGML_OP_CONCAT: ggml_op = 21; -pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 22; -pub const ggml_op_GGML_OP_NORM: ggml_op = 23; -pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 24; -pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 25; -pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 26; -pub const ggml_op_GGML_OP_L2_NORM: ggml_op = 27; -pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 28; -pub const ggml_op_GGML_OP_MUL_MAT_ID: ggml_op = 29; -pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 30; -pub const ggml_op_GGML_OP_SCALE: ggml_op = 31; -pub const ggml_op_GGML_OP_SET: ggml_op = 32; -pub const ggml_op_GGML_OP_CPY: ggml_op = 33; -pub const ggml_op_GGML_OP_CONT: ggml_op = 34; -pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 35; -pub const ggml_op_GGML_OP_VIEW: ggml_op = 36; -pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 37; -pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 38; -pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 39; -pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 40; -pub const ggml_op_GGML_OP_SET_ROWS: ggml_op = 41; -pub const ggml_op_GGML_OP_DIAG: ggml_op = 42; -pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 43; -pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 44; -pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 45; -pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 46; -pub const ggml_op_GGML_OP_ROPE: ggml_op = 47; -pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 48; -pub const ggml_op_GGML_OP_CLAMP: ggml_op = 49; -pub const ggml_op_GGML_OP_CONV_TRANSPOSE_1D: ggml_op = 50; -pub const ggml_op_GGML_OP_IM2COL: ggml_op = 51; -pub const ggml_op_GGML_OP_IM2COL_BACK: ggml_op = 52; -pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 53; -pub const ggml_op_GGML_OP_CONV_2D_DW: ggml_op = 54; -pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 55; -pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 56; -pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 57; -pub const ggml_op_GGML_OP_POOL_2D_BACK: ggml_op = 58; -pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 59; -pub const ggml_op_GGML_OP_PAD: ggml_op = 60; -pub const ggml_op_GGML_OP_PAD_REFLECT_1D: ggml_op = 61; -pub const ggml_op_GGML_OP_ROLL: ggml_op = 62; -pub const ggml_op_GGML_OP_ARANGE: ggml_op = 63; -pub const ggml_op_GGML_OP_TIMESTEP_EMBEDDING: ggml_op = 64; -pub const ggml_op_GGML_OP_ARGSORT: ggml_op = 65; -pub const ggml_op_GGML_OP_LEAKY_RELU: ggml_op = 66; -pub const ggml_op_GGML_OP_FLASH_ATTN_EXT: ggml_op = 67; -pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 68; -pub const ggml_op_GGML_OP_SSM_CONV: ggml_op = 69; -pub const ggml_op_GGML_OP_SSM_SCAN: ggml_op = 70; -pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 71; -pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 72; -pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 73; -pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 74; -pub const ggml_op_GGML_OP_RWKV_WKV6: ggml_op = 75; -pub const ggml_op_GGML_OP_GATED_LINEAR_ATTN: ggml_op = 76; -pub const ggml_op_GGML_OP_RWKV_WKV7: ggml_op = 77; -pub const ggml_op_GGML_OP_UNARY: ggml_op = 78; -pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 79; -pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 80; -pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 81; -pub const ggml_op_GGML_OP_CUSTOM: ggml_op = 82; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 83; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 84; -pub const ggml_op_GGML_OP_OPT_STEP_ADAMW: ggml_op = 85; -pub const ggml_op_GGML_OP_OPT_STEP_SGD: ggml_op = 86; -pub const ggml_op_GGML_OP_GLU: ggml_op = 87; -pub const ggml_op_GGML_OP_COUNT: ggml_op = 88; +pub const ggml_op_GGML_OP_CUMSUM: ggml_op = 16; +pub const ggml_op_GGML_OP_MEAN: ggml_op = 17; +pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 18; +pub const ggml_op_GGML_OP_COUNT_EQUAL: ggml_op = 19; +pub const ggml_op_GGML_OP_REPEAT: ggml_op = 20; +pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 21; +pub const ggml_op_GGML_OP_CONCAT: ggml_op = 22; +pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 23; +pub const ggml_op_GGML_OP_NORM: ggml_op = 24; +pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 25; +pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 26; +pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 27; +pub const ggml_op_GGML_OP_L2_NORM: ggml_op = 28; +pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 29; +pub const ggml_op_GGML_OP_MUL_MAT_ID: ggml_op = 30; +pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 31; +pub const ggml_op_GGML_OP_SCALE: ggml_op = 32; +pub const ggml_op_GGML_OP_SET: ggml_op = 33; +pub const ggml_op_GGML_OP_CPY: ggml_op = 34; +pub const ggml_op_GGML_OP_CONT: ggml_op = 35; +pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 36; +pub const ggml_op_GGML_OP_VIEW: ggml_op = 37; +pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 38; +pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 39; +pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 40; +pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 41; +pub const ggml_op_GGML_OP_SET_ROWS: ggml_op = 42; +pub const ggml_op_GGML_OP_DIAG: ggml_op = 43; +pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 44; +pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 45; +pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 46; +pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 47; +pub const ggml_op_GGML_OP_ROPE: ggml_op = 48; +pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 49; +pub const ggml_op_GGML_OP_CLAMP: ggml_op = 50; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_1D: ggml_op = 51; +pub const ggml_op_GGML_OP_IM2COL: ggml_op = 52; +pub const ggml_op_GGML_OP_IM2COL_BACK: ggml_op = 53; +pub const ggml_op_GGML_OP_IM2COL_3D: ggml_op = 54; +pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 55; +pub const ggml_op_GGML_OP_CONV_3D: ggml_op = 56; +pub const ggml_op_GGML_OP_CONV_2D_DW: ggml_op = 57; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 58; +pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 59; +pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 60; +pub const ggml_op_GGML_OP_POOL_2D_BACK: ggml_op = 61; +pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 62; +pub const ggml_op_GGML_OP_PAD: ggml_op = 63; +pub const ggml_op_GGML_OP_PAD_REFLECT_1D: ggml_op = 64; +pub const ggml_op_GGML_OP_ROLL: ggml_op = 65; +pub const ggml_op_GGML_OP_ARANGE: ggml_op = 66; +pub const ggml_op_GGML_OP_TIMESTEP_EMBEDDING: ggml_op = 67; +pub const ggml_op_GGML_OP_ARGSORT: ggml_op = 68; +pub const ggml_op_GGML_OP_LEAKY_RELU: ggml_op = 69; +pub const ggml_op_GGML_OP_TRI: ggml_op = 70; +pub const ggml_op_GGML_OP_FILL: ggml_op = 71; +pub const ggml_op_GGML_OP_FLASH_ATTN_EXT: ggml_op = 72; +pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 73; +pub const ggml_op_GGML_OP_SSM_CONV: ggml_op = 74; +pub const ggml_op_GGML_OP_SSM_SCAN: ggml_op = 75; +pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 76; +pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 77; +pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 78; +pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 79; +pub const ggml_op_GGML_OP_RWKV_WKV6: ggml_op = 80; +pub const ggml_op_GGML_OP_GATED_LINEAR_ATTN: ggml_op = 81; +pub const ggml_op_GGML_OP_RWKV_WKV7: ggml_op = 82; +pub const ggml_op_GGML_OP_SOLVE_TRI: ggml_op = 83; +pub const ggml_op_GGML_OP_UNARY: ggml_op = 84; +pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 85; +pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 86; +pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 87; +pub const ggml_op_GGML_OP_CUSTOM: ggml_op = 88; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 89; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 90; +pub const ggml_op_GGML_OP_OPT_STEP_ADAMW: ggml_op = 91; +pub const ggml_op_GGML_OP_OPT_STEP_SGD: ggml_op = 92; +pub const ggml_op_GGML_OP_GLU: ggml_op = 93; +pub const ggml_op_GGML_OP_COUNT: ggml_op = 94; pub type ggml_op = ::std::os::raw::c_uint; pub const ggml_unary_op_GGML_UNARY_OP_ABS: ggml_unary_op = 0; pub const ggml_unary_op_GGML_UNARY_OP_SGN: ggml_unary_op = 1; @@ -1402,8 +1415,15 @@ pub const ggml_unary_op_GGML_UNARY_OP_SILU: ggml_unary_op = 10; pub const ggml_unary_op_GGML_UNARY_OP_HARDSWISH: ggml_unary_op = 11; pub const ggml_unary_op_GGML_UNARY_OP_HARDSIGMOID: ggml_unary_op = 12; pub const ggml_unary_op_GGML_UNARY_OP_EXP: ggml_unary_op = 13; -pub const ggml_unary_op_GGML_UNARY_OP_GELU_ERF: ggml_unary_op = 14; -pub const ggml_unary_op_GGML_UNARY_OP_COUNT: ggml_unary_op = 15; +pub const ggml_unary_op_GGML_UNARY_OP_EXPM1: ggml_unary_op = 14; +pub const ggml_unary_op_GGML_UNARY_OP_SOFTPLUS: ggml_unary_op = 15; +pub const ggml_unary_op_GGML_UNARY_OP_GELU_ERF: ggml_unary_op = 16; +pub const ggml_unary_op_GGML_UNARY_OP_XIELU: ggml_unary_op = 17; +pub const ggml_unary_op_GGML_UNARY_OP_FLOOR: ggml_unary_op = 18; +pub const ggml_unary_op_GGML_UNARY_OP_CEIL: ggml_unary_op = 19; +pub const ggml_unary_op_GGML_UNARY_OP_ROUND: ggml_unary_op = 20; +pub const ggml_unary_op_GGML_UNARY_OP_TRUNC: ggml_unary_op = 21; +pub const ggml_unary_op_GGML_UNARY_OP_COUNT: ggml_unary_op = 22; pub type ggml_unary_op = ::std::os::raw::c_uint; pub const ggml_glu_op_GGML_GLU_OP_REGLU: ggml_glu_op = 0; pub const ggml_glu_op_GGML_GLU_OP_GEGLU: ggml_glu_op = 1; @@ -1429,6 +1449,11 @@ pub const ggml_tensor_flag_GGML_TENSOR_FLAG_OUTPUT: ggml_tensor_flag = 2; pub const ggml_tensor_flag_GGML_TENSOR_FLAG_PARAM: ggml_tensor_flag = 4; pub const ggml_tensor_flag_GGML_TENSOR_FLAG_LOSS: ggml_tensor_flag = 8; pub type ggml_tensor_flag = ::std::os::raw::c_uint; +pub const ggml_tri_type_GGML_TRI_TYPE_UPPER_DIAG: ggml_tri_type = 0; +pub const ggml_tri_type_GGML_TRI_TYPE_UPPER: ggml_tri_type = 1; +pub const ggml_tri_type_GGML_TRI_TYPE_LOWER_DIAG: ggml_tri_type = 2; +pub const ggml_tri_type_GGML_TRI_TYPE_LOWER: ggml_tri_type = 3; +pub type ggml_tri_type = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_init_params { @@ -1919,6 +1944,18 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_log_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_expm1(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_expm1_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_softplus(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_softplus_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_sin(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } @@ -1937,6 +1974,9 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_sum_rows(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_cumsum(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_mean(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } @@ -2082,6 +2122,41 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_exp_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_floor(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_floor_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_ceil(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_ceil_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_round(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_round_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + #[doc = " Truncates the fractional part of each element in the tensor (towards zero).\n For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0\n Similar to std::trunc in C/C++."] + pub fn ggml_trunc(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_trunc_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_xielu( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + alpha_n: f32, + alpha_p: f32, + beta: f32, + eps: f32, + ) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_glu( ctx: *mut ggml_context, @@ -2551,6 +2626,15 @@ unsafe extern "C" { max_bias: f32, ) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_soft_max_ext_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + mask: *mut ggml_tensor, + scale: f32, + max_bias: f32, + ) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_soft_max_add_sinks(a: *mut ggml_tensor, sinks: *mut ggml_tensor); } @@ -2836,6 +2920,41 @@ unsafe extern "C" { d1: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_im2col_3d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + IC: i64, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + s2: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + p2: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + d2: ::std::os::raw::c_int, + dst_type: ggml_type, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_conv_3d( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + IC: i64, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + s2: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + p2: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + d2: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_conv_2d_sk_p0( ctx: *mut ggml_context, @@ -2897,6 +3016,25 @@ unsafe extern "C" { d1: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_conv_3d_direct( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + s0: ::std::os::raw::c_int, + s1: ::std::os::raw::c_int, + s2: ::std::os::raw::c_int, + p0: ::std::os::raw::c_int, + p1: ::std::os::raw::c_int, + p2: ::std::os::raw::c_int, + d0: ::std::os::raw::c_int, + d1: ::std::os::raw::c_int, + d2: ::std::os::raw::c_int, + n_channels: ::std::os::raw::c_int, + n_batch: ::std::os::raw::c_int, + n_channels_out: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} pub const ggml_op_pool_GGML_OP_POOL_MAX: ggml_op_pool = 0; pub const ggml_op_pool_GGML_OP_POOL_AVG: ggml_op_pool = 1; pub const ggml_op_pool_GGML_OP_POOL_COUNT: ggml_op_pool = 2; @@ -2940,7 +3078,8 @@ unsafe extern "C" { } pub const ggml_scale_mode_GGML_SCALE_MODE_NEAREST: ggml_scale_mode = 0; pub const ggml_scale_mode_GGML_SCALE_MODE_BILINEAR: ggml_scale_mode = 1; -pub const ggml_scale_mode_GGML_SCALE_MODE_COUNT: ggml_scale_mode = 2; +pub const ggml_scale_mode_GGML_SCALE_MODE_BICUBIC: ggml_scale_mode = 2; +pub const ggml_scale_mode_GGML_SCALE_MODE_COUNT: ggml_scale_mode = 3; pub type ggml_scale_mode = ::std::os::raw::c_uint; pub const ggml_scale_flag_GGML_SCALE_FLAG_ALIGN_CORNERS: ggml_scale_flag = 256; pub type ggml_scale_flag = ::std::os::raw::c_uint; @@ -2984,6 +3123,20 @@ unsafe extern "C" { p3: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_pad_ext( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + lp0: ::std::os::raw::c_int, + rp0: ::std::os::raw::c_int, + lp1: ::std::os::raw::c_int, + rp1: ::std::os::raw::c_int, + lp2: ::std::os::raw::c_int, + rp2: ::std::os::raw::c_int, + lp3: ::std::os::raw::c_int, + rp3: ::std::os::raw::c_int, + ) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_pad_reflect_1d( ctx: *mut ggml_context, @@ -3002,6 +3155,23 @@ unsafe extern "C" { shift3: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_tri( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + type_: ggml_tri_type, + ) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_fill(ctx: *mut ggml_context, a: *mut ggml_tensor, c: f32) -> *mut ggml_tensor; +} +unsafe extern "C" { + pub fn ggml_fill_inplace( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + c: f32, + ) -> *mut ggml_tensor; +} unsafe extern "C" { pub fn ggml_timestep_embedding( ctx: *mut ggml_context, @@ -3173,6 +3343,16 @@ unsafe extern "C" { state: *mut ggml_tensor, ) -> *mut ggml_tensor; } +unsafe extern "C" { + pub fn ggml_solve_tri( + ctx: *mut ggml_context, + a: *mut ggml_tensor, + b: *mut ggml_tensor, + left: bool, + lower: bool, + uni: bool, + ) -> *mut ggml_tensor; +} pub type ggml_custom1_op_t = ::std::option::Option< unsafe extern "C" fn( dst: *mut ggml_tensor, @@ -3884,7 +4064,8 @@ unsafe extern "C" { } pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_CPU: ggml_backend_dev_type = 0; pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_GPU: ggml_backend_dev_type = 1; -pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_ACCEL: ggml_backend_dev_type = 2; +pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_IGPU: ggml_backend_dev_type = 2; +pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_ACCEL: ggml_backend_dev_type = 3; pub type ggml_backend_dev_type = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -3916,11 +4097,12 @@ pub struct ggml_backend_dev_props { pub memory_free: usize, pub memory_total: usize, pub type_: ggml_backend_dev_type, + pub device_id: *const ::std::os::raw::c_char, pub caps: ggml_backend_dev_caps, } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of ggml_backend_dev_props"][::std::mem::size_of::() - 40usize]; + ["Size of ggml_backend_dev_props"][::std::mem::size_of::() - 56usize]; ["Alignment of ggml_backend_dev_props"] [::std::mem::align_of::() - 8usize]; ["Offset of field: ggml_backend_dev_props::name"] @@ -3933,8 +4115,10 @@ const _: () = { [::std::mem::offset_of!(ggml_backend_dev_props, memory_total) - 24usize]; ["Offset of field: ggml_backend_dev_props::type_"] [::std::mem::offset_of!(ggml_backend_dev_props, type_) - 32usize]; + ["Offset of field: ggml_backend_dev_props::device_id"] + [::std::mem::offset_of!(ggml_backend_dev_props, device_id) - 40usize]; ["Offset of field: ggml_backend_dev_props::caps"] - [::std::mem::offset_of!(ggml_backend_dev_props, caps) - 36usize]; + [::std::mem::offset_of!(ggml_backend_dev_props, caps) - 48usize]; }; unsafe extern "C" { pub fn ggml_backend_dev_name(device: ggml_backend_dev_t) -> *const ::std::os::raw::c_char; @@ -4046,6 +4230,9 @@ const _: () = { pub type ggml_backend_get_features_t = ::std::option::Option< unsafe extern "C" fn(reg: ggml_backend_reg_t) -> *mut ggml_backend_feature, >; +unsafe extern "C" { + pub fn ggml_backend_register(reg: ggml_backend_reg_t); +} unsafe extern "C" { pub fn ggml_backend_device_register(device: ggml_backend_dev_t); } @@ -4144,6 +4331,12 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_backend_sched_get_n_copies(sched: ggml_backend_sched_t) -> ::std::os::raw::c_int; } +unsafe extern "C" { + pub fn ggml_backend_sched_get_buffer_type( + sched: ggml_backend_sched_t, + backend: ggml_backend_t, + ) -> ggml_backend_buffer_type_t; +} unsafe extern "C" { pub fn ggml_backend_sched_get_buffer_size( sched: ggml_backend_sched_t, @@ -4163,6 +4356,9 @@ unsafe extern "C" { node: *mut ggml_tensor, ) -> ggml_backend_t; } +unsafe extern "C" { + pub fn ggml_backend_sched_split_graph(sched: ggml_backend_sched_t, graph: *mut ggml_cgraph); +} unsafe extern "C" { pub fn ggml_backend_sched_alloc_graph( sched: ggml_backend_sched_t, @@ -4469,9 +4665,6 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_cpu_has_vxe() -> ::std::os::raw::c_int; } -unsafe extern "C" { - pub fn ggml_cpu_has_nnpa() -> ::std::os::raw::c_int; -} unsafe extern "C" { pub fn ggml_cpu_has_wasm_simd() -> ::std::os::raw::c_int; } @@ -4548,6 +4741,9 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_cpu_fp32_to_fp32(arg1: *const f32, arg2: *mut f32, arg3: i64); } +unsafe extern "C" { + pub fn ggml_cpu_fp32_to_i32(arg1: *const f32, arg2: *mut i32, arg3: i64); +} unsafe extern "C" { pub fn ggml_cpu_fp32_to_fp16(arg1: *const f32, arg2: *mut ggml_fp16_t, arg3: i64); } @@ -5190,6 +5386,7 @@ pub struct whisper_full_params { pub tdrz_enable: bool, pub suppress_regex: *const ::std::os::raw::c_char, pub initial_prompt: *const ::std::os::raw::c_char, + pub carry_initial_prompt: bool, pub prompt_tokens: *const whisper_token, pub prompt_n_tokens: ::std::os::raw::c_int, pub language: *const ::std::os::raw::c_char, @@ -5256,7 +5453,7 @@ const _: () = { }; #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of whisper_full_params"][::std::mem::size_of::() - 296usize]; + ["Size of whisper_full_params"][::std::mem::size_of::() - 304usize]; ["Alignment of whisper_full_params"][::std::mem::align_of::() - 8usize]; ["Offset of field: whisper_full_params::strategy"] [::std::mem::offset_of!(whisper_full_params, strategy) - 0usize]; @@ -5306,70 +5503,72 @@ const _: () = { [::std::mem::offset_of!(whisper_full_params, suppress_regex) - 64usize]; ["Offset of field: whisper_full_params::initial_prompt"] [::std::mem::offset_of!(whisper_full_params, initial_prompt) - 72usize]; + ["Offset of field: whisper_full_params::carry_initial_prompt"] + [::std::mem::offset_of!(whisper_full_params, carry_initial_prompt) - 80usize]; ["Offset of field: whisper_full_params::prompt_tokens"] - [::std::mem::offset_of!(whisper_full_params, prompt_tokens) - 80usize]; + [::std::mem::offset_of!(whisper_full_params, prompt_tokens) - 88usize]; ["Offset of field: whisper_full_params::prompt_n_tokens"] - [::std::mem::offset_of!(whisper_full_params, prompt_n_tokens) - 88usize]; + [::std::mem::offset_of!(whisper_full_params, prompt_n_tokens) - 96usize]; ["Offset of field: whisper_full_params::language"] - [::std::mem::offset_of!(whisper_full_params, language) - 96usize]; + [::std::mem::offset_of!(whisper_full_params, language) - 104usize]; ["Offset of field: whisper_full_params::detect_language"] - [::std::mem::offset_of!(whisper_full_params, detect_language) - 104usize]; + [::std::mem::offset_of!(whisper_full_params, detect_language) - 112usize]; ["Offset of field: whisper_full_params::suppress_blank"] - [::std::mem::offset_of!(whisper_full_params, suppress_blank) - 105usize]; + [::std::mem::offset_of!(whisper_full_params, suppress_blank) - 113usize]; ["Offset of field: whisper_full_params::suppress_nst"] - [::std::mem::offset_of!(whisper_full_params, suppress_nst) - 106usize]; + [::std::mem::offset_of!(whisper_full_params, suppress_nst) - 114usize]; ["Offset of field: whisper_full_params::temperature"] - [::std::mem::offset_of!(whisper_full_params, temperature) - 108usize]; + [::std::mem::offset_of!(whisper_full_params, temperature) - 116usize]; ["Offset of field: whisper_full_params::max_initial_ts"] - [::std::mem::offset_of!(whisper_full_params, max_initial_ts) - 112usize]; + [::std::mem::offset_of!(whisper_full_params, max_initial_ts) - 120usize]; ["Offset of field: whisper_full_params::length_penalty"] - [::std::mem::offset_of!(whisper_full_params, length_penalty) - 116usize]; + [::std::mem::offset_of!(whisper_full_params, length_penalty) - 124usize]; ["Offset of field: whisper_full_params::temperature_inc"] - [::std::mem::offset_of!(whisper_full_params, temperature_inc) - 120usize]; + [::std::mem::offset_of!(whisper_full_params, temperature_inc) - 128usize]; ["Offset of field: whisper_full_params::entropy_thold"] - [::std::mem::offset_of!(whisper_full_params, entropy_thold) - 124usize]; + [::std::mem::offset_of!(whisper_full_params, entropy_thold) - 132usize]; ["Offset of field: whisper_full_params::logprob_thold"] - [::std::mem::offset_of!(whisper_full_params, logprob_thold) - 128usize]; + [::std::mem::offset_of!(whisper_full_params, logprob_thold) - 136usize]; ["Offset of field: whisper_full_params::no_speech_thold"] - [::std::mem::offset_of!(whisper_full_params, no_speech_thold) - 132usize]; + [::std::mem::offset_of!(whisper_full_params, no_speech_thold) - 140usize]; ["Offset of field: whisper_full_params::greedy"] - [::std::mem::offset_of!(whisper_full_params, greedy) - 136usize]; + [::std::mem::offset_of!(whisper_full_params, greedy) - 144usize]; ["Offset of field: whisper_full_params::beam_search"] - [::std::mem::offset_of!(whisper_full_params, beam_search) - 140usize]; + [::std::mem::offset_of!(whisper_full_params, beam_search) - 148usize]; ["Offset of field: whisper_full_params::new_segment_callback"] - [::std::mem::offset_of!(whisper_full_params, new_segment_callback) - 152usize]; + [::std::mem::offset_of!(whisper_full_params, new_segment_callback) - 160usize]; ["Offset of field: whisper_full_params::new_segment_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, new_segment_callback_user_data) - 160usize]; + [::std::mem::offset_of!(whisper_full_params, new_segment_callback_user_data) - 168usize]; ["Offset of field: whisper_full_params::progress_callback"] - [::std::mem::offset_of!(whisper_full_params, progress_callback) - 168usize]; + [::std::mem::offset_of!(whisper_full_params, progress_callback) - 176usize]; ["Offset of field: whisper_full_params::progress_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, progress_callback_user_data) - 176usize]; + [::std::mem::offset_of!(whisper_full_params, progress_callback_user_data) - 184usize]; ["Offset of field: whisper_full_params::encoder_begin_callback"] - [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback) - 184usize]; + [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback) - 192usize]; ["Offset of field: whisper_full_params::encoder_begin_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback_user_data) - 192usize]; + [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback_user_data) - 200usize]; ["Offset of field: whisper_full_params::abort_callback"] - [::std::mem::offset_of!(whisper_full_params, abort_callback) - 200usize]; + [::std::mem::offset_of!(whisper_full_params, abort_callback) - 208usize]; ["Offset of field: whisper_full_params::abort_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, abort_callback_user_data) - 208usize]; + [::std::mem::offset_of!(whisper_full_params, abort_callback_user_data) - 216usize]; ["Offset of field: whisper_full_params::logits_filter_callback"] - [::std::mem::offset_of!(whisper_full_params, logits_filter_callback) - 216usize]; + [::std::mem::offset_of!(whisper_full_params, logits_filter_callback) - 224usize]; ["Offset of field: whisper_full_params::logits_filter_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, logits_filter_callback_user_data) - 224usize]; + [::std::mem::offset_of!(whisper_full_params, logits_filter_callback_user_data) - 232usize]; ["Offset of field: whisper_full_params::grammar_rules"] - [::std::mem::offset_of!(whisper_full_params, grammar_rules) - 232usize]; + [::std::mem::offset_of!(whisper_full_params, grammar_rules) - 240usize]; ["Offset of field: whisper_full_params::n_grammar_rules"] - [::std::mem::offset_of!(whisper_full_params, n_grammar_rules) - 240usize]; + [::std::mem::offset_of!(whisper_full_params, n_grammar_rules) - 248usize]; ["Offset of field: whisper_full_params::i_start_rule"] - [::std::mem::offset_of!(whisper_full_params, i_start_rule) - 248usize]; + [::std::mem::offset_of!(whisper_full_params, i_start_rule) - 256usize]; ["Offset of field: whisper_full_params::grammar_penalty"] - [::std::mem::offset_of!(whisper_full_params, grammar_penalty) - 256usize]; + [::std::mem::offset_of!(whisper_full_params, grammar_penalty) - 264usize]; ["Offset of field: whisper_full_params::vad"] - [::std::mem::offset_of!(whisper_full_params, vad) - 260usize]; + [::std::mem::offset_of!(whisper_full_params, vad) - 268usize]; ["Offset of field: whisper_full_params::vad_model_path"] - [::std::mem::offset_of!(whisper_full_params, vad_model_path) - 264usize]; + [::std::mem::offset_of!(whisper_full_params, vad_model_path) - 272usize]; ["Offset of field: whisper_full_params::vad_params"] - [::std::mem::offset_of!(whisper_full_params, vad_params) - 272usize]; + [::std::mem::offset_of!(whisper_full_params, vad_params) - 280usize]; }; unsafe extern "C" { pub fn whisper_context_default_params_by_ref() -> *mut whisper_context_params; diff --git a/sys/whisper.cpp b/sys/whisper.cpp index fc45bb8..a88b93f 160000 --- a/sys/whisper.cpp +++ b/sys/whisper.cpp @@ -1 +1 @@ -Subproject commit fc45bb86251f774ef817e89878bb4c2636c8a58f +Subproject commit a88b93f85f08fc6045e5d8a8c3f94b7be0ac8bce diff --git a/wyoming-whisper-rs/Cargo.toml b/wyoming-whisper-rs/Cargo.toml new file mode 100644 index 0000000..c6c55d8 --- /dev/null +++ b/wyoming-whisper-rs/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "wyoming-whisper-rs" +version = "0.1.0" +edition = "2021" +description = "Wyoming protocol ASR server powered by whisper-rs" +license = "Unlicense" + +[dependencies] +whisper-rs = { path = "..", features = ["tracing_backend"] } +tokio = { version = "1", features = ["rt-multi-thread", "net", "io-util", "macros", "signal", "sync"] } +clap = { version = "4", features = ["derive"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +thiserror = "2" +anyhow = "1" + +[features] +default = [] +cuda = ["whisper-rs/cuda"] +hipblas = ["whisper-rs/hipblas"] +metal = ["whisper-rs/metal"] +vulkan = ["whisper-rs/vulkan"] +openblas = ["whisper-rs/openblas"] +openmp = ["whisper-rs/openmp"] +coreml = ["whisper-rs/coreml"] diff --git a/wyoming-whisper-rs/package.nix b/wyoming-whisper-rs/package.nix new file mode 100644 index 0000000..9b522ab --- /dev/null +++ b/wyoming-whisper-rs/package.nix @@ -0,0 +1,52 @@ +{ + lib, + rustPlatform, + cmake, + pkg-config, + shaderc, + libclang, + glibc, + vulkan-headers, + vulkan-loader, + src, +}: + +rustPlatform.buildRustPackage { + pname = "wyoming-whisper-rs"; + version = "0.1.0"; + + inherit src; + + cargoLock.lockFile = ../Cargo.lock; + + buildAndTestSubdir = "wyoming-whisper-rs"; + + buildFeatures = [ "vulkan" ]; + + nativeBuildInputs = [ + cmake + pkg-config + shaderc + ]; + + buildInputs = [ + libclang.lib + vulkan-headers + vulkan-loader + ]; + + env = { + LIBCLANG_PATH = "${libclang.lib}/lib"; + VULKAN_SDK = "${vulkan-headers}"; + BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [ + "-isystem ${libclang.lib}/lib/clang/${lib.versions.major libclang.version}/include" + "-isystem ${glibc.dev}/include" + ]; + }; + + meta = { + description = "Wyoming protocol ASR server powered by whisper-rs"; + license = lib.licenses.unlicense; + mainProgram = "wyoming-whisper-rs"; + }; +} diff --git a/wyoming-whisper-rs/src/cli.rs b/wyoming-whisper-rs/src/cli.rs new file mode 100644 index 0000000..e4b7328 --- /dev/null +++ b/wyoming-whisper-rs/src/cli.rs @@ -0,0 +1,49 @@ +use clap::Parser; +use std::path::PathBuf; + +#[derive(Parser, Debug)] +#[command( + name = "wyoming-whisper-rs", + about = "Wyoming protocol ASR server powered by whisper-rs" +)] +pub struct Args { + /// Path to the GGML whisper model file + #[arg(long)] + pub model: PathBuf, + + /// Host address to bind to + #[arg(long, default_value = "0.0.0.0")] + pub host: String, + + /// Port to listen on + #[arg(long, default_value_t = 10300)] + pub port: u16, + + /// Language code (e.g. "en", "de"). Omit for auto-detection + #[arg(long)] + pub language: Option, + + /// Beam search size + #[arg(long, default_value_t = 5)] + pub beam_size: i32, + + /// Number of threads for whisper inference (0 = whisper default) + #[arg(long, default_value_t = 0)] + pub threads: i32, + + /// GPU device index + #[arg(long, default_value_t = 0)] + pub gpu_device: i32, + + /// Disable GPU acceleration + #[arg(long)] + pub no_gpu: bool, + + /// Maximum concurrent transcriptions + #[arg(long, default_value_t = 1)] + pub max_concurrent: usize, + + /// Model name reported in Wyoming info + #[arg(long, default_value = "whisper")] + pub model_name: String, +} diff --git a/wyoming-whisper-rs/src/error.rs b/wyoming-whisper-rs/src/error.rs new file mode 100644 index 0000000..bd83731 --- /dev/null +++ b/wyoming-whisper-rs/src/error.rs @@ -0,0 +1,22 @@ +use std::io; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("I/O error: {0}")] + Io(#[from] io::Error), + + #[error("JSON error: {0}")] + Json(#[from] serde_json::Error), + + #[error("Whisper error: {0}")] + Whisper(#[from] whisper_rs::WhisperError), + + #[error("payload too large: {size} bytes (max {max})")] + PayloadTooLarge { size: usize, max: usize }, + + #[error("missing field: {0}")] + MissingField(&'static str), + + #[error("invalid audio format: {0}")] + InvalidAudio(String), +} diff --git a/wyoming-whisper-rs/src/main.rs b/wyoming-whisper-rs/src/main.rs new file mode 100644 index 0000000..404ccce --- /dev/null +++ b/wyoming-whisper-rs/src/main.rs @@ -0,0 +1,111 @@ +mod cli; +mod error; +mod protocol; +mod session; +mod transcribe; + +use std::sync::Arc; + +use clap::Parser; +use serde_json::json; +use tokio::net::TcpListener; +use tracing::{error, info}; +use whisper_rs::{WhisperContext, WhisperContextParameters}; + +use crate::cli::Args; +use crate::session::SessionConfig; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let args = Args::parse(); + + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), + ) + .init(); + + // Validate model path + if !args.model.exists() { + error!(path = %args.model.display(), "model file not found"); + std::process::exit(1); + } + + // Validate language if specified + if let Some(ref lang) = args.language { + if whisper_rs::get_lang_id(lang).is_none() { + error!(language = lang, "unknown language code"); + std::process::exit(1); + } + } + + // Route whisper.cpp logs through tracing + whisper_rs::install_logging_hooks(); + + // Load model + info!(model = %args.model.display(), "loading whisper model"); + let mut ctx_params = WhisperContextParameters::default(); + if args.no_gpu { + ctx_params.use_gpu(false); + } + ctx_params.gpu_device(args.gpu_device); + + let model_path = args.model.to_string_lossy().to_string(); + let ctx = WhisperContext::new_with_params(&model_path, ctx_params)?; + let ctx = Arc::new(ctx); + info!("model loaded"); + + // Build language list + let max_id = whisper_rs::get_lang_max_id(); + let languages: Vec = (0..=max_id) + .filter_map(|id| { + whisper_rs::get_lang_str(id).map(|code| { + json!({ + "name": code, + "description": whisper_rs::get_lang_str_full(id).unwrap_or(code), + }) + }) + }) + .collect(); + + let semaphore = Arc::new(tokio::sync::Semaphore::new(args.max_concurrent)); + + let session_config = Arc::new(SessionConfig { + ctx, + semaphore, + model_name: args.model_name, + languages, + default_language: args.language, + beam_size: args.beam_size, + threads: args.threads, + }); + + let bind_addr = format!("{}:{}", args.host, args.port); + let listener = TcpListener::bind(&bind_addr).await?; + info!(address = %bind_addr, "wyoming server listening"); + + loop { + tokio::select! { + result = listener.accept() => { + let (stream, addr) = result?; + info!(peer = %addr, "new connection"); + let config = Arc::clone(&session_config); + + tokio::spawn(async move { + let (reader, writer) = stream.into_split(); + if let Err(e) = session::run(config, reader, writer).await { + error!(peer = %addr, error = %e, "session error"); + } + info!(peer = %addr, "connection closed"); + }); + } + _ = tokio::signal::ctrl_c() => { + info!("shutting down"); + break; + } + } + } + + Ok(()) +} diff --git a/wyoming-whisper-rs/src/protocol.rs b/wyoming-whisper-rs/src/protocol.rs new file mode 100644 index 0000000..02ddbe1 --- /dev/null +++ b/wyoming-whisper-rs/src/protocol.rs @@ -0,0 +1,90 @@ +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; + +use crate::error::Error; + +const MAX_PAYLOAD: usize = 100 * 1024 * 1024; // 100 MB + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Event { + #[serde(rename = "type")] + pub event_type: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub data: Option, + + #[serde(default, skip_serializing_if = "is_zero")] + pub data_length: usize, + + #[serde(default, skip_serializing_if = "is_zero")] + pub payload_length: usize, + + #[serde(skip)] + pub payload: Option>, +} + +fn is_zero(v: &usize) -> bool { + *v == 0 +} + +impl Event { + pub fn new(event_type: impl Into) -> Self { + Self { + event_type: event_type.into(), + data: None, + data_length: 0, + payload_length: 0, + payload: None, + } + } + + pub fn with_data(mut self, data: Value) -> Self { + let serialized = serde_json::to_string(&data).unwrap_or_default(); + self.data_length = serialized.len(); + self.data = Some(data); + self + } +} + +pub async fn read_event( + reader: &mut BufReader, +) -> Result, Error> { + let mut line = String::new(); + let n = reader.read_line(&mut line).await?; + if n == 0 { + return Ok(None); + } + + let mut event: Event = serde_json::from_str(line.trim())?; + + if event.payload_length > 0 { + if event.payload_length > MAX_PAYLOAD { + return Err(Error::PayloadTooLarge { + size: event.payload_length, + max: MAX_PAYLOAD, + }); + } + let mut buf = vec![0u8; event.payload_length]; + reader.read_exact(&mut buf).await?; + event.payload = Some(buf); + } + + Ok(Some(event)) +} + +pub async fn write_event( + writer: &mut W, + event: &Event, +) -> Result<(), Error> { + let json = serde_json::to_string(event)?; + writer.write_all(json.as_bytes()).await?; + writer.write_all(b"\n").await?; + + if let Some(ref payload) = event.payload { + writer.write_all(payload).await?; + } + + writer.flush().await?; + Ok(()) +} diff --git a/wyoming-whisper-rs/src/session.rs b/wyoming-whisper-rs/src/session.rs new file mode 100644 index 0000000..5c719d9 --- /dev/null +++ b/wyoming-whisper-rs/src/session.rs @@ -0,0 +1,208 @@ +use std::sync::Arc; + +use serde_json::{json, Value}; +use tokio::io::{AsyncRead, AsyncWrite, BufReader}; +use tokio::sync::Semaphore; +use tracing::{debug, warn}; +use whisper_rs::WhisperContext; + +use crate::error::Error; +use crate::protocol::{read_event, write_event, Event}; +use crate::transcribe::{AudioBuffer, TranscribeConfig}; + +pub struct SessionConfig { + pub ctx: Arc, + pub semaphore: Arc, + pub model_name: String, + pub languages: Vec, + pub default_language: Option, + pub beam_size: i32, + pub threads: i32, +} + +enum State { + Idle, + AwaitingAudioStart { + language: Option, + }, + Streaming { + buffer: AudioBuffer, + language: Option, + }, +} + +impl State { + fn name(&self) -> &'static str { + match self { + State::Idle => "idle", + State::AwaitingAudioStart { .. } => "awaiting_audio_start", + State::Streaming { .. } => "streaming", + } + } +} + +pub async fn run(config: Arc, reader: R, mut writer: W) -> Result<(), Error> +where + R: AsyncRead + Unpin, + W: AsyncWrite + Unpin, +{ + let mut reader = BufReader::new(reader); + let mut state = State::Idle; + + loop { + let event = match read_event(&mut reader).await? { + Some(e) => e, + None => { + debug!("client disconnected"); + return Ok(()); + } + }; + + debug!(event_type = %event.event_type, "received event"); + + match event.event_type.as_str() { + "describe" => { + let info = Event::new("info").with_data(json!({ + "type": "asr", + "name": config.model_name, + "description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION), + "installed": true, + "attribution": { + "name": "whisper.cpp", + "url": "https://github.com/ggerganov/whisper.cpp" + }, + "models": [{ + "name": config.model_name, + "description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION), + "installed": true, + "attribution": { + "name": "whisper.cpp", + "url": "https://github.com/ggerganov/whisper.cpp" + }, + "languages": config.languages, + }], + "version": env!("CARGO_PKG_VERSION"), + })); + write_event(&mut writer, &info).await?; + } + + "transcribe" => { + let language = event + .data + .as_ref() + .and_then(|d| d.get("language")) + .and_then(|v| v.as_str()) + .map(String::from) + .or_else(|| config.default_language.clone()); + + state = State::AwaitingAudioStart { language }; + } + + "audio-start" => { + let (rate, width, channels) = parse_audio_start(&event)?; + let language = match state { + State::AwaitingAudioStart { language } => language, + _ => config.default_language.clone(), + }; + + state = State::Streaming { + buffer: AudioBuffer::new(rate, width, channels), + language, + }; + } + + "audio-chunk" => { + if let State::Streaming { ref mut buffer, .. } = state { + if let Some(ref payload) = event.payload { + buffer.append(payload); + } + } else { + warn!(state = state.name(), "audio-chunk received in wrong state"); + } + } + + "audio-stop" => { + if let State::Streaming { buffer, language } = state { + state = State::Idle; + + let tc = TranscribeConfig { + language, + beam_size: config.beam_size, + threads: config.threads, + }; + + match do_transcribe(&config.ctx, &config.semaphore, tc, buffer).await { + Ok(text) => { + let transcript = Event::new("transcript").with_data(json!({ + "text": text, + })); + write_event(&mut writer, &transcript).await?; + } + Err(e) => { + warn!(error = %e, "transcription failed"); + let err = Event::new("error").with_data(json!({ + "text": format!("transcription failed: {e}"), + })); + write_event(&mut writer, &err).await?; + } + } + } else { + warn!(state = state.name(), "audio-stop received in wrong state"); + state = State::Idle; + } + } + + other => { + warn!(event_type = other, "unknown event type"); + let err = Event::new("error").with_data(json!({ + "text": format!("unknown event type: {other}"), + })); + write_event(&mut writer, &err).await?; + } + } + } +} + +fn parse_audio_start(event: &Event) -> Result<(u32, u16, u16), Error> { + let data = event + .data + .as_ref() + .ok_or(Error::MissingField("data in audio-start"))?; + + let rate = data + .get("rate") + .and_then(|v| v.as_u64()) + .ok_or(Error::MissingField("rate"))? as u32; + + let width = data + .get("width") + .and_then(|v| v.as_u64()) + .ok_or(Error::MissingField("width"))? as u16; + + let channels = data + .get("channels") + .and_then(|v| v.as_u64()) + .ok_or(Error::MissingField("channels"))? as u16; + + Ok((rate, width, channels)) +} + +async fn do_transcribe( + ctx: &Arc, + semaphore: &Arc, + config: TranscribeConfig, + buffer: AudioBuffer, +) -> Result { + let audio = buffer.into_f32_16khz_mono()?; + + if audio.is_empty() { + return Ok(String::new()); + } + + let _permit = semaphore.acquire().await.expect("semaphore closed"); + let ctx = Arc::clone(ctx); + + tokio::task::spawn_blocking(move || crate::transcribe::transcribe(&ctx, &config, audio)) + .await + .expect("transcription task panicked") +} diff --git a/wyoming-whisper-rs/src/transcribe.rs b/wyoming-whisper-rs/src/transcribe.rs new file mode 100644 index 0000000..e7220d7 --- /dev/null +++ b/wyoming-whisper-rs/src/transcribe.rs @@ -0,0 +1,151 @@ +use std::sync::Arc; + +use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; + +use crate::error::Error; + +pub struct TranscribeConfig { + pub language: Option, + pub beam_size: i32, + pub threads: i32, +} + +pub struct AudioBuffer { + data: Vec, + rate: u32, + width: u16, + channels: u16, +} + +impl AudioBuffer { + pub fn new(rate: u32, width: u16, channels: u16) -> Self { + Self { + data: Vec::new(), + rate, + width, + channels, + } + } + + pub fn append(&mut self, chunk: &[u8]) { + self.data.extend_from_slice(chunk); + } + + pub fn into_f32_16khz_mono(self) -> Result, Error> { + if self.width != 2 { + return Err(Error::InvalidAudio(format!( + "expected 16-bit audio (width=2), got width={}", + self.width + ))); + } + + if !self.data.len().is_multiple_of(2) { + return Err(Error::InvalidAudio( + "audio data has odd number of bytes for 16-bit samples".into(), + )); + } + + // Interpret as i16 little-endian + let samples_i16: Vec = self + .data + .chunks_exact(2) + .map(|c| i16::from_le_bytes([c[0], c[1]])) + .collect(); + + // Convert i16 -> f32 + let mut samples_f32 = vec![0.0f32; samples_i16.len()]; + whisper_rs::convert_integer_to_float_audio(&samples_i16, &mut samples_f32) + .map_err(|e| Error::InvalidAudio(format!("i16 to f32 conversion failed: {e}")))?; + + // Convert stereo to mono if needed + let mono = if self.channels == 2 { + let mut mono = vec![0.0f32; samples_f32.len() / 2]; + whisper_rs::convert_stereo_to_mono_audio(&samples_f32, &mut mono) + .map_err(|e| Error::InvalidAudio(format!("stereo to mono failed: {e}")))?; + mono + } else if self.channels == 1 { + samples_f32 + } else { + return Err(Error::InvalidAudio(format!( + "unsupported channel count: {}", + self.channels + ))); + }; + + // Resample if not 16kHz + if self.rate == 16000 { + Ok(mono) + } else { + Ok(resample(&mono, self.rate, 16000)) + } + } +} + +/// Simple linear interpolation resampler. +fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec { + if from_rate == to_rate || input.is_empty() { + return input.to_vec(); + } + + let ratio = from_rate as f64 / to_rate as f64; + let output_len = ((input.len() as f64) / ratio).ceil() as usize; + let mut output = Vec::with_capacity(output_len); + + for i in 0..output_len { + let src_pos = i as f64 * ratio; + let idx = src_pos as usize; + let frac = src_pos - idx as f64; + + let sample = if idx + 1 < input.len() { + input[idx] as f64 * (1.0 - frac) + input[idx + 1] as f64 * frac + } else { + input[idx.min(input.len() - 1)] as f64 + }; + + output.push(sample as f32); + } + + output +} + +pub fn transcribe( + ctx: &Arc, + config: &TranscribeConfig, + audio: Vec, +) -> Result { + let mut state = ctx.create_state()?; + + let mut params = FullParams::new(SamplingStrategy::BeamSearch { + beam_size: config.beam_size, + patience: -1.0, + }); + + if let Some(ref lang) = config.language { + params.set_language(Some(lang)); + } else { + params.set_language(None); + params.set_detect_language(true); + } + + if config.threads > 0 { + params.set_n_threads(config.threads); + } + + params.set_print_special(false); + params.set_print_progress(false); + params.set_print_realtime(false); + params.set_print_timestamps(false); + params.set_no_context(true); + params.set_single_segment(false); + + state.full(params, &audio)?; + + let mut text = String::new(); + for segment in state.as_iter() { + if let Ok(s) = segment.to_str_lossy() { + text.push_str(&s); + } + } + + Ok(text.trim().to_string()) +}