diff --git a/.envrc b/.envrc deleted file mode 100644 index 3550a30..0000000 --- a/.envrc +++ /dev/null @@ -1 +0,0 @@ -use flake diff --git a/.gitignore b/.gitignore index 8e7d90a..902abe5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ **/target +**/Cargo.lock /.idea /.vscode -/.direnv *.bin *.wav \ No newline at end of file diff --git a/.rustfmt.toml b/.rustfmt.toml deleted file mode 100644 index e69de29..0000000 diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 2c44afd..0000000 --- a/CLAUDE.md +++ /dev/null @@ -1,97 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -whisper-rs provides safe Rust bindings to [whisper.cpp](https://github.com/ggerganov/whisper.cpp), a C++ speech recognition library. It's a two-crate workspace: - -- **`whisper-rs`** (root) — Safe public API -- **`whisper-rs-sys`** (`sys/`) — FFI bindings generated via bindgen, with a CMake-based build of the whisper.cpp submodule - -The upstream C++ source lives in `sys/whisper.cpp/` (git submodule — clone with `--recursive`). - -## Build Commands - -```bash -cargo build # Default build -cargo build --release --features vulkan # With Vulkan GPU support -cargo build --release --features hipblas # With AMD ROCm support -cargo build --release --features cuda # With NVIDIA CUDA support -cargo test # Run all tests -cargo fmt # Format code -cargo clippy # Lint -``` - -**Running examples** (require a GGML model file and a WAV audio file): -```bash -cargo run --example basic_use -- model.bin audio.wav -cargo run --example audio_transcription -- model.bin audio.wav -cargo run --example vad -- model.bin audio.wav output.wav -``` - -**Skipping bindgen** (use pre-generated bindings): set `WHISPER_DONT_GENERATE_BINDINGS=1`. - -All `WHISPER_*` and `CMAKE_*` env vars are forwarded to the CMake build. - -## Architecture - -``` -whisper-rs (safe Rust API) - → whisper-rs-sys (bindgen FFI) - → whisper.cpp (C++ submodule, built via CMake) - → GGML (tensor library with CPU/GPU backends) -``` - -**Key types and their relationships:** - -- `WhisperContext` — Arc-wrapped model handle, thread-safe, created from a model file -- `WhisperState` — Inference state created from a context; multiple states can share one context -- `FullParams` — Transcription configuration (sampling strategy, language, callbacks) -- `WhisperSegment` / `WhisperToken` — Result types with timestamps and probabilities - -**Core flow:** `WhisperContext::new_with_params()` → `ctx.create_state()` → `state.full(params, &audio_data)` → iterate segments via `state.as_iter()` - -**Module layout in `src/`:** -- `whisper_ctx.rs` — Raw context wrapper (`WhisperInnerContext`) -- `whisper_ctx_wrapper.rs` — Safe public `WhisperContext` -- `whisper_state/` — State management, segments, tokens, iterators -- `whisper_params.rs` — `FullParams`, `SamplingStrategy` (Greedy / BeamSearch) -- `whisper_vad.rs` — Voice Activity Detection -- `whisper_grammar.rs` — GBNF grammar-constrained decoding -- `vulkan.rs` — Vulkan device enumeration (behind `vulkan` feature) - -## Build System (sys/build.rs) - -The build script: -1. Copies `whisper.cpp` sources into `OUT_DIR` -2. Runs bindgen on `wrapper.h` to generate FFI bindings (falls back to `src/bindings.rs` on failure) -3. Configures and builds whisper.cpp via CMake with feature-dependent flags (`GGML_CUDA`, `GGML_HIP`, `GGML_VULKAN`, `GGML_METAL`, etc.) -4. Statically links: `whisper`, `ggml`, `ggml-base`, `ggml-cpu`, plus backend-specific libs - -## Feature Flags - -| Feature | Purpose | -|---------|---------| -| `cuda` | NVIDIA GPU (needs CUDA toolkit) | -| `hipblas` | AMD GPU via ROCm | -| `metal` | Apple Metal GPU | -| `vulkan` | Vulkan GPU | -| `openblas` | OpenBLAS acceleration (requires `BLAS_INCLUDE_DIRS` env var) | -| `openmp` | OpenMP threading | -| `coreml` | Apple CoreML | -| `intel-sycl` | Intel SYCL | -| `raw-api` | Re-export `whisper-rs-sys` types publicly | -| `log_backend` | Route C++ logs to the `log` crate | -| `tracing_backend` | Route C++ logs to the `tracing` crate | - -## Nix Development Environment - -The `flake.nix` provides a devshell with all dependencies for Vulkan and ROCm/hipBLAS builds. Use `direnv allow` or `nix develop` to enter the environment. Key env vars (`LIBCLANG_PATH`, `BINDGEN_EXTRA_CLANG_ARGS`, `HIP_PATH`, `VULKAN_SDK`) are set automatically. - -## PR Conventions - -Per `.github/PULL_REQUEST_TEMPLATE.md`: -- Run `cargo fmt` and `cargo clippy` before submitting -- Self-review code for legibility -- No GenAI-generated code in PRs diff --git a/Cargo.lock b/Cargo.lock deleted file mode 100644 index 0d0f661..0000000 --- a/Cargo.lock +++ /dev/null @@ -1,862 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 4 - -[[package]] -name = "aho-corasick" -version = "1.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd31a130427c27518df266943a5308ed92d4b226cc639f5a8f1002816174301" -dependencies = [ - "memchr", -] - -[[package]] -name = "anstream" -version = "0.6.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "43d5b281e737544384e969a5ccad3f1cdd24b48086a0fc1b2a5262a26b8f4f4a" -dependencies = [ - "anstyle", - "anstyle-parse", - "anstyle-query", - "anstyle-wincon", - "colorchoice", - "is_terminal_polyfill", - "utf8parse", -] - -[[package]] -name = "anstyle" -version = "1.0.13" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5192cca8006f1fd4f7237516f40fa183bb07f8fbdfedaa0036de5ea9b0b45e78" - -[[package]] -name = "anstyle-parse" -version = "0.2.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e7644824f0aa2c7b9384579234ef10eb7efb6a0deb83f9630a49594dd9c15c2" -dependencies = [ - "utf8parse", -] - -[[package]] -name = "anstyle-query" -version = "1.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40c48f72fd53cd289104fc64099abca73db4166ad86ea0b4341abe65af83dadc" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "anstyle-wincon" -version = "3.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291e6a250ff86cd4a820112fb8898808a366d8f9f58ce16d1f538353ad55747d" -dependencies = [ - "anstyle", - "once_cell_polyfill", - "windows-sys 0.61.2", -] - -[[package]] -name = "anyhow" -version = "1.0.102" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" - -[[package]] -name = "bindgen" -version = "0.71.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" -dependencies = [ - "bitflags", - "cexpr", - "clang-sys", - "itertools", - "log", - "prettyplease", - "proc-macro2", - "quote", - "regex", - "rustc-hash", - "shlex", - "syn", -] - -[[package]] -name = "bitflags" -version = "2.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "843867be96c8daad0d758b57df9392b6d8d271134fce549de6ce169ff98a92af" - -[[package]] -name = "bytes" -version = "1.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" - -[[package]] -name = "cc" -version = "1.2.56" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aebf35691d1bfb0ac386a69bac2fde4dd276fb618cf8bf4f5318fe285e821bb2" -dependencies = [ - "find-msvc-tools", - "shlex", -] - -[[package]] -name = "cexpr" -version = "0.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" -dependencies = [ - "nom", -] - -[[package]] -name = "cfg-if" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" - -[[package]] -name = "clang-sys" -version = "1.8.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" -dependencies = [ - "glob", - "libc", - "libloading", -] - -[[package]] -name = "clap" -version = "4.5.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" -dependencies = [ - "clap_builder", - "clap_derive", -] - -[[package]] -name = "clap_builder" -version = "4.5.60" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" -dependencies = [ - "anstream", - "anstyle", - "clap_lex", - "strsim", -] - -[[package]] -name = "clap_derive" -version = "4.5.55" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a92793da1a46a5f2a02a6f4c46c6496b28c43638adea8306fcb0caa1634f24e5" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "clap_lex" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a822ea5bc7590f9d40f1ba12c0dc3c2760f3482c6984db1573ad11031420831" - -[[package]] -name = "cmake" -version = "0.1.57" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75443c44cd6b379beb8c5b45d85d0773baf31cce901fe7bb252f4eff3008ef7d" -dependencies = [ - "cc", -] - -[[package]] -name = "colorchoice" -version = "1.0.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" - -[[package]] -name = "either" -version = "1.15.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" - -[[package]] -name = "errno" -version = "0.3.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" -dependencies = [ - "libc", - "windows-sys 0.61.2", -] - -[[package]] -name = "find-msvc-tools" -version = "0.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" - -[[package]] -name = "fs_extra" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" - -[[package]] -name = "getrandom" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff2abc00be7fca6ebc474524697ae276ad847ad0a6b3faa4bcb027e9a4614ad0" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "glob" -version = "0.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0cc23270f6e1808e30a928bdc84dea0b9b4136a8bc82338574f23baf47bbd280" - -[[package]] -name = "heck" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" - -[[package]] -name = "hound" -version = "3.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62adaabb884c94955b19907d60019f4e145d091c75345379e70d1ee696f7854f" - -[[package]] -name = "is_terminal_polyfill" -version = "1.70.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a6cb138bb79a146c1bd460005623e142ef0181e3d0219cb493e02f7d08a35695" - -[[package]] -name = "itertools" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" -dependencies = [ - "either", -] - -[[package]] -name = "itoa" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" - -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - -[[package]] -name = "libc" -version = "0.2.182" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6800badb6cb2082ffd7b6a67e6125bb39f18782f793520caee8cb8846be06112" - -[[package]] -name = "libloading" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7c4b02199fee7c5d21a5ae7d8cfa79a6ef5bb2fc834d6e9058e89c825efdc55" -dependencies = [ - "cfg-if", - "windows-link", -] - -[[package]] -name = "log" -version = "0.4.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" - -[[package]] -name = "matchers" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" -dependencies = [ - "regex-automata", -] - -[[package]] -name = "memchr" -version = "2.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" - -[[package]] -name = "minimal-lexical" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" - -[[package]] -name = "mio" -version = "1.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" -dependencies = [ - "libc", - "wasi", - "windows-sys 0.61.2", -] - -[[package]] -name = "nom" -version = "7.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" -dependencies = [ - "memchr", - "minimal-lexical", -] - -[[package]] -name = "nu-ansi-term" -version = "0.50.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" -dependencies = [ - "windows-sys 0.61.2", -] - -[[package]] -name = "once_cell" -version = "1.21.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" - -[[package]] -name = "once_cell_polyfill" -version = "1.70.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "384b8ab6d37215f3c5301a95a4accb5d64aa607f1fcb26a11b5303878451b4fe" - -[[package]] -name = "pin-project-lite" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" - -[[package]] -name = "ppv-lite86" -version = "0.2.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" -dependencies = [ - "zerocopy", -] - -[[package]] -name = "prettyplease" -version = "0.2.37" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" -dependencies = [ - "proc-macro2", - "syn", -] - -[[package]] -name = "proc-macro2" -version = "1.0.106" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8fd00f0bb2e90d81d1044c2b32617f68fcb9fa3bb7640c23e9c748e53fb30934" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "21b2ebcf727b7760c461f091f9f0f539b77b8e87f2fd88131e7f1b433b3cece4" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "regex" -version = "1.12.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e10754a14b9137dd7b1e3e5b0493cc9171fdd105e0ab477f51b72e7f3ac0e276" -dependencies = [ - "aho-corasick", - "memchr", - "regex-automata", - "regex-syntax", -] - -[[package]] -name = "regex-automata" -version = "0.4.14" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6e1dd4122fc1595e8162618945476892eefca7b88c52820e74af6262213cae8f" -dependencies = [ - "aho-corasick", - "memchr", - "regex-syntax", -] - -[[package]] -name = "regex-syntax" -version = "0.8.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a96887878f22d7bad8a3b6dc5b7440e0ada9a245242924394987b21cf2210a4c" - -[[package]] -name = "rustc-hash" -version = "2.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" - -[[package]] -name = "serde" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" -dependencies = [ - "serde_core", - "serde_derive", -] - -[[package]] -name = "serde_core" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.228" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "serde_json" -version = "1.0.149" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" -dependencies = [ - "itoa", - "memchr", - "serde", - "serde_core", - "zmij", -] - -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - -[[package]] -name = "shlex" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" - -[[package]] -name = "signal-hook-registry" -version = "1.4.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4db69cba1110affc0e9f7bcd48bbf87b3f4fc7c61fc9155afd4c469eb3d6c1b" -dependencies = [ - "errno", - "libc", -] - -[[package]] -name = "smallvec" -version = "1.15.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03" - -[[package]] -name = "socket2" -version = "0.6.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86f4aa3ad99f2088c990dfa82d367e19cb29268ed67c574d10d0a4bfe71f07e0" -dependencies = [ - "libc", - "windows-sys 0.60.2", -] - -[[package]] -name = "strsim" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" - -[[package]] -name = "syn" -version = "2.0.117" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e665b8803e7b1d2a727f4023456bbbbe74da67099c585258af0ad9c5013b9b99" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "thiserror" -version = "2.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4288b5bcbc7920c07a1149a35cf9590a2aa808e0bc1eafaade0b80947865fbc4" -dependencies = [ - "thiserror-impl", -] - -[[package]] -name = "thiserror-impl" -version = "2.0.18" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ebc4ee7f67670e9b64d05fa4253e753e016c6c95ff35b89b7941d6b856dec1d5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "thread_local" -version = "1.1.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f60246a4944f24f6e018aa17cdeffb7818b76356965d03b07d6a9886e8962185" -dependencies = [ - "cfg-if", -] - -[[package]] -name = "tokio" -version = "1.49.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72a2903cd7736441aac9df9d7688bd0ce48edccaadf181c3b90be801e81d3d86" -dependencies = [ - "bytes", - "libc", - "mio", - "pin-project-lite", - "signal-hook-registry", - "socket2", - "tokio-macros", - "windows-sys 0.61.2", -] - -[[package]] -name = "tokio-macros" -version = "2.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af407857209536a95c8e56f8231ef2c2e2aff839b22e07a1ffcbc617e9db9fa5" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing" -version = "0.1.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" -dependencies = [ - "pin-project-lite", - "tracing-attributes", - "tracing-core", -] - -[[package]] -name = "tracing-attributes" -version = "0.1.31" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7490cfa5ec963746568740651ac6781f701c9c5ea257c58e057f3ba8cf69e8da" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "tracing-core" -version = "0.1.36" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db97caf9d906fbde555dd62fa95ddba9eecfd14cb388e4f491a66d74cd5fb79a" -dependencies = [ - "once_cell", - "valuable", -] - -[[package]] -name = "tracing-log" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex-automata", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", -] - -[[package]] -name = "unicode-ident" -version = "1.0.24" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" - -[[package]] -name = "utf8parse" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" - -[[package]] -name = "valuable" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" - -[[package]] -name = "wasi" -version = "0.11.1+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" - -[[package]] -name = "whisper-rs" -version = "0.15.1" -dependencies = [ - "hound", - "libc", - "log", - "rand", - "tracing", - "whisper-rs-sys", -] - -[[package]] -name = "whisper-rs-sys" -version = "0.14.1" -dependencies = [ - "bindgen", - "cfg-if", - "cmake", - "fs_extra", -] - -[[package]] -name = "windows-link" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" - -[[package]] -name = "windows-sys" -version = "0.60.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" -dependencies = [ - "windows-targets", -] - -[[package]] -name = "windows-sys" -version = "0.61.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae137229bcbd6cdf0f7b80a31df61766145077ddf49416a728b02cb3921ff3fc" -dependencies = [ - "windows-link", -] - -[[package]] -name = "windows-targets" -version = "0.53.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" -dependencies = [ - "windows-link", - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" - -[[package]] -name = "windows_i686_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" - -[[package]] -name = "windows_i686_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.53.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6bbff5f0aada427a1e5a6da5f1f98158182f26556f345ac9e04d36d0ebed650" - -[[package]] -name = "wyoming-whisper-rs" -version = "0.1.0" -dependencies = [ - "anyhow", - "clap", - "serde", - "serde_json", - "thiserror", - "tokio", - "tracing", - "tracing-subscriber", - "whisper-rs", -] - -[[package]] -name = "zerocopy" -version = "0.8.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db6d35d663eadb6c932438e763b262fe1a70987f9ae936e60158176d710cae4a" -dependencies = [ - "zerocopy-derive", -] - -[[package]] -name = "zerocopy-derive" -version = "0.8.39" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4122cd3169e94605190e77839c9a40d40ed048d305bfdc146e7df40ab0f3e517" -dependencies = [ - "proc-macro2", - "quote", - "syn", -] - -[[package]] -name = "zmij" -version = "1.0.21" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b8848ee67ecc8aedbaf3e4122217aff892639231befc6a1b58d29fff4c2cabaa" diff --git a/Cargo.toml b/Cargo.toml index 685b1ba..24849b6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["sys", "wyoming-whisper-rs"] +members = ["sys"] exclude = ["examples/full_usage"] [package] diff --git a/examples/basic_use.rs b/examples/basic_use.rs index 54e7d5d..53d2aab 100644 --- a/examples/basic_use.rs +++ b/examples/basic_use.rs @@ -51,16 +51,15 @@ fn main() { // note that you don't need to use these, you can do it yourself or any other way you want // these are just provided for convenience let mut inter_samples = vec![Default::default(); samples.len()]; - let mut mono_samples = vec![Default::default(); samples.len() / 2]; whisper_rs::convert_integer_to_float_audio(&samples, &mut inter_samples) .expect("failed to convert audio data"); - whisper_rs::convert_stereo_to_mono_audio(&inter_samples, &mut mono_samples) + let samples = whisper_rs::convert_stereo_to_mono_audio(&inter_samples) .expect("failed to convert audio data"); // now we can run the model state - .full(params, &mono_samples[..]) + .full(params, &samples[..]) .expect("failed to run model"); // fetch the results diff --git a/flake.lock b/flake.lock deleted file mode 100644 index 516965b..0000000 --- a/flake.lock +++ /dev/null @@ -1,27 +0,0 @@ -{ - "nodes": { - "nixpkgs": { - "locked": { - "lastModified": 1771482645, - "narHash": "sha256-MpAKyXfJRDTgRU33Hja+G+3h9ywLAJJNRq4Pjbb4dQs=", - "owner": "NixOS", - "repo": "nixpkgs", - "rev": "724cf38d99ba81fbb4a347081db93e2e3a9bc2ae", - "type": "github" - }, - "original": { - "owner": "NixOS", - "ref": "nixpkgs-unstable", - "repo": "nixpkgs", - "type": "github" - } - }, - "root": { - "inputs": { - "nixpkgs": "nixpkgs" - } - } - }, - "root": "root", - "version": 7 -} diff --git a/flake.nix b/flake.nix deleted file mode 100644 index 9c99811..0000000 --- a/flake.nix +++ /dev/null @@ -1,60 +0,0 @@ -{ - inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; - }; - - outputs = { self, nixpkgs, ... }: - let - systems = [ "x86_64-linux" "aarch64-linux" "x86_64-darwin" "aarch64-darwin" ]; - forAllSystems = f: nixpkgs.lib.genAttrs systems (system: f nixpkgs.legacyPackages.${system}); - in - { - packages = forAllSystems (pkgs: { - default = pkgs.callPackage ./wyoming-whisper-rs/package.nix { src = self; }; - }); - - devShells = forAllSystems (pkgs: { - default = pkgs.mkShell { - nativeBuildInputs = [ - pkgs.rustc - pkgs.cargo - pkgs.clippy - pkgs.rustfmt - pkgs.cmake - pkgs.pkg-config - pkgs.shaderc - pkgs.rocmPackages.llvm.clang - ]; - - buildInputs = [ - pkgs.libclang.lib - pkgs.openssl - pkgs.vulkan-headers - pkgs.vulkan-loader - pkgs.rocmPackages.clr - pkgs.rocmPackages.hipblas - pkgs.rocmPackages.rocblas - pkgs.rocmPackages.rocm-runtime - ]; - - env = { - LIBCLANG_PATH = "${pkgs.libclang.lib}/lib"; - VULKAN_SDK = "${pkgs.vulkan-headers}"; - HIP_PATH = "${pkgs.rocmPackages.clr}"; - BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [ - "-isystem ${pkgs.libclang.lib}/lib/clang/${pkgs.lib.versions.major pkgs.libclang.version}/include" - "-isystem ${pkgs.glibc.dev}/include" - ]; - }; - - LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath [ - pkgs.vulkan-loader - pkgs.rocmPackages.clr - pkgs.rocmPackages.hipblas - pkgs.rocmPackages.rocblas - pkgs.rocmPackages.rocm-runtime - ]; - }; - }); - }; -} diff --git a/src/utilities.rs b/src/utilities.rs index cf30539..2e11b7c 100644 --- a/src/utilities.rs +++ b/src/utilities.rs @@ -51,8 +51,7 @@ pub fn convert_integer_to_float_audio( /// ``` /// # use whisper_rs::convert_stereo_to_mono_audio; /// let samples = [0.0f32; 1024]; -/// let mut mono_samples = [0.0f32; 512]; -/// convert_stereo_to_mono_audio(&samples, &mut mono_samples).expect("should be no half samples missing"); +/// let mono = convert_stereo_to_mono_audio(&samples).expect("should be no half samples missing"); /// ``` pub fn convert_stereo_to_mono_audio(input: &[f32], output: &mut [f32]) -> Result<(), WhisperError> { let (input, []) = input.as_chunks::<2>() else { diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 846af3b..b51b602 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -801,7 +801,7 @@ impl<'a, 'b> FullParams<'a, 'b> { /// # Examples /// ``` /// # use whisper_rs::{FullParams, SamplingStrategy}; - /// let mut params = FullParams::new(SamplingStrategy::Greedy { best_of: 5 }); + /// let mut params = FullParams::new(SamplingStrategy::default()); /// params.set_initial_prompt("Hello, world!"); /// // ... further usage of params ... /// ``` diff --git a/src/whisper_state/mod.rs b/src/whisper_state/mod.rs index 2b93b4a..2210b1f 100644 --- a/src/whisper_state/mod.rs +++ b/src/whisper_state/mod.rs @@ -280,7 +280,7 @@ impl WhisperState { /// See utilities in the root of this crate for functions to convert audio to this format. /// /// # Returns - /// Ok(()) on success, Err(WhisperError) on failure. + /// Ok(c_int) on success, Err(WhisperError) on failure. /// /// # C++ equivalent /// `int whisper_full_with_state( @@ -289,7 +289,7 @@ impl WhisperState { /// struct whisper_full_params params, /// const float * samples, /// int n_samples)` - pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<(), WhisperError> { + pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result { if data.is_empty() { // can randomly trigger segmentation faults if we don't check this return Err(WhisperError::NoSamples); @@ -311,7 +311,7 @@ impl WhisperState { } else if ret == 8 { Err(WhisperError::FailedToDecode) } else if ret == 0 { - Ok(()) + Ok(ret) } else { Err(WhisperError::GenericError(ret)) } diff --git a/sys/build.rs b/sys/build.rs index a7f411a..86d23d8 100644 --- a/sys/build.rs +++ b/sys/build.rs @@ -102,7 +102,7 @@ fn main() { println!("cargo:rerun-if-changed=wrapper.h"); let out = PathBuf::from(env::var("OUT_DIR").unwrap()); - let whisper_root = out.join("whisper.cpp"); + let whisper_root = out.join("whisper.cpp/"); if !whisper_root.exists() { std::fs::create_dir_all(&whisper_root).unwrap(); diff --git a/sys/src/bindings.rs b/sys/src/bindings.rs index 6ac1bb5..9f4be14 100644 --- a/sys/src/bindings.rs +++ b/sys/src/bindings.rs @@ -183,7 +183,7 @@ pub const __STDC_IEC_60559_COMPLEX__: u32 = 201404; pub const __STDC_ISO_10646__: u32 = 201706; pub const __GNU_LIBRARY__: u32 = 6; pub const __GLIBC__: u32 = 2; -pub const __GLIBC_MINOR__: u32 = 42; +pub const __GLIBC_MINOR__: u32 = 41; pub const _SYS_CDEFS_H: u32 = 1; pub const __glibc_c99_flexarr_available: u32 = 1; pub const __LDOUBLE_REDIRECTS_TO_FLOAT128_ABI: u32 = 0; @@ -302,11 +302,9 @@ pub const GGML_DEFAULT_GRAPH_SIZE: u32 = 2048; pub const GGML_MEM_ALIGN: u32 = 16; pub const GGML_EXIT_SUCCESS: u32 = 0; pub const GGML_EXIT_ABORTED: u32 = 1; -pub const GGML_ROPE_TYPE_NORMAL: u32 = 0; pub const GGML_ROPE_TYPE_NEOX: u32 = 2; pub const GGML_ROPE_TYPE_MROPE: u32 = 8; pub const GGML_ROPE_TYPE_VISION: u32 = 24; -pub const GGML_ROPE_TYPE_IMROPE: u32 = 40; pub const GGML_MROPE_SECTIONS: u32 = 4; pub const GGML_KQ_MASK_PAD: u32 = 64; pub const GGML_N_TASKS_MAX: i32 = -1; @@ -536,9 +534,7 @@ pub struct _IO_FILE { pub _freeres_buf: *mut ::std::os::raw::c_void, pub _prevchain: *mut *mut _IO_FILE, pub _mode: ::std::os::raw::c_int, - pub _unused3: ::std::os::raw::c_int, - pub _total_written: __uint64_t, - pub _unused2: [::std::os::raw::c_char; 8usize], + pub _unused2: [::std::os::raw::c_char; 20usize], } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { @@ -592,10 +588,7 @@ const _: () = { ["Offset of field: _IO_FILE::_prevchain"] [::std::mem::offset_of!(_IO_FILE, _prevchain) - 184usize]; ["Offset of field: _IO_FILE::_mode"][::std::mem::offset_of!(_IO_FILE, _mode) - 192usize]; - ["Offset of field: _IO_FILE::_unused3"][::std::mem::offset_of!(_IO_FILE, _unused3) - 196usize]; - ["Offset of field: _IO_FILE::_total_written"] - [::std::mem::offset_of!(_IO_FILE, _total_written) - 200usize]; - ["Offset of field: _IO_FILE::_unused2"][::std::mem::offset_of!(_IO_FILE, _unused2) - 208usize]; + ["Offset of field: _IO_FILE::_unused2"][::std::mem::offset_of!(_IO_FILE, _unused2) - 196usize]; }; impl _IO_FILE { #[inline] @@ -1321,85 +1314,79 @@ pub const ggml_op_GGML_OP_SIN: ggml_op = 12; pub const ggml_op_GGML_OP_COS: ggml_op = 13; pub const ggml_op_GGML_OP_SUM: ggml_op = 14; pub const ggml_op_GGML_OP_SUM_ROWS: ggml_op = 15; -pub const ggml_op_GGML_OP_CUMSUM: ggml_op = 16; -pub const ggml_op_GGML_OP_MEAN: ggml_op = 17; -pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 18; -pub const ggml_op_GGML_OP_COUNT_EQUAL: ggml_op = 19; -pub const ggml_op_GGML_OP_REPEAT: ggml_op = 20; -pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 21; -pub const ggml_op_GGML_OP_CONCAT: ggml_op = 22; -pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 23; -pub const ggml_op_GGML_OP_NORM: ggml_op = 24; -pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 25; -pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 26; -pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 27; -pub const ggml_op_GGML_OP_L2_NORM: ggml_op = 28; -pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 29; -pub const ggml_op_GGML_OP_MUL_MAT_ID: ggml_op = 30; -pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 31; -pub const ggml_op_GGML_OP_SCALE: ggml_op = 32; -pub const ggml_op_GGML_OP_SET: ggml_op = 33; -pub const ggml_op_GGML_OP_CPY: ggml_op = 34; -pub const ggml_op_GGML_OP_CONT: ggml_op = 35; -pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 36; -pub const ggml_op_GGML_OP_VIEW: ggml_op = 37; -pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 38; -pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 39; -pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 40; -pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 41; -pub const ggml_op_GGML_OP_SET_ROWS: ggml_op = 42; -pub const ggml_op_GGML_OP_DIAG: ggml_op = 43; -pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 44; -pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 45; -pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 46; -pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 47; -pub const ggml_op_GGML_OP_ROPE: ggml_op = 48; -pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 49; -pub const ggml_op_GGML_OP_CLAMP: ggml_op = 50; -pub const ggml_op_GGML_OP_CONV_TRANSPOSE_1D: ggml_op = 51; -pub const ggml_op_GGML_OP_IM2COL: ggml_op = 52; -pub const ggml_op_GGML_OP_IM2COL_BACK: ggml_op = 53; -pub const ggml_op_GGML_OP_IM2COL_3D: ggml_op = 54; -pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 55; -pub const ggml_op_GGML_OP_CONV_3D: ggml_op = 56; -pub const ggml_op_GGML_OP_CONV_2D_DW: ggml_op = 57; -pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 58; -pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 59; -pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 60; -pub const ggml_op_GGML_OP_POOL_2D_BACK: ggml_op = 61; -pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 62; -pub const ggml_op_GGML_OP_PAD: ggml_op = 63; -pub const ggml_op_GGML_OP_PAD_REFLECT_1D: ggml_op = 64; -pub const ggml_op_GGML_OP_ROLL: ggml_op = 65; -pub const ggml_op_GGML_OP_ARANGE: ggml_op = 66; -pub const ggml_op_GGML_OP_TIMESTEP_EMBEDDING: ggml_op = 67; -pub const ggml_op_GGML_OP_ARGSORT: ggml_op = 68; -pub const ggml_op_GGML_OP_LEAKY_RELU: ggml_op = 69; -pub const ggml_op_GGML_OP_TRI: ggml_op = 70; -pub const ggml_op_GGML_OP_FILL: ggml_op = 71; -pub const ggml_op_GGML_OP_FLASH_ATTN_EXT: ggml_op = 72; -pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 73; -pub const ggml_op_GGML_OP_SSM_CONV: ggml_op = 74; -pub const ggml_op_GGML_OP_SSM_SCAN: ggml_op = 75; -pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 76; -pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 77; -pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 78; -pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 79; -pub const ggml_op_GGML_OP_RWKV_WKV6: ggml_op = 80; -pub const ggml_op_GGML_OP_GATED_LINEAR_ATTN: ggml_op = 81; -pub const ggml_op_GGML_OP_RWKV_WKV7: ggml_op = 82; -pub const ggml_op_GGML_OP_SOLVE_TRI: ggml_op = 83; -pub const ggml_op_GGML_OP_UNARY: ggml_op = 84; -pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 85; -pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 86; -pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 87; -pub const ggml_op_GGML_OP_CUSTOM: ggml_op = 88; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 89; -pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 90; -pub const ggml_op_GGML_OP_OPT_STEP_ADAMW: ggml_op = 91; -pub const ggml_op_GGML_OP_OPT_STEP_SGD: ggml_op = 92; -pub const ggml_op_GGML_OP_GLU: ggml_op = 93; -pub const ggml_op_GGML_OP_COUNT: ggml_op = 94; +pub const ggml_op_GGML_OP_MEAN: ggml_op = 16; +pub const ggml_op_GGML_OP_ARGMAX: ggml_op = 17; +pub const ggml_op_GGML_OP_COUNT_EQUAL: ggml_op = 18; +pub const ggml_op_GGML_OP_REPEAT: ggml_op = 19; +pub const ggml_op_GGML_OP_REPEAT_BACK: ggml_op = 20; +pub const ggml_op_GGML_OP_CONCAT: ggml_op = 21; +pub const ggml_op_GGML_OP_SILU_BACK: ggml_op = 22; +pub const ggml_op_GGML_OP_NORM: ggml_op = 23; +pub const ggml_op_GGML_OP_RMS_NORM: ggml_op = 24; +pub const ggml_op_GGML_OP_RMS_NORM_BACK: ggml_op = 25; +pub const ggml_op_GGML_OP_GROUP_NORM: ggml_op = 26; +pub const ggml_op_GGML_OP_L2_NORM: ggml_op = 27; +pub const ggml_op_GGML_OP_MUL_MAT: ggml_op = 28; +pub const ggml_op_GGML_OP_MUL_MAT_ID: ggml_op = 29; +pub const ggml_op_GGML_OP_OUT_PROD: ggml_op = 30; +pub const ggml_op_GGML_OP_SCALE: ggml_op = 31; +pub const ggml_op_GGML_OP_SET: ggml_op = 32; +pub const ggml_op_GGML_OP_CPY: ggml_op = 33; +pub const ggml_op_GGML_OP_CONT: ggml_op = 34; +pub const ggml_op_GGML_OP_RESHAPE: ggml_op = 35; +pub const ggml_op_GGML_OP_VIEW: ggml_op = 36; +pub const ggml_op_GGML_OP_PERMUTE: ggml_op = 37; +pub const ggml_op_GGML_OP_TRANSPOSE: ggml_op = 38; +pub const ggml_op_GGML_OP_GET_ROWS: ggml_op = 39; +pub const ggml_op_GGML_OP_GET_ROWS_BACK: ggml_op = 40; +pub const ggml_op_GGML_OP_SET_ROWS: ggml_op = 41; +pub const ggml_op_GGML_OP_DIAG: ggml_op = 42; +pub const ggml_op_GGML_OP_DIAG_MASK_INF: ggml_op = 43; +pub const ggml_op_GGML_OP_DIAG_MASK_ZERO: ggml_op = 44; +pub const ggml_op_GGML_OP_SOFT_MAX: ggml_op = 45; +pub const ggml_op_GGML_OP_SOFT_MAX_BACK: ggml_op = 46; +pub const ggml_op_GGML_OP_ROPE: ggml_op = 47; +pub const ggml_op_GGML_OP_ROPE_BACK: ggml_op = 48; +pub const ggml_op_GGML_OP_CLAMP: ggml_op = 49; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_1D: ggml_op = 50; +pub const ggml_op_GGML_OP_IM2COL: ggml_op = 51; +pub const ggml_op_GGML_OP_IM2COL_BACK: ggml_op = 52; +pub const ggml_op_GGML_OP_CONV_2D: ggml_op = 53; +pub const ggml_op_GGML_OP_CONV_2D_DW: ggml_op = 54; +pub const ggml_op_GGML_OP_CONV_TRANSPOSE_2D: ggml_op = 55; +pub const ggml_op_GGML_OP_POOL_1D: ggml_op = 56; +pub const ggml_op_GGML_OP_POOL_2D: ggml_op = 57; +pub const ggml_op_GGML_OP_POOL_2D_BACK: ggml_op = 58; +pub const ggml_op_GGML_OP_UPSCALE: ggml_op = 59; +pub const ggml_op_GGML_OP_PAD: ggml_op = 60; +pub const ggml_op_GGML_OP_PAD_REFLECT_1D: ggml_op = 61; +pub const ggml_op_GGML_OP_ROLL: ggml_op = 62; +pub const ggml_op_GGML_OP_ARANGE: ggml_op = 63; +pub const ggml_op_GGML_OP_TIMESTEP_EMBEDDING: ggml_op = 64; +pub const ggml_op_GGML_OP_ARGSORT: ggml_op = 65; +pub const ggml_op_GGML_OP_LEAKY_RELU: ggml_op = 66; +pub const ggml_op_GGML_OP_FLASH_ATTN_EXT: ggml_op = 67; +pub const ggml_op_GGML_OP_FLASH_ATTN_BACK: ggml_op = 68; +pub const ggml_op_GGML_OP_SSM_CONV: ggml_op = 69; +pub const ggml_op_GGML_OP_SSM_SCAN: ggml_op = 70; +pub const ggml_op_GGML_OP_WIN_PART: ggml_op = 71; +pub const ggml_op_GGML_OP_WIN_UNPART: ggml_op = 72; +pub const ggml_op_GGML_OP_GET_REL_POS: ggml_op = 73; +pub const ggml_op_GGML_OP_ADD_REL_POS: ggml_op = 74; +pub const ggml_op_GGML_OP_RWKV_WKV6: ggml_op = 75; +pub const ggml_op_GGML_OP_GATED_LINEAR_ATTN: ggml_op = 76; +pub const ggml_op_GGML_OP_RWKV_WKV7: ggml_op = 77; +pub const ggml_op_GGML_OP_UNARY: ggml_op = 78; +pub const ggml_op_GGML_OP_MAP_CUSTOM1: ggml_op = 79; +pub const ggml_op_GGML_OP_MAP_CUSTOM2: ggml_op = 80; +pub const ggml_op_GGML_OP_MAP_CUSTOM3: ggml_op = 81; +pub const ggml_op_GGML_OP_CUSTOM: ggml_op = 82; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS: ggml_op = 83; +pub const ggml_op_GGML_OP_CROSS_ENTROPY_LOSS_BACK: ggml_op = 84; +pub const ggml_op_GGML_OP_OPT_STEP_ADAMW: ggml_op = 85; +pub const ggml_op_GGML_OP_OPT_STEP_SGD: ggml_op = 86; +pub const ggml_op_GGML_OP_GLU: ggml_op = 87; +pub const ggml_op_GGML_OP_COUNT: ggml_op = 88; pub type ggml_op = ::std::os::raw::c_uint; pub const ggml_unary_op_GGML_UNARY_OP_ABS: ggml_unary_op = 0; pub const ggml_unary_op_GGML_UNARY_OP_SGN: ggml_unary_op = 1; @@ -1415,15 +1402,8 @@ pub const ggml_unary_op_GGML_UNARY_OP_SILU: ggml_unary_op = 10; pub const ggml_unary_op_GGML_UNARY_OP_HARDSWISH: ggml_unary_op = 11; pub const ggml_unary_op_GGML_UNARY_OP_HARDSIGMOID: ggml_unary_op = 12; pub const ggml_unary_op_GGML_UNARY_OP_EXP: ggml_unary_op = 13; -pub const ggml_unary_op_GGML_UNARY_OP_EXPM1: ggml_unary_op = 14; -pub const ggml_unary_op_GGML_UNARY_OP_SOFTPLUS: ggml_unary_op = 15; -pub const ggml_unary_op_GGML_UNARY_OP_GELU_ERF: ggml_unary_op = 16; -pub const ggml_unary_op_GGML_UNARY_OP_XIELU: ggml_unary_op = 17; -pub const ggml_unary_op_GGML_UNARY_OP_FLOOR: ggml_unary_op = 18; -pub const ggml_unary_op_GGML_UNARY_OP_CEIL: ggml_unary_op = 19; -pub const ggml_unary_op_GGML_UNARY_OP_ROUND: ggml_unary_op = 20; -pub const ggml_unary_op_GGML_UNARY_OP_TRUNC: ggml_unary_op = 21; -pub const ggml_unary_op_GGML_UNARY_OP_COUNT: ggml_unary_op = 22; +pub const ggml_unary_op_GGML_UNARY_OP_GELU_ERF: ggml_unary_op = 14; +pub const ggml_unary_op_GGML_UNARY_OP_COUNT: ggml_unary_op = 15; pub type ggml_unary_op = ::std::os::raw::c_uint; pub const ggml_glu_op_GGML_GLU_OP_REGLU: ggml_glu_op = 0; pub const ggml_glu_op_GGML_GLU_OP_GEGLU: ggml_glu_op = 1; @@ -1449,11 +1429,6 @@ pub const ggml_tensor_flag_GGML_TENSOR_FLAG_OUTPUT: ggml_tensor_flag = 2; pub const ggml_tensor_flag_GGML_TENSOR_FLAG_PARAM: ggml_tensor_flag = 4; pub const ggml_tensor_flag_GGML_TENSOR_FLAG_LOSS: ggml_tensor_flag = 8; pub type ggml_tensor_flag = ::std::os::raw::c_uint; -pub const ggml_tri_type_GGML_TRI_TYPE_UPPER_DIAG: ggml_tri_type = 0; -pub const ggml_tri_type_GGML_TRI_TYPE_UPPER: ggml_tri_type = 1; -pub const ggml_tri_type_GGML_TRI_TYPE_LOWER_DIAG: ggml_tri_type = 2; -pub const ggml_tri_type_GGML_TRI_TYPE_LOWER: ggml_tri_type = 3; -pub type ggml_tri_type = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct ggml_init_params { @@ -1944,18 +1919,6 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_log_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_expm1(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_expm1_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_softplus(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_softplus_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_sin(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } @@ -1974,9 +1937,6 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_sum_rows(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_cumsum(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_mean(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } @@ -2122,41 +2082,6 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_exp_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_floor(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_floor_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_ceil(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_ceil_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_round(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_round_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - #[doc = " Truncates the fractional part of each element in the tensor (towards zero).\n For example: trunc(3.7) = 3.0, trunc(-2.9) = -2.0\n Similar to std::trunc in C/C++."] - pub fn ggml_trunc(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_trunc_inplace(ctx: *mut ggml_context, a: *mut ggml_tensor) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_xielu( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - alpha_n: f32, - alpha_p: f32, - beta: f32, - eps: f32, - ) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_glu( ctx: *mut ggml_context, @@ -2626,15 +2551,6 @@ unsafe extern "C" { max_bias: f32, ) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_soft_max_ext_inplace( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - mask: *mut ggml_tensor, - scale: f32, - max_bias: f32, - ) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_soft_max_add_sinks(a: *mut ggml_tensor, sinks: *mut ggml_tensor); } @@ -2920,41 +2836,6 @@ unsafe extern "C" { d1: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_im2col_3d( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - b: *mut ggml_tensor, - IC: i64, - s0: ::std::os::raw::c_int, - s1: ::std::os::raw::c_int, - s2: ::std::os::raw::c_int, - p0: ::std::os::raw::c_int, - p1: ::std::os::raw::c_int, - p2: ::std::os::raw::c_int, - d0: ::std::os::raw::c_int, - d1: ::std::os::raw::c_int, - d2: ::std::os::raw::c_int, - dst_type: ggml_type, - ) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_conv_3d( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - b: *mut ggml_tensor, - IC: i64, - s0: ::std::os::raw::c_int, - s1: ::std::os::raw::c_int, - s2: ::std::os::raw::c_int, - p0: ::std::os::raw::c_int, - p1: ::std::os::raw::c_int, - p2: ::std::os::raw::c_int, - d0: ::std::os::raw::c_int, - d1: ::std::os::raw::c_int, - d2: ::std::os::raw::c_int, - ) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_conv_2d_sk_p0( ctx: *mut ggml_context, @@ -3016,25 +2897,6 @@ unsafe extern "C" { d1: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_conv_3d_direct( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - b: *mut ggml_tensor, - s0: ::std::os::raw::c_int, - s1: ::std::os::raw::c_int, - s2: ::std::os::raw::c_int, - p0: ::std::os::raw::c_int, - p1: ::std::os::raw::c_int, - p2: ::std::os::raw::c_int, - d0: ::std::os::raw::c_int, - d1: ::std::os::raw::c_int, - d2: ::std::os::raw::c_int, - n_channels: ::std::os::raw::c_int, - n_batch: ::std::os::raw::c_int, - n_channels_out: ::std::os::raw::c_int, - ) -> *mut ggml_tensor; -} pub const ggml_op_pool_GGML_OP_POOL_MAX: ggml_op_pool = 0; pub const ggml_op_pool_GGML_OP_POOL_AVG: ggml_op_pool = 1; pub const ggml_op_pool_GGML_OP_POOL_COUNT: ggml_op_pool = 2; @@ -3078,8 +2940,7 @@ unsafe extern "C" { } pub const ggml_scale_mode_GGML_SCALE_MODE_NEAREST: ggml_scale_mode = 0; pub const ggml_scale_mode_GGML_SCALE_MODE_BILINEAR: ggml_scale_mode = 1; -pub const ggml_scale_mode_GGML_SCALE_MODE_BICUBIC: ggml_scale_mode = 2; -pub const ggml_scale_mode_GGML_SCALE_MODE_COUNT: ggml_scale_mode = 3; +pub const ggml_scale_mode_GGML_SCALE_MODE_COUNT: ggml_scale_mode = 2; pub type ggml_scale_mode = ::std::os::raw::c_uint; pub const ggml_scale_flag_GGML_SCALE_FLAG_ALIGN_CORNERS: ggml_scale_flag = 256; pub type ggml_scale_flag = ::std::os::raw::c_uint; @@ -3123,20 +2984,6 @@ unsafe extern "C" { p3: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_pad_ext( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - lp0: ::std::os::raw::c_int, - rp0: ::std::os::raw::c_int, - lp1: ::std::os::raw::c_int, - rp1: ::std::os::raw::c_int, - lp2: ::std::os::raw::c_int, - rp2: ::std::os::raw::c_int, - lp3: ::std::os::raw::c_int, - rp3: ::std::os::raw::c_int, - ) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_pad_reflect_1d( ctx: *mut ggml_context, @@ -3155,23 +3002,6 @@ unsafe extern "C" { shift3: ::std::os::raw::c_int, ) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_tri( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - type_: ggml_tri_type, - ) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_fill(ctx: *mut ggml_context, a: *mut ggml_tensor, c: f32) -> *mut ggml_tensor; -} -unsafe extern "C" { - pub fn ggml_fill_inplace( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - c: f32, - ) -> *mut ggml_tensor; -} unsafe extern "C" { pub fn ggml_timestep_embedding( ctx: *mut ggml_context, @@ -3343,16 +3173,6 @@ unsafe extern "C" { state: *mut ggml_tensor, ) -> *mut ggml_tensor; } -unsafe extern "C" { - pub fn ggml_solve_tri( - ctx: *mut ggml_context, - a: *mut ggml_tensor, - b: *mut ggml_tensor, - left: bool, - lower: bool, - uni: bool, - ) -> *mut ggml_tensor; -} pub type ggml_custom1_op_t = ::std::option::Option< unsafe extern "C" fn( dst: *mut ggml_tensor, @@ -4064,8 +3884,7 @@ unsafe extern "C" { } pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_CPU: ggml_backend_dev_type = 0; pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_GPU: ggml_backend_dev_type = 1; -pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_IGPU: ggml_backend_dev_type = 2; -pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_ACCEL: ggml_backend_dev_type = 3; +pub const ggml_backend_dev_type_GGML_BACKEND_DEVICE_TYPE_ACCEL: ggml_backend_dev_type = 2; pub type ggml_backend_dev_type = ::std::os::raw::c_uint; #[repr(C)] #[derive(Debug, Copy, Clone)] @@ -4097,12 +3916,11 @@ pub struct ggml_backend_dev_props { pub memory_free: usize, pub memory_total: usize, pub type_: ggml_backend_dev_type, - pub device_id: *const ::std::os::raw::c_char, pub caps: ggml_backend_dev_caps, } #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of ggml_backend_dev_props"][::std::mem::size_of::() - 56usize]; + ["Size of ggml_backend_dev_props"][::std::mem::size_of::() - 40usize]; ["Alignment of ggml_backend_dev_props"] [::std::mem::align_of::() - 8usize]; ["Offset of field: ggml_backend_dev_props::name"] @@ -4115,10 +3933,8 @@ const _: () = { [::std::mem::offset_of!(ggml_backend_dev_props, memory_total) - 24usize]; ["Offset of field: ggml_backend_dev_props::type_"] [::std::mem::offset_of!(ggml_backend_dev_props, type_) - 32usize]; - ["Offset of field: ggml_backend_dev_props::device_id"] - [::std::mem::offset_of!(ggml_backend_dev_props, device_id) - 40usize]; ["Offset of field: ggml_backend_dev_props::caps"] - [::std::mem::offset_of!(ggml_backend_dev_props, caps) - 48usize]; + [::std::mem::offset_of!(ggml_backend_dev_props, caps) - 36usize]; }; unsafe extern "C" { pub fn ggml_backend_dev_name(device: ggml_backend_dev_t) -> *const ::std::os::raw::c_char; @@ -4230,9 +4046,6 @@ const _: () = { pub type ggml_backend_get_features_t = ::std::option::Option< unsafe extern "C" fn(reg: ggml_backend_reg_t) -> *mut ggml_backend_feature, >; -unsafe extern "C" { - pub fn ggml_backend_register(reg: ggml_backend_reg_t); -} unsafe extern "C" { pub fn ggml_backend_device_register(device: ggml_backend_dev_t); } @@ -4331,12 +4144,6 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_backend_sched_get_n_copies(sched: ggml_backend_sched_t) -> ::std::os::raw::c_int; } -unsafe extern "C" { - pub fn ggml_backend_sched_get_buffer_type( - sched: ggml_backend_sched_t, - backend: ggml_backend_t, - ) -> ggml_backend_buffer_type_t; -} unsafe extern "C" { pub fn ggml_backend_sched_get_buffer_size( sched: ggml_backend_sched_t, @@ -4356,9 +4163,6 @@ unsafe extern "C" { node: *mut ggml_tensor, ) -> ggml_backend_t; } -unsafe extern "C" { - pub fn ggml_backend_sched_split_graph(sched: ggml_backend_sched_t, graph: *mut ggml_cgraph); -} unsafe extern "C" { pub fn ggml_backend_sched_alloc_graph( sched: ggml_backend_sched_t, @@ -4665,6 +4469,9 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_cpu_has_vxe() -> ::std::os::raw::c_int; } +unsafe extern "C" { + pub fn ggml_cpu_has_nnpa() -> ::std::os::raw::c_int; +} unsafe extern "C" { pub fn ggml_cpu_has_wasm_simd() -> ::std::os::raw::c_int; } @@ -4741,9 +4548,6 @@ unsafe extern "C" { unsafe extern "C" { pub fn ggml_cpu_fp32_to_fp32(arg1: *const f32, arg2: *mut f32, arg3: i64); } -unsafe extern "C" { - pub fn ggml_cpu_fp32_to_i32(arg1: *const f32, arg2: *mut i32, arg3: i64); -} unsafe extern "C" { pub fn ggml_cpu_fp32_to_fp16(arg1: *const f32, arg2: *mut ggml_fp16_t, arg3: i64); } @@ -5386,7 +5190,6 @@ pub struct whisper_full_params { pub tdrz_enable: bool, pub suppress_regex: *const ::std::os::raw::c_char, pub initial_prompt: *const ::std::os::raw::c_char, - pub carry_initial_prompt: bool, pub prompt_tokens: *const whisper_token, pub prompt_n_tokens: ::std::os::raw::c_int, pub language: *const ::std::os::raw::c_char, @@ -5453,7 +5256,7 @@ const _: () = { }; #[allow(clippy::unnecessary_operation, clippy::identity_op)] const _: () = { - ["Size of whisper_full_params"][::std::mem::size_of::() - 304usize]; + ["Size of whisper_full_params"][::std::mem::size_of::() - 296usize]; ["Alignment of whisper_full_params"][::std::mem::align_of::() - 8usize]; ["Offset of field: whisper_full_params::strategy"] [::std::mem::offset_of!(whisper_full_params, strategy) - 0usize]; @@ -5503,72 +5306,70 @@ const _: () = { [::std::mem::offset_of!(whisper_full_params, suppress_regex) - 64usize]; ["Offset of field: whisper_full_params::initial_prompt"] [::std::mem::offset_of!(whisper_full_params, initial_prompt) - 72usize]; - ["Offset of field: whisper_full_params::carry_initial_prompt"] - [::std::mem::offset_of!(whisper_full_params, carry_initial_prompt) - 80usize]; ["Offset of field: whisper_full_params::prompt_tokens"] - [::std::mem::offset_of!(whisper_full_params, prompt_tokens) - 88usize]; + [::std::mem::offset_of!(whisper_full_params, prompt_tokens) - 80usize]; ["Offset of field: whisper_full_params::prompt_n_tokens"] - [::std::mem::offset_of!(whisper_full_params, prompt_n_tokens) - 96usize]; + [::std::mem::offset_of!(whisper_full_params, prompt_n_tokens) - 88usize]; ["Offset of field: whisper_full_params::language"] - [::std::mem::offset_of!(whisper_full_params, language) - 104usize]; + [::std::mem::offset_of!(whisper_full_params, language) - 96usize]; ["Offset of field: whisper_full_params::detect_language"] - [::std::mem::offset_of!(whisper_full_params, detect_language) - 112usize]; + [::std::mem::offset_of!(whisper_full_params, detect_language) - 104usize]; ["Offset of field: whisper_full_params::suppress_blank"] - [::std::mem::offset_of!(whisper_full_params, suppress_blank) - 113usize]; + [::std::mem::offset_of!(whisper_full_params, suppress_blank) - 105usize]; ["Offset of field: whisper_full_params::suppress_nst"] - [::std::mem::offset_of!(whisper_full_params, suppress_nst) - 114usize]; + [::std::mem::offset_of!(whisper_full_params, suppress_nst) - 106usize]; ["Offset of field: whisper_full_params::temperature"] - [::std::mem::offset_of!(whisper_full_params, temperature) - 116usize]; + [::std::mem::offset_of!(whisper_full_params, temperature) - 108usize]; ["Offset of field: whisper_full_params::max_initial_ts"] - [::std::mem::offset_of!(whisper_full_params, max_initial_ts) - 120usize]; + [::std::mem::offset_of!(whisper_full_params, max_initial_ts) - 112usize]; ["Offset of field: whisper_full_params::length_penalty"] - [::std::mem::offset_of!(whisper_full_params, length_penalty) - 124usize]; + [::std::mem::offset_of!(whisper_full_params, length_penalty) - 116usize]; ["Offset of field: whisper_full_params::temperature_inc"] - [::std::mem::offset_of!(whisper_full_params, temperature_inc) - 128usize]; + [::std::mem::offset_of!(whisper_full_params, temperature_inc) - 120usize]; ["Offset of field: whisper_full_params::entropy_thold"] - [::std::mem::offset_of!(whisper_full_params, entropy_thold) - 132usize]; + [::std::mem::offset_of!(whisper_full_params, entropy_thold) - 124usize]; ["Offset of field: whisper_full_params::logprob_thold"] - [::std::mem::offset_of!(whisper_full_params, logprob_thold) - 136usize]; + [::std::mem::offset_of!(whisper_full_params, logprob_thold) - 128usize]; ["Offset of field: whisper_full_params::no_speech_thold"] - [::std::mem::offset_of!(whisper_full_params, no_speech_thold) - 140usize]; + [::std::mem::offset_of!(whisper_full_params, no_speech_thold) - 132usize]; ["Offset of field: whisper_full_params::greedy"] - [::std::mem::offset_of!(whisper_full_params, greedy) - 144usize]; + [::std::mem::offset_of!(whisper_full_params, greedy) - 136usize]; ["Offset of field: whisper_full_params::beam_search"] - [::std::mem::offset_of!(whisper_full_params, beam_search) - 148usize]; + [::std::mem::offset_of!(whisper_full_params, beam_search) - 140usize]; ["Offset of field: whisper_full_params::new_segment_callback"] - [::std::mem::offset_of!(whisper_full_params, new_segment_callback) - 160usize]; + [::std::mem::offset_of!(whisper_full_params, new_segment_callback) - 152usize]; ["Offset of field: whisper_full_params::new_segment_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, new_segment_callback_user_data) - 168usize]; + [::std::mem::offset_of!(whisper_full_params, new_segment_callback_user_data) - 160usize]; ["Offset of field: whisper_full_params::progress_callback"] - [::std::mem::offset_of!(whisper_full_params, progress_callback) - 176usize]; + [::std::mem::offset_of!(whisper_full_params, progress_callback) - 168usize]; ["Offset of field: whisper_full_params::progress_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, progress_callback_user_data) - 184usize]; + [::std::mem::offset_of!(whisper_full_params, progress_callback_user_data) - 176usize]; ["Offset of field: whisper_full_params::encoder_begin_callback"] - [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback) - 192usize]; + [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback) - 184usize]; ["Offset of field: whisper_full_params::encoder_begin_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback_user_data) - 200usize]; + [::std::mem::offset_of!(whisper_full_params, encoder_begin_callback_user_data) - 192usize]; ["Offset of field: whisper_full_params::abort_callback"] - [::std::mem::offset_of!(whisper_full_params, abort_callback) - 208usize]; + [::std::mem::offset_of!(whisper_full_params, abort_callback) - 200usize]; ["Offset of field: whisper_full_params::abort_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, abort_callback_user_data) - 216usize]; + [::std::mem::offset_of!(whisper_full_params, abort_callback_user_data) - 208usize]; ["Offset of field: whisper_full_params::logits_filter_callback"] - [::std::mem::offset_of!(whisper_full_params, logits_filter_callback) - 224usize]; + [::std::mem::offset_of!(whisper_full_params, logits_filter_callback) - 216usize]; ["Offset of field: whisper_full_params::logits_filter_callback_user_data"] - [::std::mem::offset_of!(whisper_full_params, logits_filter_callback_user_data) - 232usize]; + [::std::mem::offset_of!(whisper_full_params, logits_filter_callback_user_data) - 224usize]; ["Offset of field: whisper_full_params::grammar_rules"] - [::std::mem::offset_of!(whisper_full_params, grammar_rules) - 240usize]; + [::std::mem::offset_of!(whisper_full_params, grammar_rules) - 232usize]; ["Offset of field: whisper_full_params::n_grammar_rules"] - [::std::mem::offset_of!(whisper_full_params, n_grammar_rules) - 248usize]; + [::std::mem::offset_of!(whisper_full_params, n_grammar_rules) - 240usize]; ["Offset of field: whisper_full_params::i_start_rule"] - [::std::mem::offset_of!(whisper_full_params, i_start_rule) - 256usize]; + [::std::mem::offset_of!(whisper_full_params, i_start_rule) - 248usize]; ["Offset of field: whisper_full_params::grammar_penalty"] - [::std::mem::offset_of!(whisper_full_params, grammar_penalty) - 264usize]; + [::std::mem::offset_of!(whisper_full_params, grammar_penalty) - 256usize]; ["Offset of field: whisper_full_params::vad"] - [::std::mem::offset_of!(whisper_full_params, vad) - 268usize]; + [::std::mem::offset_of!(whisper_full_params, vad) - 260usize]; ["Offset of field: whisper_full_params::vad_model_path"] - [::std::mem::offset_of!(whisper_full_params, vad_model_path) - 272usize]; + [::std::mem::offset_of!(whisper_full_params, vad_model_path) - 264usize]; ["Offset of field: whisper_full_params::vad_params"] - [::std::mem::offset_of!(whisper_full_params, vad_params) - 280usize]; + [::std::mem::offset_of!(whisper_full_params, vad_params) - 272usize]; }; unsafe extern "C" { pub fn whisper_context_default_params_by_ref() -> *mut whisper_context_params; diff --git a/sys/whisper.cpp b/sys/whisper.cpp index a88b93f..fc45bb8 160000 --- a/sys/whisper.cpp +++ b/sys/whisper.cpp @@ -1 +1 @@ -Subproject commit a88b93f85f08fc6045e5d8a8c3f94b7be0ac8bce +Subproject commit fc45bb86251f774ef817e89878bb4c2636c8a58f diff --git a/wyoming-whisper-rs/Cargo.toml b/wyoming-whisper-rs/Cargo.toml deleted file mode 100644 index c6c55d8..0000000 --- a/wyoming-whisper-rs/Cargo.toml +++ /dev/null @@ -1,27 +0,0 @@ -[package] -name = "wyoming-whisper-rs" -version = "0.1.0" -edition = "2021" -description = "Wyoming protocol ASR server powered by whisper-rs" -license = "Unlicense" - -[dependencies] -whisper-rs = { path = "..", features = ["tracing_backend"] } -tokio = { version = "1", features = ["rt-multi-thread", "net", "io-util", "macros", "signal", "sync"] } -clap = { version = "4", features = ["derive"] } -serde = { version = "1", features = ["derive"] } -serde_json = "1" -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } -thiserror = "2" -anyhow = "1" - -[features] -default = [] -cuda = ["whisper-rs/cuda"] -hipblas = ["whisper-rs/hipblas"] -metal = ["whisper-rs/metal"] -vulkan = ["whisper-rs/vulkan"] -openblas = ["whisper-rs/openblas"] -openmp = ["whisper-rs/openmp"] -coreml = ["whisper-rs/coreml"] diff --git a/wyoming-whisper-rs/package.nix b/wyoming-whisper-rs/package.nix deleted file mode 100644 index 9b522ab..0000000 --- a/wyoming-whisper-rs/package.nix +++ /dev/null @@ -1,52 +0,0 @@ -{ - lib, - rustPlatform, - cmake, - pkg-config, - shaderc, - libclang, - glibc, - vulkan-headers, - vulkan-loader, - src, -}: - -rustPlatform.buildRustPackage { - pname = "wyoming-whisper-rs"; - version = "0.1.0"; - - inherit src; - - cargoLock.lockFile = ../Cargo.lock; - - buildAndTestSubdir = "wyoming-whisper-rs"; - - buildFeatures = [ "vulkan" ]; - - nativeBuildInputs = [ - cmake - pkg-config - shaderc - ]; - - buildInputs = [ - libclang.lib - vulkan-headers - vulkan-loader - ]; - - env = { - LIBCLANG_PATH = "${libclang.lib}/lib"; - VULKAN_SDK = "${vulkan-headers}"; - BINDGEN_EXTRA_CLANG_ARGS = builtins.toString [ - "-isystem ${libclang.lib}/lib/clang/${lib.versions.major libclang.version}/include" - "-isystem ${glibc.dev}/include" - ]; - }; - - meta = { - description = "Wyoming protocol ASR server powered by whisper-rs"; - license = lib.licenses.unlicense; - mainProgram = "wyoming-whisper-rs"; - }; -} diff --git a/wyoming-whisper-rs/src/cli.rs b/wyoming-whisper-rs/src/cli.rs deleted file mode 100644 index e4b7328..0000000 --- a/wyoming-whisper-rs/src/cli.rs +++ /dev/null @@ -1,49 +0,0 @@ -use clap::Parser; -use std::path::PathBuf; - -#[derive(Parser, Debug)] -#[command( - name = "wyoming-whisper-rs", - about = "Wyoming protocol ASR server powered by whisper-rs" -)] -pub struct Args { - /// Path to the GGML whisper model file - #[arg(long)] - pub model: PathBuf, - - /// Host address to bind to - #[arg(long, default_value = "0.0.0.0")] - pub host: String, - - /// Port to listen on - #[arg(long, default_value_t = 10300)] - pub port: u16, - - /// Language code (e.g. "en", "de"). Omit for auto-detection - #[arg(long)] - pub language: Option, - - /// Beam search size - #[arg(long, default_value_t = 5)] - pub beam_size: i32, - - /// Number of threads for whisper inference (0 = whisper default) - #[arg(long, default_value_t = 0)] - pub threads: i32, - - /// GPU device index - #[arg(long, default_value_t = 0)] - pub gpu_device: i32, - - /// Disable GPU acceleration - #[arg(long)] - pub no_gpu: bool, - - /// Maximum concurrent transcriptions - #[arg(long, default_value_t = 1)] - pub max_concurrent: usize, - - /// Model name reported in Wyoming info - #[arg(long, default_value = "whisper")] - pub model_name: String, -} diff --git a/wyoming-whisper-rs/src/error.rs b/wyoming-whisper-rs/src/error.rs deleted file mode 100644 index bd83731..0000000 --- a/wyoming-whisper-rs/src/error.rs +++ /dev/null @@ -1,22 +0,0 @@ -use std::io; - -#[derive(Debug, thiserror::Error)] -pub enum Error { - #[error("I/O error: {0}")] - Io(#[from] io::Error), - - #[error("JSON error: {0}")] - Json(#[from] serde_json::Error), - - #[error("Whisper error: {0}")] - Whisper(#[from] whisper_rs::WhisperError), - - #[error("payload too large: {size} bytes (max {max})")] - PayloadTooLarge { size: usize, max: usize }, - - #[error("missing field: {0}")] - MissingField(&'static str), - - #[error("invalid audio format: {0}")] - InvalidAudio(String), -} diff --git a/wyoming-whisper-rs/src/main.rs b/wyoming-whisper-rs/src/main.rs deleted file mode 100644 index 404ccce..0000000 --- a/wyoming-whisper-rs/src/main.rs +++ /dev/null @@ -1,111 +0,0 @@ -mod cli; -mod error; -mod protocol; -mod session; -mod transcribe; - -use std::sync::Arc; - -use clap::Parser; -use serde_json::json; -use tokio::net::TcpListener; -use tracing::{error, info}; -use whisper_rs::{WhisperContext, WhisperContextParameters}; - -use crate::cli::Args; -use crate::session::SessionConfig; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let args = Args::parse(); - - tracing_subscriber::fmt() - .with_env_filter( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")), - ) - .init(); - - // Validate model path - if !args.model.exists() { - error!(path = %args.model.display(), "model file not found"); - std::process::exit(1); - } - - // Validate language if specified - if let Some(ref lang) = args.language { - if whisper_rs::get_lang_id(lang).is_none() { - error!(language = lang, "unknown language code"); - std::process::exit(1); - } - } - - // Route whisper.cpp logs through tracing - whisper_rs::install_logging_hooks(); - - // Load model - info!(model = %args.model.display(), "loading whisper model"); - let mut ctx_params = WhisperContextParameters::default(); - if args.no_gpu { - ctx_params.use_gpu(false); - } - ctx_params.gpu_device(args.gpu_device); - - let model_path = args.model.to_string_lossy().to_string(); - let ctx = WhisperContext::new_with_params(&model_path, ctx_params)?; - let ctx = Arc::new(ctx); - info!("model loaded"); - - // Build language list - let max_id = whisper_rs::get_lang_max_id(); - let languages: Vec = (0..=max_id) - .filter_map(|id| { - whisper_rs::get_lang_str(id).map(|code| { - json!({ - "name": code, - "description": whisper_rs::get_lang_str_full(id).unwrap_or(code), - }) - }) - }) - .collect(); - - let semaphore = Arc::new(tokio::sync::Semaphore::new(args.max_concurrent)); - - let session_config = Arc::new(SessionConfig { - ctx, - semaphore, - model_name: args.model_name, - languages, - default_language: args.language, - beam_size: args.beam_size, - threads: args.threads, - }); - - let bind_addr = format!("{}:{}", args.host, args.port); - let listener = TcpListener::bind(&bind_addr).await?; - info!(address = %bind_addr, "wyoming server listening"); - - loop { - tokio::select! { - result = listener.accept() => { - let (stream, addr) = result?; - info!(peer = %addr, "new connection"); - let config = Arc::clone(&session_config); - - tokio::spawn(async move { - let (reader, writer) = stream.into_split(); - if let Err(e) = session::run(config, reader, writer).await { - error!(peer = %addr, error = %e, "session error"); - } - info!(peer = %addr, "connection closed"); - }); - } - _ = tokio::signal::ctrl_c() => { - info!("shutting down"); - break; - } - } - } - - Ok(()) -} diff --git a/wyoming-whisper-rs/src/protocol.rs b/wyoming-whisper-rs/src/protocol.rs deleted file mode 100644 index 02ddbe1..0000000 --- a/wyoming-whisper-rs/src/protocol.rs +++ /dev/null @@ -1,90 +0,0 @@ -use serde::{Deserialize, Serialize}; -use serde_json::Value; -use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; - -use crate::error::Error; - -const MAX_PAYLOAD: usize = 100 * 1024 * 1024; // 100 MB - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Event { - #[serde(rename = "type")] - pub event_type: String, - - #[serde(default, skip_serializing_if = "Option::is_none")] - pub data: Option, - - #[serde(default, skip_serializing_if = "is_zero")] - pub data_length: usize, - - #[serde(default, skip_serializing_if = "is_zero")] - pub payload_length: usize, - - #[serde(skip)] - pub payload: Option>, -} - -fn is_zero(v: &usize) -> bool { - *v == 0 -} - -impl Event { - pub fn new(event_type: impl Into) -> Self { - Self { - event_type: event_type.into(), - data: None, - data_length: 0, - payload_length: 0, - payload: None, - } - } - - pub fn with_data(mut self, data: Value) -> Self { - let serialized = serde_json::to_string(&data).unwrap_or_default(); - self.data_length = serialized.len(); - self.data = Some(data); - self - } -} - -pub async fn read_event( - reader: &mut BufReader, -) -> Result, Error> { - let mut line = String::new(); - let n = reader.read_line(&mut line).await?; - if n == 0 { - return Ok(None); - } - - let mut event: Event = serde_json::from_str(line.trim())?; - - if event.payload_length > 0 { - if event.payload_length > MAX_PAYLOAD { - return Err(Error::PayloadTooLarge { - size: event.payload_length, - max: MAX_PAYLOAD, - }); - } - let mut buf = vec![0u8; event.payload_length]; - reader.read_exact(&mut buf).await?; - event.payload = Some(buf); - } - - Ok(Some(event)) -} - -pub async fn write_event( - writer: &mut W, - event: &Event, -) -> Result<(), Error> { - let json = serde_json::to_string(event)?; - writer.write_all(json.as_bytes()).await?; - writer.write_all(b"\n").await?; - - if let Some(ref payload) = event.payload { - writer.write_all(payload).await?; - } - - writer.flush().await?; - Ok(()) -} diff --git a/wyoming-whisper-rs/src/session.rs b/wyoming-whisper-rs/src/session.rs deleted file mode 100644 index 5c719d9..0000000 --- a/wyoming-whisper-rs/src/session.rs +++ /dev/null @@ -1,208 +0,0 @@ -use std::sync::Arc; - -use serde_json::{json, Value}; -use tokio::io::{AsyncRead, AsyncWrite, BufReader}; -use tokio::sync::Semaphore; -use tracing::{debug, warn}; -use whisper_rs::WhisperContext; - -use crate::error::Error; -use crate::protocol::{read_event, write_event, Event}; -use crate::transcribe::{AudioBuffer, TranscribeConfig}; - -pub struct SessionConfig { - pub ctx: Arc, - pub semaphore: Arc, - pub model_name: String, - pub languages: Vec, - pub default_language: Option, - pub beam_size: i32, - pub threads: i32, -} - -enum State { - Idle, - AwaitingAudioStart { - language: Option, - }, - Streaming { - buffer: AudioBuffer, - language: Option, - }, -} - -impl State { - fn name(&self) -> &'static str { - match self { - State::Idle => "idle", - State::AwaitingAudioStart { .. } => "awaiting_audio_start", - State::Streaming { .. } => "streaming", - } - } -} - -pub async fn run(config: Arc, reader: R, mut writer: W) -> Result<(), Error> -where - R: AsyncRead + Unpin, - W: AsyncWrite + Unpin, -{ - let mut reader = BufReader::new(reader); - let mut state = State::Idle; - - loop { - let event = match read_event(&mut reader).await? { - Some(e) => e, - None => { - debug!("client disconnected"); - return Ok(()); - } - }; - - debug!(event_type = %event.event_type, "received event"); - - match event.event_type.as_str() { - "describe" => { - let info = Event::new("info").with_data(json!({ - "type": "asr", - "name": config.model_name, - "description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION), - "installed": true, - "attribution": { - "name": "whisper.cpp", - "url": "https://github.com/ggerganov/whisper.cpp" - }, - "models": [{ - "name": config.model_name, - "description": format!("whisper-rs {}", whisper_rs::WHISPER_CPP_VERSION), - "installed": true, - "attribution": { - "name": "whisper.cpp", - "url": "https://github.com/ggerganov/whisper.cpp" - }, - "languages": config.languages, - }], - "version": env!("CARGO_PKG_VERSION"), - })); - write_event(&mut writer, &info).await?; - } - - "transcribe" => { - let language = event - .data - .as_ref() - .and_then(|d| d.get("language")) - .and_then(|v| v.as_str()) - .map(String::from) - .or_else(|| config.default_language.clone()); - - state = State::AwaitingAudioStart { language }; - } - - "audio-start" => { - let (rate, width, channels) = parse_audio_start(&event)?; - let language = match state { - State::AwaitingAudioStart { language } => language, - _ => config.default_language.clone(), - }; - - state = State::Streaming { - buffer: AudioBuffer::new(rate, width, channels), - language, - }; - } - - "audio-chunk" => { - if let State::Streaming { ref mut buffer, .. } = state { - if let Some(ref payload) = event.payload { - buffer.append(payload); - } - } else { - warn!(state = state.name(), "audio-chunk received in wrong state"); - } - } - - "audio-stop" => { - if let State::Streaming { buffer, language } = state { - state = State::Idle; - - let tc = TranscribeConfig { - language, - beam_size: config.beam_size, - threads: config.threads, - }; - - match do_transcribe(&config.ctx, &config.semaphore, tc, buffer).await { - Ok(text) => { - let transcript = Event::new("transcript").with_data(json!({ - "text": text, - })); - write_event(&mut writer, &transcript).await?; - } - Err(e) => { - warn!(error = %e, "transcription failed"); - let err = Event::new("error").with_data(json!({ - "text": format!("transcription failed: {e}"), - })); - write_event(&mut writer, &err).await?; - } - } - } else { - warn!(state = state.name(), "audio-stop received in wrong state"); - state = State::Idle; - } - } - - other => { - warn!(event_type = other, "unknown event type"); - let err = Event::new("error").with_data(json!({ - "text": format!("unknown event type: {other}"), - })); - write_event(&mut writer, &err).await?; - } - } - } -} - -fn parse_audio_start(event: &Event) -> Result<(u32, u16, u16), Error> { - let data = event - .data - .as_ref() - .ok_or(Error::MissingField("data in audio-start"))?; - - let rate = data - .get("rate") - .and_then(|v| v.as_u64()) - .ok_or(Error::MissingField("rate"))? as u32; - - let width = data - .get("width") - .and_then(|v| v.as_u64()) - .ok_or(Error::MissingField("width"))? as u16; - - let channels = data - .get("channels") - .and_then(|v| v.as_u64()) - .ok_or(Error::MissingField("channels"))? as u16; - - Ok((rate, width, channels)) -} - -async fn do_transcribe( - ctx: &Arc, - semaphore: &Arc, - config: TranscribeConfig, - buffer: AudioBuffer, -) -> Result { - let audio = buffer.into_f32_16khz_mono()?; - - if audio.is_empty() { - return Ok(String::new()); - } - - let _permit = semaphore.acquire().await.expect("semaphore closed"); - let ctx = Arc::clone(ctx); - - tokio::task::spawn_blocking(move || crate::transcribe::transcribe(&ctx, &config, audio)) - .await - .expect("transcription task panicked") -} diff --git a/wyoming-whisper-rs/src/transcribe.rs b/wyoming-whisper-rs/src/transcribe.rs deleted file mode 100644 index e7220d7..0000000 --- a/wyoming-whisper-rs/src/transcribe.rs +++ /dev/null @@ -1,151 +0,0 @@ -use std::sync::Arc; - -use whisper_rs::{FullParams, SamplingStrategy, WhisperContext}; - -use crate::error::Error; - -pub struct TranscribeConfig { - pub language: Option, - pub beam_size: i32, - pub threads: i32, -} - -pub struct AudioBuffer { - data: Vec, - rate: u32, - width: u16, - channels: u16, -} - -impl AudioBuffer { - pub fn new(rate: u32, width: u16, channels: u16) -> Self { - Self { - data: Vec::new(), - rate, - width, - channels, - } - } - - pub fn append(&mut self, chunk: &[u8]) { - self.data.extend_from_slice(chunk); - } - - pub fn into_f32_16khz_mono(self) -> Result, Error> { - if self.width != 2 { - return Err(Error::InvalidAudio(format!( - "expected 16-bit audio (width=2), got width={}", - self.width - ))); - } - - if !self.data.len().is_multiple_of(2) { - return Err(Error::InvalidAudio( - "audio data has odd number of bytes for 16-bit samples".into(), - )); - } - - // Interpret as i16 little-endian - let samples_i16: Vec = self - .data - .chunks_exact(2) - .map(|c| i16::from_le_bytes([c[0], c[1]])) - .collect(); - - // Convert i16 -> f32 - let mut samples_f32 = vec![0.0f32; samples_i16.len()]; - whisper_rs::convert_integer_to_float_audio(&samples_i16, &mut samples_f32) - .map_err(|e| Error::InvalidAudio(format!("i16 to f32 conversion failed: {e}")))?; - - // Convert stereo to mono if needed - let mono = if self.channels == 2 { - let mut mono = vec![0.0f32; samples_f32.len() / 2]; - whisper_rs::convert_stereo_to_mono_audio(&samples_f32, &mut mono) - .map_err(|e| Error::InvalidAudio(format!("stereo to mono failed: {e}")))?; - mono - } else if self.channels == 1 { - samples_f32 - } else { - return Err(Error::InvalidAudio(format!( - "unsupported channel count: {}", - self.channels - ))); - }; - - // Resample if not 16kHz - if self.rate == 16000 { - Ok(mono) - } else { - Ok(resample(&mono, self.rate, 16000)) - } - } -} - -/// Simple linear interpolation resampler. -fn resample(input: &[f32], from_rate: u32, to_rate: u32) -> Vec { - if from_rate == to_rate || input.is_empty() { - return input.to_vec(); - } - - let ratio = from_rate as f64 / to_rate as f64; - let output_len = ((input.len() as f64) / ratio).ceil() as usize; - let mut output = Vec::with_capacity(output_len); - - for i in 0..output_len { - let src_pos = i as f64 * ratio; - let idx = src_pos as usize; - let frac = src_pos - idx as f64; - - let sample = if idx + 1 < input.len() { - input[idx] as f64 * (1.0 - frac) + input[idx + 1] as f64 * frac - } else { - input[idx.min(input.len() - 1)] as f64 - }; - - output.push(sample as f32); - } - - output -} - -pub fn transcribe( - ctx: &Arc, - config: &TranscribeConfig, - audio: Vec, -) -> Result { - let mut state = ctx.create_state()?; - - let mut params = FullParams::new(SamplingStrategy::BeamSearch { - beam_size: config.beam_size, - patience: -1.0, - }); - - if let Some(ref lang) = config.language { - params.set_language(Some(lang)); - } else { - params.set_language(None); - params.set_detect_language(true); - } - - if config.threads > 0 { - params.set_n_threads(config.threads); - } - - params.set_print_special(false); - params.set_print_progress(false); - params.set_print_realtime(false); - params.set_print_timestamps(false); - params.set_no_context(true); - params.set_single_segment(false); - - state.full(params, &audio)?; - - let mut text = String::new(); - for segment in state.as_iter() { - if let Ok(s) = segment.to_str_lossy() { - text.push_str(&s); - } - } - - Ok(text.trim().to_string()) -}