update to latest whisper.cpp
This commit is contained in:
parent
0da39195c0
commit
edac524756
4 changed files with 277 additions and 55 deletions
|
|
@ -35,6 +35,8 @@ pub enum WhisperError {
|
||||||
NullPointer,
|
NullPointer,
|
||||||
/// Generic whisper error. Varies depending on the function.
|
/// Generic whisper error. Varies depending on the function.
|
||||||
GenericError(c_int),
|
GenericError(c_int),
|
||||||
|
/// Whisper failed to convert the provided text into tokens.
|
||||||
|
InvalidText,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Utf8Error> for WhisperError {
|
impl From<Utf8Error> for WhisperError {
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,34 @@ pub fn get_lang_id(lang: &str) -> Option<c_int> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Return the ID of the maximum language (ie the number of languages - 1)
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// i32
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_lang_max_id()`
|
||||||
|
pub fn get_lang_max_id() -> i32 {
|
||||||
|
unsafe { whisper_rs_sys::whisper_lang_max_id() }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the short string of the specified language id (e.g. 2 -> "de").
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// The short string of the language, None if not found.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_lang_str(int id)`
|
||||||
|
pub fn get_lang_str(id: i32) -> Option<&'static str> {
|
||||||
|
let c_buf = unsafe { whisper_rs_sys::whisper_lang_str(id) };
|
||||||
|
if c_buf.is_null() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let c_str = unsafe { CStr::from_ptr(c_buf) };
|
||||||
|
Some(c_str.to_str().unwrap())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// task tokens
|
// task tokens
|
||||||
/// Get the ID of the translate task token.
|
/// Get the ID of the translate task token.
|
||||||
///
|
///
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ pub struct WhisperContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl WhisperContext {
|
impl WhisperContext {
|
||||||
/// Create a new WhisperContext.
|
/// Create a new WhisperContext from a file.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * path: The path to the model file.
|
/// * path: The path to the model file.
|
||||||
|
|
@ -28,10 +28,10 @@ impl WhisperContext {
|
||||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `struct whisper_context * whisper_init(const char * path_model);`
|
/// `struct whisper_context * whisper_init_from_file(const char * path_model);`
|
||||||
pub fn new(path: &str) -> Result<Self, WhisperError> {
|
pub fn new(path: &str) -> Result<Self, WhisperError> {
|
||||||
let path_cstr = CString::new(path)?;
|
let path_cstr = CString::new(path)?;
|
||||||
let ctx = unsafe { whisper_rs_sys::whisper_init(path_cstr.as_ptr()) };
|
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file(path_cstr.as_ptr()) };
|
||||||
if ctx.is_null() {
|
if ctx.is_null() {
|
||||||
Err(WhisperError::InitError)
|
Err(WhisperError::InitError)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -44,6 +44,33 @@ impl WhisperContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a new WhisperContext from a buffer.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * buffer: The buffer containing the model.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `struct whisper_context * whisper_init_from_buffer(const char * buffer, int n_bytes);`
|
||||||
|
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
|
||||||
|
let ctx =
|
||||||
|
unsafe { whisper_rs_sys::whisper_init_from_buffer(buffer.as_ptr() as _, buffer.len()) };
|
||||||
|
if ctx.is_null() {
|
||||||
|
Err(WhisperError::InitError)
|
||||||
|
} else {
|
||||||
|
Ok(Self {
|
||||||
|
ctx,
|
||||||
|
spectrogram_initialized: false,
|
||||||
|
encode_complete: false,
|
||||||
|
decode_once: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
|
||||||
|
|
||||||
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
||||||
/// The resulting spectrogram is stored in the context transparently.
|
/// The resulting spectrogram is stored in the context transparently.
|
||||||
///
|
///
|
||||||
|
|
@ -190,40 +217,90 @@ impl WhisperContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Token sampling functions
|
/// Convert the provided text into tokens.
|
||||||
/// Return the token with the highest probability.
|
|
||||||
/// Make sure to call [WhisperContext::decode] first.
|
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * needs_timestamp
|
/// * text: The text to convert.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
|
/// Ok(Vec<WhisperToken>) on success, Err(WhisperError) on failure.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_sample_best(struct whisper_context * ctx, bool need_timestamp)`
|
/// `int whisper_tokenize(struct whisper_context * ctx, const char * text, whisper_token * tokens, int n_max_tokens);`
|
||||||
pub fn sample_best(&mut self) -> Result<WhisperTokenData, WhisperError> {
|
pub fn tokenize(
|
||||||
if !self.decode_once {
|
&mut self,
|
||||||
return Err(WhisperError::DecodeNotComplete);
|
text: &str,
|
||||||
|
max_tokens: usize,
|
||||||
|
) -> Result<Vec<WhisperToken>, WhisperError> {
|
||||||
|
// allocate at least max_tokens to ensure the memory is valid
|
||||||
|
let mut tokens: Vec<WhisperToken> = Vec::with_capacity(max_tokens);
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_tokenize(
|
||||||
|
self.ctx,
|
||||||
|
text.as_ptr() as *const _,
|
||||||
|
tokens.as_mut_ptr(),
|
||||||
|
max_tokens as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::InvalidText)
|
||||||
|
} else {
|
||||||
|
// SAFETY: when ret != -1, we know that the length of the vector is at least ret tokens
|
||||||
|
unsafe { tokens.set_len(ret as usize) };
|
||||||
|
Ok(tokens)
|
||||||
}
|
}
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx) };
|
|
||||||
Ok(ret)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return the token with the most probable timestamp.
|
// Language functions
|
||||||
/// Make sure to call [WhisperContext::decode] first.
|
/// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||||
|
/// Make sure to call pcm_to_mel() or set_mel() first
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * offset_ms: The offset in milliseconds to use for the language detection.
|
||||||
|
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// Ok(WhisperToken) on success, Err(WhisperError) on failure.
|
/// Ok(Vec<f32>) on success, Err(WhisperError) on failure.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `whisper_token whisper_sample_timestamp(struct whisper_context * ctx)`
|
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
|
||||||
pub fn sample_timestamp(&mut self, is_initial: bool) -> Result<WhisperTokenData, WhisperError> {
|
pub fn lang_detect(
|
||||||
if !self.decode_once {
|
&mut self,
|
||||||
return Err(WhisperError::DecodeNotComplete);
|
offset_ms: usize,
|
||||||
|
threads: usize,
|
||||||
|
) -> Result<Vec<f32>, WhisperError> {
|
||||||
|
if !self.spectrogram_initialized {
|
||||||
|
return Err(WhisperError::SpectrogramNotInitialized);
|
||||||
|
}
|
||||||
|
if threads < 1 {
|
||||||
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
|
}
|
||||||
|
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_lang_auto_detect(
|
||||||
|
self.ctx,
|
||||||
|
offset_ms as c_int,
|
||||||
|
threads as c_int,
|
||||||
|
lang_probs.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateEvaluation)
|
||||||
|
} else {
|
||||||
|
assert_eq!(
|
||||||
|
ret as usize,
|
||||||
|
lang_probs.len(),
|
||||||
|
"lang_probs length mismatch: this is a bug in whisper.cpp"
|
||||||
|
);
|
||||||
|
// if we're still running, double check that the length is correct, otherwise print to stderr
|
||||||
|
// and abort, as this will cause Undefined Behavior
|
||||||
|
// might get here due to the unwind being caught by a user-installed panic handler
|
||||||
|
if lang_probs.len() != ret as usize {
|
||||||
|
eprintln!("lang_probs length mismatch: this is a bug in whisper.cpp, aborting");
|
||||||
|
std::process::abort();
|
||||||
|
}
|
||||||
|
Ok(lang_probs)
|
||||||
}
|
}
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_sample_timestamp(self.ctx, is_initial) };
|
|
||||||
Ok(ret)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// model attributes
|
// model attributes
|
||||||
|
|
@ -263,6 +340,18 @@ impl WhisperContext {
|
||||||
unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get n_audio_ctx.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// c_int
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_n_audio_ctx (struct whisper_context * ctx)`
|
||||||
|
#[inline]
|
||||||
|
pub fn n_audio_ctx(&self) -> c_int {
|
||||||
|
unsafe { whisper_rs_sys::whisper_n_audio_ctx(self.ctx) }
|
||||||
|
}
|
||||||
|
|
||||||
/// Does this model support multiple languages?
|
/// Does this model support multiple languages?
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
|
|
@ -272,25 +361,46 @@ impl WhisperContext {
|
||||||
unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
|
unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The probabilities for the next token.
|
// logit functions
|
||||||
/// Make sure to call [WhisperContext::decode] first.
|
/// Get the logits obtained from the last call to [WhisperContext::decode].
|
||||||
|
/// The logits for the last token are stored in the last row of the matrix.
|
||||||
|
///
|
||||||
|
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
|
||||||
|
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: The segment to fetch data for.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// Ok(*const f32) on success, Err(WhisperError) on failure.
|
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `float * whisper_get_probs(struct whisper_context * ctx)`
|
/// `float * whisper_get_logits(struct whisper_context * ctx)`
|
||||||
pub fn get_probs(&mut self) -> Result<*const f32, WhisperError> {
|
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
|
||||||
if !self.decode_once {
|
if !self.spectrogram_initialized {
|
||||||
return Err(WhisperError::DecodeNotComplete);
|
return Err(WhisperError::SpectrogramNotInitialized);
|
||||||
}
|
}
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_get_probs(self.ctx) };
|
|
||||||
|
let ret = unsafe { whisper_rs_sys::whisper_get_logits(self.ctx) };
|
||||||
if ret.is_null() {
|
if ret.is_null() {
|
||||||
return Err(WhisperError::NullPointer);
|
return Err(WhisperError::NullPointer);
|
||||||
}
|
}
|
||||||
Ok(ret)
|
let mut logits = Vec::new();
|
||||||
|
let n_vocab = self.n_vocab();
|
||||||
|
let n_tokens = self.full_n_tokens(segment);
|
||||||
|
for i in 0..n_tokens {
|
||||||
|
let mut row = Vec::new();
|
||||||
|
for j in 0..n_vocab {
|
||||||
|
let idx = (i * n_vocab) + j;
|
||||||
|
let val = unsafe { *ret.offset(idx as isize) };
|
||||||
|
row.push(val);
|
||||||
|
}
|
||||||
|
logits.push(row);
|
||||||
|
}
|
||||||
|
Ok(logits)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// token functions
|
||||||
/// Convert a token ID to a string.
|
/// Convert a token ID to a string.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
@ -311,7 +421,6 @@ impl WhisperContext {
|
||||||
Ok(r_str.to_string())
|
Ok(r_str.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
// special tokens
|
|
||||||
/// Get the ID of the eot token.
|
/// Get the ID of the eot token.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
|
|
@ -366,6 +475,18 @@ impl WhisperContext {
|
||||||
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the ID of a specified language token
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * lang_id: ID of the language
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `whisper_token whisper_token_lang(struct whisper_context * ctx, int lang_id)`
|
||||||
|
#[inline]
|
||||||
|
pub fn token_lang(&self, lang_id: c_int) -> WhisperToken {
|
||||||
|
unsafe { whisper_rs_sys::whisper_token_lang(self.ctx, lang_id) }
|
||||||
|
}
|
||||||
|
|
||||||
/// Print performance statistics to stderr.
|
/// Print performance statistics to stderr.
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,16 @@
|
||||||
use std::ffi::{c_int, CString};
|
use std::ffi::{c_float, c_int, CString};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use whisper_rs_sys::whisper_token;
|
use whisper_rs_sys::whisper_token;
|
||||||
|
|
||||||
pub enum SamplingStrategy {
|
pub enum SamplingStrategy {
|
||||||
Greedy {
|
Greedy {
|
||||||
n_past: c_int,
|
best_of: c_int,
|
||||||
},
|
},
|
||||||
/// not implemented yet, results of using this unknown
|
/// not implemented yet, results of using this unknown
|
||||||
BeamSearch {
|
BeamSearch {
|
||||||
n_past: c_int,
|
beam_size: c_int,
|
||||||
beam_width: c_int,
|
// not implemented in whisper.cpp as of this writing (v1.2.0)
|
||||||
n_best: c_int,
|
patience: c_float,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -35,17 +35,15 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
};
|
};
|
||||||
|
|
||||||
match sampling_strategy {
|
match sampling_strategy {
|
||||||
SamplingStrategy::Greedy { n_past } => {
|
SamplingStrategy::Greedy { best_of } => {
|
||||||
fp.greedy.n_past = n_past;
|
fp.greedy.best_of = best_of;
|
||||||
}
|
}
|
||||||
SamplingStrategy::BeamSearch {
|
SamplingStrategy::BeamSearch {
|
||||||
n_past,
|
beam_size,
|
||||||
beam_width,
|
patience,
|
||||||
n_best,
|
|
||||||
} => {
|
} => {
|
||||||
fp.beam_search.n_past = n_past;
|
fp.beam_search.beam_size = beam_size;
|
||||||
fp.beam_search.beam_width = beam_width;
|
fp.beam_search.patience = patience;
|
||||||
fp.beam_search.n_best = n_best;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -63,7 +61,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
self.fp.n_threads = n_threads;
|
self.fp.n_threads = n_threads;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set n_max_text_ctx.
|
/// Max tokens to use from past text as prompt for the decoder
|
||||||
///
|
///
|
||||||
/// Defaults to 16384.
|
/// Defaults to 16384.
|
||||||
pub fn set_n_max_text_ctx(&mut self, n_max_text_ctx: c_int) {
|
pub fn set_n_max_text_ctx(&mut self, n_max_text_ctx: c_int) {
|
||||||
|
|
@ -91,7 +89,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
self.fp.translate = translate;
|
self.fp.translate = translate;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set no_context. Usage unknown.
|
/// Do not use past transcription (if any) as initial prompt for the decoder.
|
||||||
///
|
///
|
||||||
/// Defaults to false.
|
/// Defaults to false.
|
||||||
pub fn set_no_context(&mut self, no_context: bool) {
|
pub fn set_no_context(&mut self, no_context: bool) {
|
||||||
|
|
@ -105,7 +103,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
self.fp.single_segment = single_segment;
|
self.fp.single_segment = single_segment;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set print_special. Usage unknown.
|
/// Print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||||
///
|
///
|
||||||
/// Defaults to false.
|
/// Defaults to false.
|
||||||
pub fn set_print_special(&mut self, print_special: bool) {
|
pub fn set_print_special(&mut self, print_special: bool) {
|
||||||
|
|
@ -119,14 +117,17 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
self.fp.print_progress = print_progress;
|
self.fp.print_progress = print_progress;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set print_realtime. Usage unknown.
|
/// Print results from within whisper.cpp.
|
||||||
|
/// Try to use the callback methods instead: [set_new_segment_callback](FullParams::set_new_segment_callback),
|
||||||
|
/// [set_new_segment_callback_user_data](FullParams::set_new_segment_callback_user_data).
|
||||||
///
|
///
|
||||||
/// Defaults to false.
|
/// Defaults to false.
|
||||||
pub fn set_print_realtime(&mut self, print_realtime: bool) {
|
pub fn set_print_realtime(&mut self, print_realtime: bool) {
|
||||||
self.fp.print_realtime = print_realtime;
|
self.fp.print_realtime = print_realtime;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set whether to print timestamps.
|
/// Print timestamps for each text segment when printing realtime. Only has an effect if
|
||||||
|
/// [set_print_realtime](FullParams::set_print_realtime) is set to true.
|
||||||
///
|
///
|
||||||
/// Defaults to true.
|
/// Defaults to true.
|
||||||
pub fn set_print_timestamps(&mut self, print_timestamps: bool) {
|
pub fn set_print_timestamps(&mut self, print_timestamps: bool) {
|
||||||
|
|
@ -181,6 +182,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
/// # EXPERIMENTAL
|
/// # EXPERIMENTAL
|
||||||
///
|
///
|
||||||
/// Speed up audio ~2x by using phase vocoder.
|
/// Speed up audio ~2x by using phase vocoder.
|
||||||
|
/// Note that this can significantly reduce the accuracy of the transcription.
|
||||||
///
|
///
|
||||||
/// Defaults to false.
|
/// Defaults to false.
|
||||||
pub fn set_speed_up(&mut self, speed_up: bool) {
|
pub fn set_speed_up(&mut self, speed_up: bool) {
|
||||||
|
|
@ -190,6 +192,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
/// # EXPERIMENTAL
|
/// # EXPERIMENTAL
|
||||||
///
|
///
|
||||||
/// Overwrite the audio context size. 0 = default.
|
/// Overwrite the audio context size. 0 = default.
|
||||||
|
/// As with [set_speed_up](FullParams::set_speed_up), this can significantly reduce the accuracy of the transcription.
|
||||||
///
|
///
|
||||||
/// Defaults to 0.
|
/// Defaults to 0.
|
||||||
pub fn set_audio_ctx(&mut self, audio_ctx: c_int) {
|
pub fn set_audio_ctx(&mut self, audio_ctx: c_int) {
|
||||||
|
|
@ -215,10 +218,78 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
|
|
||||||
/// Set the target language.
|
/// Set the target language.
|
||||||
///
|
///
|
||||||
|
/// For auto-detection, set this to either "auto" or None.
|
||||||
|
///
|
||||||
/// Defaults to "en".
|
/// Defaults to "en".
|
||||||
pub fn set_language(&mut self, language: &'a str) {
|
pub fn set_language(&mut self, language: Option<&'a str>) {
|
||||||
let c_lang = CString::new(language).expect("Language contains null byte");
|
self.fp.language = match language {
|
||||||
self.fp.language = c_lang.into_raw() as *const _;
|
Some(language) => CString::new(language)
|
||||||
|
.expect("Language contains null byte")
|
||||||
|
.into_raw() as *const _,
|
||||||
|
None => std::ptr::null(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set suppress_blank. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||||
|
/// for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to true.
|
||||||
|
pub fn set_suppress_blank(&mut self, suppress_blank: bool) {
|
||||||
|
self.fp.suppress_blank = suppress_blank;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set initial decoding temperature. See https://ai.stackexchange.com/a/32478 for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to 0.0.
|
||||||
|
pub fn set_temperature(&mut self, temperature: f32) {
|
||||||
|
self.fp.temperature = temperature;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set max_initial_ts. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
||||||
|
/// for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to 1.0.
|
||||||
|
pub fn set_max_initial_ts(&mut self, max_initial_ts: f32) {
|
||||||
|
self.fp.max_initial_ts = max_initial_ts;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set length_penalty. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
|
||||||
|
/// for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to -1.0.
|
||||||
|
pub fn set_length_penalty(&mut self, length_penalty: f32) {
|
||||||
|
self.fp.length_penalty = length_penalty;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set temperature_inc. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
|
||||||
|
/// for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to 0.2.
|
||||||
|
pub fn set_temperature_inc(&mut self, temperature_inc: f32) {
|
||||||
|
self.fp.temperature_inc = temperature_inc;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set entropy_thold. Similar to OpenAI's compression_ratio_threshold.
|
||||||
|
/// See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278 for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to 2.4.
|
||||||
|
pub fn set_entropy_thold(&mut self, entropy_thold: f32) {
|
||||||
|
self.fp.entropy_thold = entropy_thold;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set logprob_thold. See https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
|
||||||
|
/// for more information.
|
||||||
|
///
|
||||||
|
/// Defaults to -1.0.
|
||||||
|
pub fn set_logprob_thold(&mut self, logprob_thold: f32) {
|
||||||
|
self.fp.logprob_thold = logprob_thold;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set no_speech_thold. Currently (as of v1.2.0) not implemented.
|
||||||
|
///
|
||||||
|
/// Defaults to 0.6.
|
||||||
|
pub fn set_no_speech_thold(&mut self, no_speech_thold: f32) {
|
||||||
|
self.fp.no_speech_thold = no_speech_thold;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the callback for new segments.
|
/// Set the callback for new segments.
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue