Update whisper.cpp to v1.5 and add all structs

This commit is contained in:
Niko 2023-11-16 18:56:50 -07:00
parent 73e33a182c
commit 8690d35deb
No known key found for this signature in database
GPG key ID: 3861E636EA1E0E2B
10 changed files with 4437 additions and 103 deletions

View file

@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
[package] [package]
name = "whisper-rs" name = "whisper-rs"
version = "0.9.0-rc.2" version = "0.10.0"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp" description = "Rust bindings for whisper.cpp"
license = "Unlicense" license = "Unlicense"
@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
whisper-rs-sys = { path = "sys", version = "0.7" } whisper-rs-sys = { path = "sys", version = "0.8" }
[dev-dependencies] [dev-dependencies]
hound = "3.5.0" hound = "3.5.0"
@ -24,10 +24,11 @@ default = []
simd = [] simd = []
coreml = ["whisper-rs-sys/coreml"] coreml = ["whisper-rs-sys/coreml"]
cuda = ["whisper-rs-sys/cuda"] cuda = ["whisper-rs-sys/cuda", "_gpu"]
opencl = ["whisper-rs-sys/opencl"] opencl = ["whisper-rs-sys/opencl"]
openblas = ["whisper-rs-sys/openblas"] openblas = ["whisper-rs-sys/openblas"]
metal = ["whisper-rs-sys/metal"] metal = ["whisper-rs-sys/metal", "_gpu"]
_gpu = []
test-with-tiny-model = [] test-with-tiny-model = []
[package.metadata.docs.rs] [package.metadata.docs.rs]

View file

@ -5,6 +5,7 @@ mod error;
mod standalone; mod standalone;
mod utilities; mod utilities;
mod whisper_ctx; mod whisper_ctx;
mod whisper_grammar;
mod whisper_params; mod whisper_params;
mod whisper_state; mod whisper_state;
@ -25,4 +26,4 @@ pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_cal
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback; pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback; pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
pub type WhisperAbortCallback = whisper_rs_sys::whisper_abort_callback; pub type WhisperAbortCallback = whisper_rs_sys::whisper_abort_callback;
pub type WhisperLogCallback = whisper_rs_sys::whisper_log_callback; pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback;

View file

@ -55,10 +55,16 @@ pub fn get_lang_str(id: i32) -> Option<&'static str> {
/// Callback to control logging output: default behaviour is to print to stderr. /// Callback to control logging output: default behaviour is to print to stderr.
/// ///
/// # Safety
/// The callback must be safe to call from C (i.e. no panicking, no unwinding, etc).
///
/// # C++ equivalent /// # C++ equivalent
/// `void whisper_set_log_callback(whisper_log_callback callback);` /// `void whisper_set_log_callback(whisper_log_callback callback);`
pub unsafe fn set_log_callback(callback: whisper_rs_sys::whisper_log_callback) { pub unsafe fn set_log_callback(
unsafe { whisper_rs_sys::whisper_set_log_callback(callback) } log_callback: crate::WhisperLogCallback,
user_data: *mut std::ffi::c_void,
) {
unsafe { whisper_rs_sys::whisper_log_set(log_callback, user_data) }
} }
/// Print system information. /// Print system information.

View file

@ -5,7 +5,7 @@ use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context. /// Safe Rust wrapper around a Whisper context.
/// ///
/// You likely want to create this with [WhisperContext::new], /// You likely want to create this with [WhisperContext::new_with_params],
/// then run a full transcription with [WhisperContext::full]. /// then run a full transcription with [WhisperContext::full].
#[derive(Debug)] #[derive(Debug)]
pub struct WhisperContext { pub struct WhisperContext {
@ -13,6 +13,63 @@ pub struct WhisperContext {
} }
impl WhisperContext { impl WhisperContext {
/// Create a new WhisperContext from a file, with parameters.
///
/// # Arguments
/// * path: The path to the model file.
/// * parameters: A parameter struct containing the parameters to use.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file_with_params_no_state(const char * path_model, struct whisper_context_params params);`
pub fn new_with_params(
path: &str,
parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_file_with_params_no_state(
path_cstr.as_ptr(),
parameters.to_c_struct(),
)
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
/// Create a new WhisperContext from a buffer.
///
/// # Arguments
/// * buffer: The buffer containing the model.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer_with_params_no_state(void * buffer, size_t buffer_size, struct whisper_context_params params);`
pub fn new_from_buffer_with_params(
buffer: &[u8],
parameters: WhisperContextParameters,
) -> Result<Self, WhisperError> {
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_with_params_no_state(
buffer.as_ptr() as _,
buffer.len(),
parameters.to_c_struct(),
)
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
/// Create a new WhisperContext from a file. /// Create a new WhisperContext from a file.
/// ///
/// # Arguments /// # Arguments
@ -22,7 +79,8 @@ impl WhisperContext {
/// Ok(Self) on success, Err(WhisperError) on failure. /// Ok(Self) on success, Err(WhisperError) on failure.
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file(const char * path_model);` /// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)`
#[deprecated = "Use `new_with_params` instead"]
pub fn new(path: &str) -> Result<Self, WhisperError> { pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?; let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) }; let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
@ -42,7 +100,8 @@ impl WhisperContext {
/// Ok(Self) on success, Err(WhisperError) on failure. /// Ok(Self) on success, Err(WhisperError) on failure.
/// ///
/// # C++ equivalent /// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);` /// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)`
#[deprecated = "Use `new_from_buffer_with_params` instead"]
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> { pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx = unsafe { let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len()) whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
@ -472,6 +531,37 @@ impl Drop for WhisperContext {
unsafe impl Send for WhisperContext {} unsafe impl Send for WhisperContext {}
unsafe impl Sync for WhisperContext {} unsafe impl Sync for WhisperContext {}
pub struct WhisperContextParameters {
/// Use GPU if available.
///
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
/// (in that case, GPU is always enabled).
pub use_gpu: bool,
}
#[allow(clippy::derivable_impls)] // this impl cannot be derived
impl Default for WhisperContextParameters {
fn default() -> Self {
Self {
use_gpu: cfg!(feature = "_gpu"),
}
}
}
impl WhisperContextParameters {
pub fn new() -> Self {
Self::default()
}
pub fn use_gpu(mut self, use_gpu: bool) -> Self {
self.use_gpu = use_gpu;
self
}
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
whisper_rs_sys::whisper_context_params {
use_gpu: self.use_gpu,
}
}
}
#[cfg(test)] #[cfg(test)]
#[cfg(feature = "test-with-tiny-model")] #[cfg(feature = "test-with-tiny-model")]
mod test_with_tiny_model { mod test_with_tiny_model {

79
src/whisper_grammar.rs Normal file
View file

@ -0,0 +1,79 @@
use whisper_rs_sys::{
whisper_gretype_WHISPER_GRETYPE_ALT, whisper_gretype_WHISPER_GRETYPE_CHAR,
whisper_gretype_WHISPER_GRETYPE_CHAR_ALT, whisper_gretype_WHISPER_GRETYPE_CHAR_NOT,
whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER, whisper_gretype_WHISPER_GRETYPE_END,
whisper_gretype_WHISPER_GRETYPE_RULE_REF,
};
#[repr(u32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub enum WhisperGrammarElementType {
/// End of rule definition
End = whisper_gretype_WHISPER_GRETYPE_END,
/// Start of alternate definition for a rule
Alternate = whisper_gretype_WHISPER_GRETYPE_ALT,
/// Non-terminal element: reference to another rule
RuleReference = whisper_gretype_WHISPER_GRETYPE_RULE_REF,
/// Terminal element: character (code point)
Character = whisper_gretype_WHISPER_GRETYPE_CHAR,
/// Inverse of a character(s)
NotCharacter = whisper_gretype_WHISPER_GRETYPE_CHAR_NOT,
/// Modifies a preceding [Self::Character] to be an inclusive range
CharacterRangeUpper = whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER,
/// Modifies a preceding [Self::Character] to add an alternate character to match
CharacterAlternate = whisper_gretype_WHISPER_GRETYPE_CHAR_ALT,
}
impl From<whisper_rs_sys::whisper_gretype> for WhisperGrammarElementType {
fn from(value: whisper_rs_sys::whisper_gretype) -> Self {
assert!(
(0..=6).contains(&value),
"Invalid WhisperGrammarElementType value: {}",
value
);
#[allow(non_upper_case_globals)] // weird place to trigger this
match value {
whisper_gretype_WHISPER_GRETYPE_END => WhisperGrammarElementType::End,
whisper_gretype_WHISPER_GRETYPE_ALT => WhisperGrammarElementType::Alternate,
whisper_gretype_WHISPER_GRETYPE_RULE_REF => WhisperGrammarElementType::RuleReference,
whisper_gretype_WHISPER_GRETYPE_CHAR => WhisperGrammarElementType::Character,
whisper_gretype_WHISPER_GRETYPE_CHAR_NOT => WhisperGrammarElementType::NotCharacter,
whisper_gretype_WHISPER_GRETYPE_CHAR_RNG_UPPER => {
WhisperGrammarElementType::CharacterRangeUpper
}
whisper_gretype_WHISPER_GRETYPE_CHAR_ALT => {
WhisperGrammarElementType::CharacterAlternate
}
_ => unreachable!(),
}
}
}
impl From<WhisperGrammarElementType> for whisper_rs_sys::whisper_gretype {
fn from(value: WhisperGrammarElementType) -> Self {
value as Self
}
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
pub struct WhisperGrammarElement {
pub element_type: WhisperGrammarElementType,
pub value: u32,
}
impl WhisperGrammarElement {
pub fn new(element_type: WhisperGrammarElementType, value: u32) -> Self {
Self {
element_type,
value,
}
}
pub fn to_c_type(self) -> whisper_rs_sys::whisper_grammar_element {
whisper_rs_sys::whisper_grammar_element {
type_: self.element_type.into(),
value: self.value,
}
}
}

View file

@ -1,3 +1,4 @@
use crate::whisper_grammar::WhisperGrammarElement;
use std::ffi::{c_float, c_int, CString}; use std::ffi::{c_float, c_int, CString};
use std::marker::PhantomData; use std::marker::PhantomData;
use whisper_rs_sys::whisper_token; use whisper_rs_sys::whisper_token;
@ -24,6 +25,7 @@ pub struct FullParams<'a, 'b> {
pub(crate) fp: whisper_rs_sys::whisper_full_params, pub(crate) fp: whisper_rs_sys::whisper_full_params,
phantom_lang: PhantomData<&'a str>, phantom_lang: PhantomData<&'a str>,
phantom_tokens: PhantomData<&'b [c_int]>, phantom_tokens: PhantomData<&'b [c_int]>,
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
progess_callback_safe: Option<Box<dyn FnMut(i32)>>, progess_callback_safe: Option<Box<dyn FnMut(i32)>>,
} }
@ -58,6 +60,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
fp, fp,
phantom_lang: PhantomData, phantom_lang: PhantomData,
phantom_tokens: PhantomData, phantom_tokens: PhantomData,
grammar: None,
progess_callback_safe: None, progess_callback_safe: None,
} }
} }
@ -104,6 +107,13 @@ impl<'a, 'b> FullParams<'a, 'b> {
self.fp.no_context = no_context; self.fp.no_context = no_context;
} }
/// Do not generate timestamps.
///
/// Defaults to false.
pub fn set_no_timestamps(&mut self, no_timestamps: bool) {
self.fp.no_timestamps = no_timestamps;
}
/// Force single segment output. This may be useful for streaming. /// Force single segment output. This may be useful for streaming.
/// ///
/// Defaults to false. /// Defaults to false.
@ -529,6 +539,46 @@ impl<'a, 'b> FullParams<'a, 'b> {
pub unsafe fn set_abort_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) { pub unsafe fn set_abort_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
self.fp.abort_callback_user_data = user_data; self.fp.abort_callback_user_data = user_data;
} }
/// Enable an array of grammar elements to be passed to the whisper model.
///
/// Defaults to an empty vector.
pub fn set_grammar(&mut self, grammar: Option<&[WhisperGrammarElement]>) {
if let Some(grammar) = grammar {
// convert to c types
let inner = grammar.iter().map(|e| e.to_c_type()).collect::<Vec<_>>();
// turn into ptr and len
let grammar_ptr = inner.as_ptr() as *mut _;
let grammar_len = inner.len();
self.grammar = Some(inner);
// set the grammar
self.fp.grammar_rules = grammar_ptr;
self.fp.n_grammar_rules = grammar_len;
} else {
self.grammar = None;
self.fp.grammar_rules = std::ptr::null_mut();
self.fp.n_grammar_rules = 0;
self.fp.i_start_rule = 0;
}
}
/// Set the start grammar rule. Does nothing if no grammar is set.
///
/// Defaults to 0.
pub fn set_start_rule(&mut self, start_rule: usize) {
if self.grammar.is_some() {
self.fp.i_start_rule = start_rule;
}
}
/// Set grammar penalty.
///
/// Defaults to 100.0.
pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) {
self.fp.grammar_penalty = grammar_penalty;
}
} }
// following implementations are safe // following implementations are safe

View file

@ -1,6 +1,6 @@
[package] [package]
name = "whisper-rs-sys" name = "whisper-rs-sys"
version = "0.7.3" version = "0.8.0"
edition = "2021" edition = "2021"
description = "Rust bindings for whisper.cpp (FFI bindings)" description = "Rust bindings for whisper.cpp (FFI bindings)"
license = "Unlicense" license = "Unlicense"
@ -43,6 +43,6 @@ metal = []
[build-dependencies] [build-dependencies]
cmake = "0.1" cmake = "0.1"
bindgen = "0.68" bindgen = "0.69"
cfg-if = "1" cfg-if = "1"
fs_extra = "1.3" fs_extra = "1.3"

View file

@ -78,7 +78,7 @@ fn main() {
let bindings = bindgen::Builder::default() let bindings = bindgen::Builder::default()
.header("wrapper.h") .header("wrapper.h")
.clang_arg("-I./whisper.cpp") .clang_arg("-I./whisper.cpp")
.parse_callbacks(Box::new(bindgen::CargoCallbacks)) .parse_callbacks(Box::new(bindgen::CargoCallbacks::new()))
.generate(); .generate();
match bindings { match bindings {

File diff suppressed because it is too large Load diff

@ -1 +1 @@
Subproject commit 1b775cdd68843fcfe331fc32ceb0d915c73a3cbd Subproject commit e8d5638b7c19e60cdf03d6928a0c8933cb31d2ad