Update whisper.cpp to v1.7.6
This commit is contained in:
parent
b202069aa8
commit
55e54212f1
6 changed files with 725 additions and 458 deletions
|
|
@ -15,6 +15,7 @@ mod whisper_grammar;
|
||||||
mod whisper_logging_hook;
|
mod whisper_logging_hook;
|
||||||
mod whisper_params;
|
mod whisper_params;
|
||||||
mod whisper_state;
|
mod whisper_state;
|
||||||
|
mod whisper_vad;
|
||||||
|
|
||||||
pub use common_logging::GGMLLogLevel;
|
pub use common_logging::GGMLLogLevel;
|
||||||
pub use error::WhisperError;
|
pub use error::WhisperError;
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::whisper_grammar::WhisperGrammarElement;
|
use crate::whisper_grammar::WhisperGrammarElement;
|
||||||
|
use crate::whisper_vad::WhisperVadParams;
|
||||||
use std::ffi::{c_char, c_float, c_int, CString};
|
use std::ffi::{c_char, c_float, c_int, CString};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
@ -33,19 +34,20 @@ pub struct SegmentCallbackData {
|
||||||
type SegmentCallbackFn = Box<dyn FnMut(SegmentCallbackData)>;
|
type SegmentCallbackFn = Box<dyn FnMut(SegmentCallbackData)>;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct FullParams<'a, 'b> {
|
pub struct FullParams<'a, 'b, 'c> {
|
||||||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||||
phantom_lang: PhantomData<&'a str>,
|
phantom_lang: PhantomData<&'a str>,
|
||||||
phantom_tokens: PhantomData<&'b [c_int]>,
|
phantom_tokens: PhantomData<&'b [c_int]>,
|
||||||
|
phantom_model_path: PhantomData<&'c str>,
|
||||||
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
||||||
progress_callback_safe: Option<Arc<Box<dyn FnMut(i32)>>>,
|
progress_callback_safe: Option<Arc<Box<dyn FnMut(i32)>>>,
|
||||||
abort_callback_safe: Option<Arc<Box<dyn FnMut() -> bool>>>,
|
abort_callback_safe: Option<Arc<Box<dyn FnMut() -> bool>>>,
|
||||||
segment_calllback_safe: Option<Arc<SegmentCallbackFn>>,
|
segment_calllback_safe: Option<Arc<SegmentCallbackFn>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, 'b> FullParams<'a, 'b> {
|
impl<'a, 'b, 'c> FullParams<'a, 'b, 'c> {
|
||||||
/// Create a new set of parameters for the decoder.
|
/// Create a new set of parameters for the decoder.
|
||||||
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b> {
|
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b, 'c> {
|
||||||
let mut fp = unsafe {
|
let mut fp = unsafe {
|
||||||
whisper_rs_sys::whisper_full_default_params(match sampling_strategy {
|
whisper_rs_sys::whisper_full_default_params(match sampling_strategy {
|
||||||
SamplingStrategy::Greedy { .. } => {
|
SamplingStrategy::Greedy { .. } => {
|
||||||
|
|
@ -74,6 +76,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
fp,
|
fp,
|
||||||
phantom_lang: PhantomData,
|
phantom_lang: PhantomData,
|
||||||
phantom_tokens: PhantomData,
|
phantom_tokens: PhantomData,
|
||||||
|
phantom_model_path: PhantomData,
|
||||||
grammar: None,
|
grammar: None,
|
||||||
progress_callback_safe: None,
|
progress_callback_safe: None,
|
||||||
abort_callback_safe: None,
|
abort_callback_safe: None,
|
||||||
|
|
@ -800,19 +803,52 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
.expect("Initial prompt contains null byte")
|
.expect("Initial prompt contains null byte")
|
||||||
.into_raw() as *const c_char;
|
.into_raw() as *const c_char;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Enable or disable VAD.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// This method will panic if `vad_model_path` is not set prior to enabling VAD.
|
||||||
|
pub fn enable_vad(&mut self, vad: bool) {
|
||||||
|
if vad && self.fp.vad_model_path.is_null() {
|
||||||
|
panic!("Set a VAD model path before calling enable_vad");
|
||||||
|
}
|
||||||
|
|
||||||
|
self.fp.vad = vad;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the path where a VAD model can be found. Passing `None` will clear it and disable VAD.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// This method will panic if `vad_model_path` contains a null byte.
|
||||||
|
pub fn set_vad_model_path(&mut self, vad_model_path: Option<&str>) {
|
||||||
|
self.fp.vad_model_path = if let Some(vad_model_path) = vad_model_path {
|
||||||
|
CString::new(vad_model_path)
|
||||||
|
.expect("VAD model path contains null byte")
|
||||||
|
.into_raw() as *const c_char
|
||||||
|
} else {
|
||||||
|
self.fp.vad = false;
|
||||||
|
|
||||||
|
std::ptr::null()
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Replace the VAD model parameters.
|
||||||
|
pub fn set_vad_params(&mut self, params: WhisperVadParams) {
|
||||||
|
self.fp.vad_params = params.into_inner();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// following implementations are safe
|
// following implementations are safe
|
||||||
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
|
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
|
||||||
// concurrent usage is prevented by &mut self on methods that modify the struct
|
// concurrent usage is prevented by &mut self on methods that modify the struct
|
||||||
unsafe impl Send for FullParams<'_, '_> {}
|
unsafe impl Send for FullParams<'_, '_, '_> {}
|
||||||
unsafe impl Sync for FullParams<'_, '_> {}
|
unsafe impl Sync for FullParams<'_, '_, '_> {}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test_whisper_params_initial_prompt {
|
mod test_whisper_params_initial_prompt {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
impl<'a, 'b> FullParams<'a, 'b> {
|
impl<'a, 'b, 'c> FullParams<'a, 'b, 'c> {
|
||||||
pub fn get_initial_prompt(&self) -> &str {
|
pub fn get_initial_prompt(&self) -> &str {
|
||||||
// SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp
|
// SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp
|
||||||
unsafe {
|
unsafe {
|
||||||
|
|
|
||||||
|
|
@ -588,4 +588,11 @@ impl WhisperState {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Get the no_speech probability for the specified segment
|
||||||
|
pub fn full_get_segment_no_speech_prob(&self, i_segment: c_int) -> f32 {
|
||||||
|
unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state(self.ptr, i_segment)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
298
src/whisper_vad.rs
Normal file
298
src/whisper_vad.rs
Normal file
|
|
@ -0,0 +1,298 @@
|
||||||
|
use crate::WhisperError;
|
||||||
|
use std::ffi::{c_char, CString};
|
||||||
|
use std::iter::Peekable;
|
||||||
|
use std::os::raw::c_int;
|
||||||
|
use whisper_rs_sys::{
|
||||||
|
whisper_vad_context, whisper_vad_context_params, whisper_vad_detect_speech, whisper_vad_free,
|
||||||
|
whisper_vad_free_segments, whisper_vad_init_from_file_with_params, whisper_vad_n_probs,
|
||||||
|
whisper_vad_params, whisper_vad_probs, whisper_vad_segments, whisper_vad_segments_from_probs,
|
||||||
|
whisper_vad_segments_from_samples, whisper_vad_segments_get_segment_t0,
|
||||||
|
whisper_vad_segments_get_segment_t1, whisper_vad_segments_n_segments,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Configuration for Voice Activity Detection in `whisper.cpp`.
|
||||||
|
///
|
||||||
|
/// See [the `whisper.cpp` README](https://github.com/ggml-org/whisper.cpp/#voice-activity-detection-vad) for more details.
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct WhisperVadParams {
|
||||||
|
params: whisper_vad_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WhisperVadParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
params: whisper_vad_params {
|
||||||
|
threshold: 0.5,
|
||||||
|
min_speech_duration_ms: 250,
|
||||||
|
min_silence_duration_ms: 100,
|
||||||
|
max_speech_duration_s: f32::MAX,
|
||||||
|
speech_pad_ms: 30,
|
||||||
|
samples_overlap: 0.1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WhisperVadParams {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the probability threshold to consider as speech.
|
||||||
|
/// A probability for a speech segment/frame above this threshold will be considered as speech.
|
||||||
|
///
|
||||||
|
/// Defaults to 0.5.
|
||||||
|
pub fn set_threshold(&mut self, threshold: f32) {
|
||||||
|
self.params.threshold = threshold;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the minimum duration for a valid speech segment, in milliseconds.
|
||||||
|
/// Speech segments shorter than this value will be discarded to filter out brief noise or false positives.
|
||||||
|
///
|
||||||
|
/// Defaults to 250 milliseconds.
|
||||||
|
pub fn set_min_speech_duration(&mut self, min_speech_duration: c_int) {
|
||||||
|
self.params.min_speech_duration_ms = min_speech_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the minimum silence duration to consider speech as ended.
|
||||||
|
/// Silence periods must be at least this long to end a speech segment.
|
||||||
|
/// Shorter silence periods will be ignored and included as part of the speech.
|
||||||
|
///
|
||||||
|
/// Defaults to 100 milliseconds.
|
||||||
|
pub fn set_min_silence_duration(&mut self, min_silence_duration: c_int) {
|
||||||
|
self.params.min_silence_duration_ms = min_silence_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the maximum duration of a speech segment before forcing a new segment.
|
||||||
|
/// Speech segments longer than this will be automatically split into multiple segments at
|
||||||
|
/// silence points exceeding 98ms to prevent excessively long segments.
|
||||||
|
///
|
||||||
|
/// Defaults to [`f32::MAX`].
|
||||||
|
pub fn set_max_speech_duration(&mut self, max_speech_duration: f32) {
|
||||||
|
self.params.max_speech_duration_s = max_speech_duration;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the amount of padding added before and after speech segments, in milliseconds.
|
||||||
|
/// Adds this amount of padding before and after each detected speech segment to avoid cutting off speech edges.
|
||||||
|
///
|
||||||
|
/// Defaults to 30 milliseconds.
|
||||||
|
pub fn set_speech_pad(&mut self, speech_pad: c_int) {
|
||||||
|
self.params.speech_pad_ms = speech_pad;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the amount of audio to extend from each speech segment into the next one, in seconds (e.g., 0.10 = 100ms overlap).
|
||||||
|
/// This ensures speech isn't cut off abruptly between segments when they're concatenated together.
|
||||||
|
///
|
||||||
|
/// Defaults to 0.1 seconds.
|
||||||
|
pub fn set_samples_overlap(&mut self, samples_overlap: f32) {
|
||||||
|
self.params.samples_overlap = samples_overlap;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn into_inner(self) -> whisper_vad_params {
|
||||||
|
self.params
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Whisper VAD context parameters
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct WhisperVadContextParams {
|
||||||
|
params: whisper_vad_context_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for WhisperVadContextParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
params: whisper_vad_context_params {
|
||||||
|
n_threads: 4,
|
||||||
|
use_gpu: false,
|
||||||
|
gpu_device: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WhisperVadContextParams {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Set the number of threads to use for processing
|
||||||
|
pub fn set_n_threads(&mut self, n_threads: c_int) {
|
||||||
|
self.params.n_threads = n_threads;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enable the GPU for VAD?
|
||||||
|
pub fn set_use_gpu(&mut self, use_gpu: bool) {
|
||||||
|
self.params.use_gpu = use_gpu;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The CUDA device to use if `use_gpu` is true
|
||||||
|
pub fn set_gpu_device(&mut self, gpu_device: c_int) {
|
||||||
|
self.params.gpu_device = gpu_device;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn into_inner(self) -> whisper_vad_context_params {
|
||||||
|
self.params
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A handle to use `whisper.cpp`'s built in VAD standalone.
|
||||||
|
///
|
||||||
|
/// You probably want to use [`Self::segments_from_samples`].
|
||||||
|
pub struct WhisperVadContext {
|
||||||
|
ptr: *mut whisper_vad_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WhisperVadContext {
|
||||||
|
pub fn new(model_path: &str, params: WhisperVadContextParams) -> Result<Self, WhisperError> {
|
||||||
|
let model_path = CString::new(model_path)
|
||||||
|
.expect("VAD model path contains null byte")
|
||||||
|
.into_raw() as *const c_char;
|
||||||
|
let ptr =
|
||||||
|
unsafe { whisper_vad_init_from_file_with_params(model_path, params.into_inner()) };
|
||||||
|
|
||||||
|
if ptr.is_null() {
|
||||||
|
Err(WhisperError::NullPointer)
|
||||||
|
} else {
|
||||||
|
Ok(Self { ptr })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Detect speech in `samples`. Call [`Self::segments_from_probabilities`] to finish the pipeline.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// This function will exclusively return `WhisperError::GenericError(-1)` on error.
|
||||||
|
/// If you've registered logging hooks, they will have much more detailed information.
|
||||||
|
pub fn detect_speech(&mut self, samples: &[f32]) -> Result<(), WhisperError> {
|
||||||
|
let (samples, len) = (samples.as_ptr(), samples.len() as c_int);
|
||||||
|
|
||||||
|
let success = unsafe { whisper_vad_detect_speech(self.ptr, samples, len) };
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
Err(WhisperError::GenericError(-1))
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get an array of probabilities. Undocumented use.
|
||||||
|
pub fn probabilities(&self) -> &[f32] {
|
||||||
|
let prob_ptr = unsafe { whisper_vad_probs(self.ptr) };
|
||||||
|
let prob_count = unsafe { whisper_vad_n_probs(self.ptr) }
|
||||||
|
.try_into()
|
||||||
|
.expect("n_probs is too large to fit into usize");
|
||||||
|
unsafe { core::slice::from_raw_parts(prob_ptr, prob_count) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Finish running the VAD pipeline and return segment details.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// The only possible error is [`WhisperError::NullPointer`].
|
||||||
|
pub fn segments_from_probabilities(
|
||||||
|
&mut self,
|
||||||
|
params: WhisperVadParams,
|
||||||
|
) -> Result<WhisperVadSegments, WhisperError> {
|
||||||
|
let ptr = unsafe { whisper_vad_segments_from_probs(self.ptr, params.into_inner()) };
|
||||||
|
|
||||||
|
if ptr.is_null() {
|
||||||
|
Err(WhisperError::NullPointer)
|
||||||
|
} else {
|
||||||
|
Ok(WhisperVadSegments { ptr })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the entire VAD pipeline.
|
||||||
|
/// This calls both [`Self::detect_speech`] and [`Self::segments_from_probabilities`] behind the scenes.
|
||||||
|
///
|
||||||
|
/// # Errors
|
||||||
|
/// The only possible error is [`WhisperError::NullPointer`].
|
||||||
|
pub fn segments_from_samples(
|
||||||
|
&mut self,
|
||||||
|
params: WhisperVadParams,
|
||||||
|
samples: &[f32],
|
||||||
|
) -> Result<WhisperVadSegments, WhisperError> {
|
||||||
|
let (sample_ptr, sample_len) = (samples.as_ptr(), samples.len() as c_int);
|
||||||
|
let ptr = unsafe {
|
||||||
|
whisper_vad_segments_from_samples(self.ptr, params.into_inner(), sample_ptr, sample_len)
|
||||||
|
};
|
||||||
|
|
||||||
|
if ptr.is_null() {
|
||||||
|
Err(WhisperError::NullPointer)
|
||||||
|
} else {
|
||||||
|
Ok(WhisperVadSegments { ptr })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for WhisperVadContext {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
unsafe { whisper_vad_free(self.ptr) }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// You can obtain this struct from a [`WhisperVadContext`].
|
||||||
|
pub struct WhisperVadSegments {
|
||||||
|
ptr: *mut whisper_vad_segments,
|
||||||
|
segment_count: c_int,
|
||||||
|
iter_idx: c_int,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WhisperVadSegments {
|
||||||
|
fn new(ptr: *mut whisper_vad_segments) -> Self {
|
||||||
|
let segment_count = unsafe { whisper_vad_segments_n_segments(ptr) };
|
||||||
|
Self {
|
||||||
|
ptr,
|
||||||
|
segment_count,
|
||||||
|
iter_idx: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_segments(&self) -> c_int {
|
||||||
|
self.segment_count
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the start timestamp of this segment in centiseconds (10s of milliseconds).
|
||||||
|
pub fn get_segment_start_timestamp(&self, idx: c_int) -> Option<f32> {
|
||||||
|
if idx < 0 || idx > self.segment_count {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(unsafe { whisper_vad_segments_get_segment_t0(self.ptr, idx) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return the end timestamp of this segment in centiseconds (10s of milliseconds).
|
||||||
|
pub fn get_segment_end_timestamp(&self, idx: c_int) -> Option<f32> {
|
||||||
|
if idx < 0 || idx > self.segment_count {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(unsafe { whisper_vad_segments_get_segment_t1(self.ptr, idx) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Iterator for WhisperVadSegments {
|
||||||
|
type Item = WhisperVadSegment;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.iter_idx > self.segment_count {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let start = unsafe { whisper_vad_segments_get_segment_t0(self.ptr, self.iter_idx) };
|
||||||
|
let end = unsafe { whisper_vad_segments_get_segment_t1(self.ptr, self.iter_idx) };
|
||||||
|
self.iter_idx += 1;
|
||||||
|
Some(WhisperVadSegment { start, end })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub struct WhisperVadSegment {
|
||||||
|
start: f32,
|
||||||
|
end: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Drop for WhisperVadSegments {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
unsafe { whisper_vad_free_segments(self.ptr) }
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1 +1 @@
|
||||||
Subproject commit 8a9ad7844d6e2a10cddf4b92de4089d7ac2b14a9
|
Subproject commit a8d002cfd879315632a579e73f0148d06959de36
|
||||||
Loading…
Add table
Add a link
Reference in a new issue