Merge pull request #99 from tazz4843/whisper.cpp-1.5

Update whisper.cpp to v1.5
This commit is contained in:
Niko 2023-11-27 00:36:43 +00:00 committed by GitHub
commit 276bc43a35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 4459 additions and 112 deletions

View file

@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
[package]
name = "whisper-rs"
version = "0.9.0-rc.2"
version = "0.10.0"
edition = "2021"
description = "Rust bindings for whisper.cpp"
license = "Unlicense"
@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
whisper-rs-sys = { path = "sys", version = "0.7" }
whisper-rs-sys = { path = "sys", version = "0.8" }
[dev-dependencies]
hound = "3.5.0"
@ -24,10 +24,11 @@ default = []
simd = []
coreml = ["whisper-rs-sys/coreml"]
cuda = ["whisper-rs-sys/cuda"]
cuda = ["whisper-rs-sys/cuda", "_gpu"]
opencl = ["whisper-rs-sys/opencl"]
openblas = ["whisper-rs-sys/openblas"]
metal = ["whisper-rs-sys/metal"]
metal = ["whisper-rs-sys/metal", "_gpu"]
_gpu = []
test-with-tiny-model = []
[package.metadata.docs.rs]

View file

@ -15,13 +15,16 @@ cargo run --example audio_transcription
```
```rust
use whisper_rs::{WhisperContext, FullParams, SamplingStrategy};
use whisper_rs::{WhisperContext, WhisperContextParameters, FullParams, SamplingStrategy};
fn main() {
let path_to_model = std::env::args().nth(1).unwrap();
// load a context and model
let ctx = WhisperContext::new(&path_to_model).expect("failed to load model");
let ctx = WhisperContext::new_with_params(
path_to_model,
WhisperContextParameters::default()
).expect("failed to load model");
// create a params object
let params = FullParams::new(SamplingStrategy::Greedy { best_of: 1 });

View file

@ -4,13 +4,16 @@
use hound;
use std::fs::File;
use std::io::Write;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
fn main() -> Result<(), &'static str> {
// Load a context and model.
let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
.expect("failed to load model");
let ctx = WhisperContext::new_with_params(
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
WhisperContextParameters::default(),
)
.expect("failed to load model");
// Create a state
let mut state = ctx.create_state().expect("failed to create key");

View file

@ -1,13 +1,14 @@
#![allow(clippy::uninlined_format_args)]
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
// note that running this example will not do anything, as it is just a
// demonstration of how to use the library, and actual usage requires
// more dependencies than the base library.
pub fn usage() -> Result<(), &'static str> {
// load a context and model
let ctx = WhisperContext::new("path/to/model").expect("failed to load model");
let ctx = WhisperContext::new_with_params("path/to/model", WhisperContextParameters::default())
.expect("failed to load model");
// make a state
let mut state = ctx.create_state().expect("failed to create state");

View file

@ -2,7 +2,7 @@
use hound::{SampleFormat, WavReader};
use std::path::Path;
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextParameters};
fn parse_wav_file(path: &Path) -> Vec<i16> {
let reader = WavReader::open(path).expect("failed to read file");
@ -45,7 +45,10 @@ fn main() {
let original_samples = parse_wav_file(audio_path);
let samples = whisper_rs::convert_integer_to_float_audio(&original_samples);
let ctx = WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
let ctx = WhisperContext::new_with_params(
&whisper_path.to_string_lossy(),
WhisperContextParameters::default()
).expect("failed to open model");
let mut state = ctx.create_state().expect("failed to create key");
let mut params = FullParams::new(SamplingStrategy::default());
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));

View file

@ -5,6 +5,7 @@ mod error;
mod standalone;
mod utilities;
mod whisper_ctx;
mod whisper_grammar;
mod whisper_params;
mod whisper_state;
@ -12,6 +13,8 @@ pub use error::WhisperError;
pub use standalone::*;
pub use utilities::*;
pub use whisper_ctx::WhisperContext;
pub use whisper_ctx::WhisperContextParameters;
pub use whisper_grammar::{WhisperGrammarElement, WhisperGrammarElementType};
pub use whisper_params::{FullParams, SamplingStrategy};
pub use whisper_state::WhisperState;
@ -25,4 +28,4 @@ pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_cal
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
pub type WhisperAbortCallback = whisper_rs_sys::whisper_abort_callback;
pub type WhisperLogCallback = whisper_rs_sys::whisper_log_callback;
pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback;

View file

@ -55,10 +55,16 @@ pub fn get_lang_str(id: i32) -> Option<&'static str> {
/// Callback to control logging output: default behaviour is to print to stderr.
///
/// # Safety
/// The callback must be safe to call from C (i.e. no panicking, no unwinding, etc).
///
/// # C++ equivalent
/// `void whisper_set_log_callback(whisper_log_callback callback);`
pub unsafe fn set_log_callback(callback: whisper_rs_sys::whisper_log_callback) {
unsafe { whisper_rs_sys::whisper_set_log_callback(callback) }
pub unsafe fn set_log_callback(
log_callback: crate::WhisperLogCallback,
user_data: *mut std::ffi::c_void,
) {
unsafe { whisper_rs_sys::whisper_log_set(log_callback, user_data) }
}
/// Print system information.

View file

@ -5,7 +5,7 @@ use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context.
///
/// You likely want to create this with [WhisperContext::new],
/// You likely want to create this with [WhisperContext::new_with_params],
/// then run a full transcription with [WhisperContext::full].
#[derive(Debug)]
pub struct WhisperContext {
@ -13,6 +13,63 @@ pub struct WhisperContext {
}
impl WhisperContext {
/// Create a new WhisperContext from a file, with parameters.
///
/// # Arguments
/// * path: The path to the model file.
/// * parameters: A parameter struct containing the parameters to use.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params(
path: &str,
parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_file_with_params_no_state(
path_cstr.as_ptr(),
parameters.to_c_struct(),
)
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
/// Create a new WhisperContext from a buffer.
///
/// # Arguments
/// * buffer: The buffer containing the model.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);`
pub fn new_from_buffer_with_params(
buffer: &[u8],
parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> {
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_with_params_no_state(
buffer.as_ptr() as _,
buffer.len(),
parameters.to_c_struct(),
)
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
/// Create a new WhisperContext from a file.
///
/// # Arguments
@ -22,7 +79,8 @@ impl WhisperContext {
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file(const char * path_model);`
/// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)`
#[deprecated = "Use `new_with_params` instead"]
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
@ -42,7 +100,8 @@ impl WhisperContext {
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);`
/// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)`
#[deprecated = "Use `new_from_buffer_with_params` instead"]
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
@ -472,6 +531,37 @@ impl Drop for WhisperContext {
unsafe impl Send for WhisperContext {}
unsafe impl Sync for WhisperContext {}
pub struct WhisperContextParameters {
/// Use GPU if available.
///
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
/// (in that case, GPU is always enabled).
pub use_gpu: bool,
}
#[allow(clippy::derivable_impls)] // this impl cannot be derived
impl Default for WhisperContextParameters {
fn default() -> Self {
Self {
use_gpu: cfg!(feature = "_gpu"),
}
}
}
impl WhisperContextParameters {
pub fn new() -> Self {
Self::default()
}
pub fn use_gpu(&mut self, use_gpu: bool) -> &mut Self {
self.use_gpu = use_gpu;
self
}
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
whisper_rs_sys::whisper_context_params {
use_gpu: self.use_gpu,
}
}
}
#[cfg(test)]
#[cfg(feature = "test-with-tiny-model")]
mod test_with_tiny_model {

80
src/whisper_grammar.rs Normal file
View file

@ -0,0 +1,80 @@
use whisper_rs_sys::{
whisper_gretype_WHISPER_GRETYPE_ALT, whisper_gretype_WHISPER_GRETYPE_CHAR,
whisper_gretype_WHISPER_GRETYPE_CHAR_ALT, whisper_gretype_WHISPER_GRETYPE_CHAR_NOT,
whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER, whisper_gretype_WHISPER_GRETYPE_END,
whisper_gretype_WHISPER_GRETYPE_RULE_REF,
};
#[cfg_attr(not(windows), repr(u32))]
#[cfg_attr(windows, repr(i32))] // windows being *special* again
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum WhisperGrammarElementType {
/// End of rule definition
End = whisper_gretype_WHISPER_GRETYPE_END,
/// Start of alternate definition for a rule
Alternate = whisper_gretype_WHISPER_GRETYPE_ALT,
/// Non-terminal element: reference to another rule
RuleReference = whisper_gretype_WHISPER_GRETYPE_RULE_REF,
/// Terminal element: character (code point)
Character = whisper_gretype_WHISPER_GRETYPE_CHAR,
/// Inverse of a character(s)
NotCharacter = whisper_gretype_WHISPER_GRETYPE_CHAR_NOT,
/// Modifies a preceding [Self::Character] to be an inclusive range
CharacterRangeUpper = whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER,
/// Modifies a preceding [Self::Character] to add an alternate character to match
CharacterAlternate = whisper_gretype_WHISPER_GRETYPE_CHAR_ALT,
}
impl From<whisper_rs_sys::whisper_gretype> for WhisperGrammarElementType {
fn from(value: whisper_rs_sys::whisper_gretype) -> Self {
assert!(
(0..=6).contains(&value),
"Invalid WhisperGrammarElementType value: {}",
value
);
#[allow(non_upper_case_globals)] // weird place to trigger this
match value {
whisper_gretype_WHISPER_GRETYPE_END => WhisperGrammarElementType::End,
whisper_gretype_WHISPER_GRETYPE_ALT => WhisperGrammarElementType::Alternate,
whisper_gretype_WHISPER_GRETYPE_RULE_REF => WhisperGrammarElementType::RuleReference,
whisper_gretype_WHISPER_GRETYPE_CHAR => WhisperGrammarElementType::Character,
whisper_gretype_WHISPER_GRETYPE_CHAR_NOT => WhisperGrammarElementType::NotCharacter,
whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER => {
WhisperGrammarElementType::CharacterRangeUpper
}
whisper_gretype_WHISPER_GRETYPE_CHAR_ALT => {
WhisperGrammarElementType::CharacterAlternate
}
_ => unreachable!(),
}
}
}
impl From<WhisperGrammarElementType> for whisper_rs_sys::whisper_gretype {
fn from(value: WhisperGrammarElementType) -> Self {
value as Self
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct WhisperGrammarElement {
pub element_type: WhisperGrammarElementType,
pub value: u32,
}
impl WhisperGrammarElement {
pub fn new(element_type: WhisperGrammarElementType, value: u32) -> Self {
Self {
element_type,
value,
}
}
pub fn to_c_type(self) -> whisper_rs_sys::whisper_grammar_element {
whisper_rs_sys::whisper_grammar_element {
type_: self.element_type.into(),
value: self.value,
}
}
}

View file

@ -1,3 +1,4 @@
use crate::whisper_grammar::WhisperGrammarElement;
use std::ffi::{c_float, c_int, CString};
use std::marker::PhantomData;
use whisper_rs_sys::whisper_token;
@ -24,6 +25,7 @@ pub struct FullParams<'a, 'b> {
pub(crate) fp: whisper_rs_sys::whisper_full_params,
phantom_lang: PhantomData<&'a str>,
phantom_tokens: PhantomData<&'b [c_int]>,
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
}
@ -58,6 +60,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
fp,
phantom_lang: PhantomData,
phantom_tokens: PhantomData,
grammar: None,
progess_callback_safe: None,
}
}
@ -104,6 +107,13 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.no_context = no_context;
}
/// Do not generate timestamps.
///
/// Defaults to false.
pub fn set_no_timestamps(&mut self, no_timestamps: bool) {
self.fp.no_timestamps = no_timestamps;
}
/// Force single segment output. This may be useful for streaming.
///
/// Defaults to false.
@ -529,6 +539,46 @@ impl<'a, 'b> FullParams<'a, 'b> {
pub unsafe fn set_abort_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
self.fp.abort_callback_user_data = user_data;
}
/// Enable an array of grammar elements to be passed to the whisper model.
///
/// Defaults to an empty vector.
pub fn set_grammar(&mut self, grammar: Option<&[WhisperGrammarElement]>) {
if let Some(grammar) = grammar {
// convert to c types
let inner = grammar.iter().map(|e| e.to_c_type()).collect::<Vec<_>>();
// turn into ptr and len
let grammar_ptr = inner.as_ptr() as *mut _;
let grammar_len = inner.len();
self.grammar = Some(inner);
// set the grammar
self.fp.grammar_rules = grammar_ptr;
self.fp.n_grammar_rules = grammar_len;
} else {
self.grammar = None;
self.fp.grammar_rules = std::ptr::null_mut();
self.fp.n_grammar_rules = 0;
self.fp.i_start_rule = 0;
}
}
/// Set the start grammar rule. Does nothing if no grammar is set.
///
/// Defaults to 0.
pub fn set_start_rule(&mut self, start_rule: usize) {
if self.grammar.is_some() {
self.fp.i_start_rule = start_rule;
}
}
/// Set grammar penalty.
///
/// Defaults to 100.0.
pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) {
self.fp.grammar_penalty = grammar_penalty;
}
}
// following implementations are safe

View file

@ -1,6 +1,6 @@
[package]
name = "whisper-rs-sys"
version = "0.7.3"
version = "0.8.0"
edition = "2021"
description = "Rust bindings for whisper.cpp (FFI bindings)"
license = "Unlicense"
@ -43,6 +43,6 @@ metal = []
[build-dependencies]
cmake = "0.1"
bindgen = "0.68"
bindgen = "0.69"
cfg-if = "1"
fs_extra = "1.3"

View file

@ -78,7 +78,7 @@ fn main() {
let bindings = bindgen::Builder::default()
.header("wrapper.h")
.clang_arg("-I./whisper.cpp")
.parse_callbacks(Box::new(bindgen::CargoCallbacks))
.parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate();
match bindings {

File diff suppressed because it is too large Load diff

@ -1 +1 @@
Subproject commit 1b775cdd68843fcfe331fc32ceb0d915c73a3cbd
Subproject commit d38af151a1ed5378c5a9ae368e767ed22c8ab141