update upstream whisper.cpp and fix breaking changes
This commit is contained in:
parent
4ccd746949
commit
1632ac11fe
10 changed files with 657 additions and 402 deletions
|
|
@ -10,6 +10,7 @@ pub use error::WhisperError;
|
|||
pub use standalone::*;
|
||||
pub use utilities::*;
|
||||
pub use whisper_ctx::WhisperContext;
|
||||
pub use whisper_params::{DecodeStrategy, FullParams};
|
||||
pub use whisper_params::{FullParams, SamplingStrategy};
|
||||
|
||||
pub type WhisperToken = std::ffi::c_int;
|
||||
pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callback;
|
||||
|
|
|
|||
|
|
@ -42,3 +42,11 @@ pub fn token_translate() -> WhisperToken {
|
|||
pub fn token_transcribe() -> WhisperToken {
|
||||
unsafe { whisper_rs_sys::whisper_token_transcribe() }
|
||||
}
|
||||
|
||||
/// Print system information.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_print_system_info()`
|
||||
pub fn print_system_info() {
|
||||
unsafe { whisper_rs_sys::whisper_print_system_info() };
|
||||
}
|
||||
|
|
|
|||
|
|
@ -194,11 +194,11 @@ impl WhisperContext {
|
|||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
|
||||
pub fn sample_best(&mut self, needs_timestamp: bool) -> Result<WhisperToken, WhisperError> {
|
||||
pub fn sample_best(&mut self) -> Result<WhisperToken, WhisperError> {
|
||||
if !self.decode_once {
|
||||
return Err(WhisperError::DecodeNotComplete);
|
||||
}
|
||||
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) };
|
||||
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) };
|
||||
Ok(ret)
|
||||
}
|
||||
|
||||
|
|
@ -446,6 +446,90 @@ impl WhisperContext {
|
|||
let r_str = c_str.to_str()?;
|
||||
Ok(r_str.to_string())
|
||||
}
|
||||
|
||||
/// Get number of tokens in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(c_int) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
||||
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_full_n_tokens(self.ctx, segment) };
|
||||
if ret < 0 {
|
||||
Err(WhisperError::GenericError(ret))
|
||||
} else {
|
||||
Ok(ret as c_int)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the token text of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_text(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<String, WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_text(self.ctx, segment, token) };
|
||||
if ret.is_null() {
|
||||
return Err(WhisperError::NullPointer);
|
||||
}
|
||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||
let r_str = c_str.to_str()?;
|
||||
Ok(r_str.to_string())
|
||||
}
|
||||
|
||||
/// Get the token ID of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
pub fn full_get_token_id(
|
||||
&self,
|
||||
segment: c_int,
|
||||
token: c_int,
|
||||
) -> Result<WhisperToken, WhisperError> {
|
||||
let ret = unsafe { whisper_rs_sys::whisper_full_get_token_id(self.ctx, segment, token) };
|
||||
if ret < 0 {
|
||||
Err(WhisperError::GenericError(ret))
|
||||
} else {
|
||||
Ok(ret as WhisperToken)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the probability of the specified token in the specified segment.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * segment: Segment index.
|
||||
/// * token: Token index.
|
||||
///
|
||||
/// # Returns
|
||||
/// f32
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||
#[inline]
|
||||
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> f32 {
|
||||
unsafe { whisper_rs_sys::whisper_full_get_token_p(self.ctx, segment, token) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WhisperContext {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
use std::ffi::c_int;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub enum DecodeStrategy {
|
||||
pub enum SamplingStrategy {
|
||||
Greedy {
|
||||
n_past: c_int,
|
||||
},
|
||||
|
|
@ -20,26 +20,30 @@ pub struct FullParams<'a> {
|
|||
|
||||
impl<'a> FullParams<'a> {
|
||||
/// Create a new set of parameters for the decoder.
|
||||
pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> {
|
||||
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a> {
|
||||
let mut fp = unsafe {
|
||||
whisper_rs_sys::whisper_full_default_params(match decode_strategy {
|
||||
DecodeStrategy::Greedy { .. } => 0,
|
||||
DecodeStrategy::BeamSearch { .. } => 1,
|
||||
whisper_rs_sys::whisper_full_default_params(match sampling_strategy {
|
||||
SamplingStrategy::Greedy { .. } => {
|
||||
whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_GREEDY
|
||||
}
|
||||
SamplingStrategy::BeamSearch { .. } => {
|
||||
whisper_rs_sys::whisper_sampling_strategy_WHISPER_SAMPLING_BEAM_SEARCH
|
||||
}
|
||||
} as _)
|
||||
};
|
||||
|
||||
match decode_strategy {
|
||||
DecodeStrategy::Greedy { n_past } => {
|
||||
fp.__bindgen_anon_1.greedy.n_past = n_past;
|
||||
match sampling_strategy {
|
||||
SamplingStrategy::Greedy { n_past } => {
|
||||
fp.greedy.n_past = n_past;
|
||||
}
|
||||
DecodeStrategy::BeamSearch {
|
||||
SamplingStrategy::BeamSearch {
|
||||
n_past,
|
||||
beam_width,
|
||||
n_best,
|
||||
} => {
|
||||
fp.__bindgen_anon_1.beam_search.n_past = n_past;
|
||||
fp.__bindgen_anon_1.beam_search.beam_width = beam_width;
|
||||
fp.__bindgen_anon_1.beam_search.n_best = n_best;
|
||||
fp.beam_search.n_past = n_past;
|
||||
fp.beam_search.beam_width = beam_width;
|
||||
fp.beam_search.n_best = n_best;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -111,6 +115,34 @@ impl<'a> FullParams<'a> {
|
|||
pub fn set_language(&mut self, language: &'a str) {
|
||||
self.fp.language = language.as_ptr() as *const _;
|
||||
}
|
||||
|
||||
/// Set the callback for new segments.
|
||||
///
|
||||
/// Note that this callback has not been Rustified yet (and likely never will be, unless someone else feels the need to do so).
|
||||
/// It is still a C callback.
|
||||
///
|
||||
/// # Safety
|
||||
/// Do not use this function unless you know what you are doing.
|
||||
/// * Be careful not to mutate the state of the whisper_context pointer returned in the callback.
|
||||
/// This could cause undefined behavior, as this violates the thread-safety guarantees of the underlying C library.
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub unsafe fn set_new_segment_callback(
|
||||
&mut self,
|
||||
new_segment_callback: crate::WhisperNewSegmentCallback,
|
||||
) {
|
||||
self.fp.new_segment_callback = new_segment_callback;
|
||||
}
|
||||
|
||||
/// Set the user data to be passed to the new segment callback.
|
||||
///
|
||||
/// # Safety
|
||||
/// See the safety notes for `set_new_segment_callback`.
|
||||
///
|
||||
/// Defaults to None.
|
||||
pub unsafe fn set_new_segment_callback_user_data(&mut self, user_data: *mut std::ffi::c_void) {
|
||||
self.fp.new_segment_callback_user_data = user_data;
|
||||
}
|
||||
}
|
||||
|
||||
// following implementations are safe
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue