Merge remote-tracking branch 'origin/master'
# Conflicts: # sys/src/bindings.rs
This commit is contained in:
commit
2cf2c7f499
11 changed files with 200 additions and 45 deletions
10
CHANGELOG.md
Normal file
10
CHANGELOG.md
Normal file
|
|
@ -0,0 +1,10 @@
|
||||||
|
# Version 0.2.0 (2022-10-28)
|
||||||
|
* Update upstream whisper.cpp to 2c281d190b7ec351b8128ba386d110f100993973.
|
||||||
|
* Fix breaking changes in update, which cascade to users:
|
||||||
|
* `DecodeStrategy` has been renamed to `SamplingStrategy`
|
||||||
|
* `WhisperContext::sample_best`'s signature has changed: `needs_timestamp` has been removed.
|
||||||
|
* New features
|
||||||
|
* `WhisperContext::full_n_tokens`
|
||||||
|
* `WhisperContext::full_get_token_text`
|
||||||
|
* `WhisperContext::full_get_token_id`
|
||||||
|
* `WhisperContext::full_get_token_prob`
|
||||||
|
|
@ -3,7 +3,7 @@ members = ["sys"]
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "whisper-rs"
|
name = "whisper-rs"
|
||||||
version = "0.1.3"
|
version = "0.2.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Rust bindings for whisper.cpp"
|
description = "Rust bindings for whisper.cpp"
|
||||||
license = "Unlicense"
|
license = "Unlicense"
|
||||||
|
|
@ -13,7 +13,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
whisper-rs-sys = { path = "sys", version = "0.1" }
|
whisper-rs-sys = { path = "sys", version = "0.2" }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
simd = []
|
simd = []
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,14 @@ See the docs: https://docs.rs/whisper-rs/ for more details.
|
||||||
* Windows/macOS/Android aren't working!
|
* Windows/macOS/Android aren't working!
|
||||||
* I don't have a way to test these platforms, so I can't really help you.
|
* I don't have a way to test these platforms, so I can't really help you.
|
||||||
* If you can get it working, please open a PR!
|
* If you can get it working, please open a PR!
|
||||||
|
* I get a panic during binding generation build!
|
||||||
|
* You can attempt to fix it yourself, or you can set the `WHISPER_DONT_GENERATE_BINDINGS` environment variable.
|
||||||
|
This skips attempting to build the bindings whatsoever and copies the existing ones. They may be out of date,
|
||||||
|
but it's better than nothing.
|
||||||
|
* `WHISPER_DONT_GENERATE_BINDINGS=1 cargo build`
|
||||||
|
* If you can fix the issue, please open a PR!
|
||||||
|
* M1 build info:
|
||||||
|
* See [this issue](https://github.com/tazz4843/whisper-rs/pull/2) for more info.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
[Unlicense](LICENSE)
|
[Unlicense](LICENSE)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use whisper_rs::{DecodeStrategy, FullParams, WhisperContext};
|
use whisper_rs::{FullParams, SamplingStrategy, WhisperContext};
|
||||||
|
|
||||||
// note that running this example will not do anything, as it is just a
|
// note that running this example will not do anything, as it is just a
|
||||||
// demonstration of how to use the library, and actual usage requires
|
// demonstration of how to use the library, and actual usage requires
|
||||||
|
|
@ -10,7 +10,7 @@ pub fn usage() {
|
||||||
// create a params object
|
// create a params object
|
||||||
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
|
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
|
||||||
// n_past defaults to 0
|
// n_past defaults to 0
|
||||||
let mut params = FullParams::new(DecodeStrategy::Greedy { n_past: 0 });
|
let mut params = FullParams::new(SamplingStrategy::Greedy { n_past: 0 });
|
||||||
|
|
||||||
// edit things as needed
|
// edit things as needed
|
||||||
// here we set the number of threads to use to 1
|
// here we set the number of threads to use to 1
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ pub use error::WhisperError;
|
||||||
pub use standalone::*;
|
pub use standalone::*;
|
||||||
pub use utilities::*;
|
pub use utilities::*;
|
||||||
pub use whisper_ctx::WhisperContext;
|
pub use whisper_ctx::WhisperContext;
|
||||||
pub use whisper_params::{DecodeStrategy, FullParams};
|
pub use whisper_params::{FullParams, SamplingStrategy};
|
||||||
|
|
||||||
pub type WhisperToken = std::ffi::c_int;
|
pub type WhisperToken = std::ffi::c_int;
|
||||||
|
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
//! Standalone functions that have no associated type.
|
//! Standalone functions that have no associated type.
|
||||||
|
|
||||||
use crate::WhisperToken;
|
use crate::WhisperToken;
|
||||||
use std::ffi::{c_int, CString};
|
use std::ffi::{c_int, CStr, CString};
|
||||||
|
|
||||||
/// Return the id of the specified language, returns -1 if not found
|
/// Return the id of the specified language, returns -1 if not found
|
||||||
///
|
///
|
||||||
|
|
@ -42,3 +42,13 @@ pub fn token_translate() -> WhisperToken {
|
||||||
pub fn token_transcribe() -> WhisperToken {
|
pub fn token_transcribe() -> WhisperToken {
|
||||||
unsafe { whisper_rs_sys::whisper_token_transcribe() }
|
unsafe { whisper_rs_sys::whisper_token_transcribe() }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Print system information.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_print_system_info()`
|
||||||
|
pub fn print_system_info() -> &'static str {
|
||||||
|
let c_buf = unsafe { whisper_rs_sys::whisper_print_system_info() };
|
||||||
|
let c_str = unsafe { CStr::from_ptr(c_buf) };
|
||||||
|
c_str.to_str().unwrap()
|
||||||
|
}
|
||||||
|
|
@ -111,6 +111,7 @@ pub fn convert_stereo_to_mono_audio_simd(samples: &[f32]) -> Vec<f32> {
|
||||||
mono
|
mono
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "simd")]
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test {
|
mod test {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
|
|
@ -194,11 +194,11 @@ impl WhisperContext {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
|
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
|
||||||
pub fn sample_best(&mut self, needs_timestamp: bool) -> Result<WhisperToken, WhisperError> {
|
pub fn sample_best(&mut self) -> Result<WhisperToken, WhisperError> {
|
||||||
if !self.decode_once {
|
if !self.decode_once {
|
||||||
return Err(WhisperError::DecodeNotComplete);
|
return Err(WhisperError::DecodeNotComplete);
|
||||||
}
|
}
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) };
|
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) };
|
||||||
Ok(ret)
|
Ok(ret)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -363,7 +363,7 @@ impl WhisperContext {
|
||||||
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Print performance statistics to stdout.
|
/// Print performance statistics to stderr.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `void whisper_print_timings(struct whisper_context * ctx)`
|
/// `void whisper_print_timings(struct whisper_context * ctx)`
|
||||||
|
|
@ -446,6 +446,90 @@ impl WhisperContext {
|
||||||
let r_str = c_str.to_str()?;
|
let r_str = c_str.to_str()?;
|
||||||
Ok(r_str.to_string())
|
Ok(r_str.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get number of tokens in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(c_int) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
|
||||||
|
let ret = unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) };
|
||||||
|
if ret < 0 {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
} else {
|
||||||
|
Ok(ret as c_int)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token text of the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn full_get_token_text(
|
||||||
|
&self,
|
||||||
|
segment: c_int,
|
||||||
|
token: c_int,
|
||||||
|
) -> Result<String, WhisperError> {
|
||||||
|
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) };
|
||||||
|
if ret.is_null() {
|
||||||
|
return Err(WhisperError::NullPointer);
|
||||||
|
}
|
||||||
|
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||||
|
let r_str = c_str.to_str()?;
|
||||||
|
Ok(r_str.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token ID of the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn full_get_token_id(
|
||||||
|
&self,
|
||||||
|
segment: c_int,
|
||||||
|
token: c_int,
|
||||||
|
) -> Result<WhisperToken, WhisperError> {
|
||||||
|
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) };
|
||||||
|
if ret < 0 {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
} else {
|
||||||
|
Ok(ret as WhisperToken)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the probability of the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// f32
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 {
|
||||||
|
unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) }
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for WhisperContext {
|
impl Drop for WhisperContext {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
use std::ffi::c_int;
|
use std::ffi::{c_int, CString};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
pub enum DecodeStrategy {
|
pub enum SamplingStrategy {
|
||||||
Greedy {
|
Greedy {
|
||||||
n_past: c_int,
|
n_past: c_int,
|
||||||
},
|
},
|
||||||
|
|
@ -20,26 +20,30 @@ pub struct FullParams<'a> {
|
||||||
|
|
||||||
impl<'a> FullParams<'a> {
|
impl<'a> FullParams<'a> {
|
||||||
/// Create a new set of parameters for the decoder.
|
/// Create a new set of parameters for the decoder.
|
||||||
pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> {
|
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> {
|
||||||
let mut fp = unsafe {
|
let mut fp = unsafe {
|
||||||
whisper_rs_sys::whisper_full_default_params(match decode_strategy {
|
whisper_rs_sys::whisper_full_default_params(match sampling_strategy {
|
||||||
DecodeStrategy::Greedy { .. } => 0,
|
SamplingStrategy::Greedy { .. } => {
|
||||||
DecodeStrategy::BeamSearch { .. } => 1,
|
whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY
|
||||||
|
}
|
||||||
|
SamplingStrategy::BeamSearch { .. } => {
|
||||||
|
whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH
|
||||||
|
}
|
||||||
} as _)
|
} as _)
|
||||||
};
|
};
|
||||||
|
|
||||||
match decode_strategy {
|
match sampling_strategy {
|
||||||
DecodeStrategy::Greedy { n_past } => {
|
SamplingStrategy::Greedy { n_past } => {
|
||||||
fp.__bindgen_anon_1.greedy.n_past = n_past;
|
fp.greedy.n_past = n_past;
|
||||||
}
|
}
|
||||||
DecodeStrategy::BeamSearch {
|
SamplingStrategy::BeamSearch {
|
||||||
n_past,
|
n_past,
|
||||||
beam_width,
|
beam_width,
|
||||||
n_best,
|
n_best,
|
||||||
} => {
|
} => {
|
||||||
fp.__bindgen_anon_1.beam_search.n_past = n_past;
|
fp.beam_search.n_past = n_past;
|
||||||
fp.__bindgen_anon_1.beam_search.beam_width = beam_width;
|
fp.beam_search.beam_width = beam_width;
|
||||||
fp.__bindgen_anon_1.beam_search.n_best = n_best;
|
fp.beam_search.n_best = n_best;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -109,7 +113,36 @@ impl<'a> FullParams<'a> {
|
||||||
///
|
///
|
||||||
/// Defaults to "en".
|
/// Defaults to "en".
|
||||||
pub fn set_language(&mut self, language: &'a str) {
|
pub fn set_language(&mut self, language: &'a str) {
|
||||||
self.fp.language = language.as_ptr() as *const _;
|
let c_lang = CString::new(language).expect("Language contains null byte");
|
||||||
|
self.fp.language = c_lang.into_raw() as *const _;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the callback for new segments.
|
||||||
|
///
|
||||||
|
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
|
||||||
|
/// It is still a C callback.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// Do not use this function unless you know what you are doing.
|
||||||
|
/// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
|
||||||
|
/// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
|
||||||
|
///
|
||||||
|
/// Defaults to None.
|
||||||
|
pub unsafe fn set_new_segment_callback(
|
||||||
|
&mut self,
|
||||||
|
new_segment_callback: crate::WhisperNewSegmentCallback,
|
||||||
|
) {
|
||||||
|
self.fp.new_segment_callback = new_segment_callback;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the user data to be passed to the new segment callback.
|
||||||
|
///
|
||||||
|
/// # Safety
|
||||||
|
/// See the safety notes for `set_new_segment_callback`.
|
||||||
|
///
|
||||||
|
/// Defaults to None.
|
||||||
|
pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
|
||||||
|
self.fp.new_segment_callback_user_data = user_data;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -13,4 +13,4 @@ links = "whisper"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
bindgen = "0.60"
|
bindgen = "0.61"
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,13 @@ fn main() {
|
||||||
println!("cargo:rustc-link-lib=static=whisper");
|
println!("cargo:rustc-link-lib=static=whisper");
|
||||||
println!("cargo:rerun-if-changed=wrapper.h");
|
println!("cargo:rerun-if-changed=wrapper.h");
|
||||||
|
|
||||||
|
if env::var("WHISPER_DONT_GENERATE_BINDINGS").is_ok() {
|
||||||
|
let _: u64 = std::fs::copy(
|
||||||
|
"src/bindings.rs",
|
||||||
|
env::var("OUT_DIR").unwrap() + "/bindings.rs",
|
||||||
|
)
|
||||||
|
.expect("Failed to copy bindings.rs");
|
||||||
|
} else {
|
||||||
let bindings = bindgen::Builder::default()
|
let bindings = bindgen::Builder::default()
|
||||||
.header("wrapper.h")
|
.header("wrapper.h")
|
||||||
.clang_arg("-I./whisper.cpp")
|
.clang_arg("-I./whisper.cpp")
|
||||||
|
|
@ -31,6 +38,7 @@ fn main() {
|
||||||
.expect("Unable to copy bindings.rs");
|
.expect("Unable to copy bindings.rs");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// stop if we're on docs.rs
|
// stop if we're on docs.rs
|
||||||
if env::var("DOCS_RS").is_ok() {
|
if env::var("DOCS_RS").is_ok() {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue