Update whisper.cpp to v1.7.6
This commit is contained in:
parent
b202069aa8
commit
55e54212f1
6 changed files with 725 additions and 458 deletions
|
|
@ -15,6 +15,7 @@ mod whisper_grammar;
|
|||
mod whisper_logging_hook;
|
||||
mod whisper_params;
|
||||
mod whisper_state;
|
||||
mod whisper_vad;
|
||||
|
||||
pub use common_logging::GGMLLogLevel;
|
||||
pub use error::WhisperError;
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
use crate::whisper_grammar::WhisperGrammarElement;
|
||||
use crate::whisper_vad::WhisperVadParams;
|
||||
use std::ffi::{c_char, c_float, c_int, CString};
|
||||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -33,19 +34,20 @@ pub struct SegmentCallbackData {
|
|||
type SegmentCallbackFn = Box<dyn FnMut(SegmentCallbackData)>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct FullParams<'a, 'b> {
|
||||
pub struct FullParams<'a, 'b, 'c> {
|
||||
pub(crate) fp: whisper_rs_sys::whisper_full_params,
|
||||
phantom_lang: PhantomData<&'a str>,
|
||||
phantom_tokens: PhantomData<&'b [c_int]>,
|
||||
phantom_model_path: PhantomData<&'c str>,
|
||||
grammar: Option<Vec<whisper_rs_sys::whisper_grammar_element>>,
|
||||
progress_callback_safe: Option<Arc<Box<dyn FnMut(i32)>>>,
|
||||
abort_callback_safe: Option<Arc<Box<dyn FnMut() -> bool>>>,
|
||||
segment_calllback_safe: Option<Arc<SegmentCallbackFn>>,
|
||||
}
|
||||
|
||||
impl<'a, 'b> FullParams<'a, 'b> {
|
||||
impl<'a, 'b, 'c> FullParams<'a, 'b, 'c> {
|
||||
/// Create a new set of parameters for the decoder.
|
||||
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b> {
|
||||
pub fn new(sampling_strategy: SamplingStrategy) -> FullParams<'a, 'b, 'c> {
|
||||
let mut fp = unsafe {
|
||||
whisper_rs_sys::whisper_full_default_params(match sampling_strategy {
|
||||
SamplingStrategy::Greedy { .. } => {
|
||||
|
|
@ -74,6 +76,7 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
fp,
|
||||
phantom_lang: PhantomData,
|
||||
phantom_tokens: PhantomData,
|
||||
phantom_model_path: PhantomData,
|
||||
grammar: None,
|
||||
progress_callback_safe: None,
|
||||
abort_callback_safe: None,
|
||||
|
|
@ -800,19 +803,52 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
.expect("Initial prompt contains null byte")
|
||||
.into_raw() as *const c_char;
|
||||
}
|
||||
|
||||
/// Enable or disable VAD.
|
||||
///
|
||||
/// # Panics
|
||||
/// This method will panic if `vad_model_path` is not set prior to enabling VAD.
|
||||
pub fn enable_vad(&mut self, vad: bool) {
|
||||
if vad && self.fp.vad_model_path.is_null() {
|
||||
panic!("Set a VAD model path before calling enable_vad");
|
||||
}
|
||||
|
||||
self.fp.vad = vad;
|
||||
}
|
||||
|
||||
/// Set the path where a VAD model can be found. Passing `None` will clear it and disable VAD.
|
||||
///
|
||||
/// # Panics
|
||||
/// This method will panic if `vad_model_path` contains a null byte.
|
||||
pub fn set_vad_model_path(&mut self, vad_model_path: Option<&str>) {
|
||||
self.fp.vad_model_path = if let Some(vad_model_path) = vad_model_path {
|
||||
CString::new(vad_model_path)
|
||||
.expect("VAD model path contains null byte")
|
||||
.into_raw() as *const c_char
|
||||
} else {
|
||||
self.fp.vad = false;
|
||||
|
||||
std::ptr::null()
|
||||
};
|
||||
}
|
||||
|
||||
/// Replace the VAD model parameters.
|
||||
pub fn set_vad_params(&mut self, params: WhisperVadParams) {
|
||||
self.fp.vad_params = params.into_inner();
|
||||
}
|
||||
}
|
||||
|
||||
// following implementations are safe
|
||||
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
|
||||
// concurrent usage is prevented by &mut self on methods that modify the struct
|
||||
unsafe impl Send for FullParams<'_, '_> {}
|
||||
unsafe impl Sync for FullParams<'_, '_> {}
|
||||
unsafe impl Send for FullParams<'_, '_, '_> {}
|
||||
unsafe impl Sync for FullParams<'_, '_, '_> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_whisper_params_initial_prompt {
|
||||
use super::*;
|
||||
|
||||
impl<'a, 'b> FullParams<'a, 'b> {
|
||||
impl<'a, 'b, 'c> FullParams<'a, 'b, 'c> {
|
||||
pub fn get_initial_prompt(&self) -> &str {
|
||||
// SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp
|
||||
unsafe {
|
||||
|
|
|
|||
|
|
@ -588,4 +588,11 @@ impl WhisperState {
|
|||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the no_speech probability for the specified segment
|
||||
pub fn full_get_segment_no_speech_prob(&self, i_segment: c_int) -> f32 {
|
||||
unsafe {
|
||||
whisper_rs_sys::whisper_full_get_segment_no_speech_prob_from_state(self.ptr, i_segment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
298
src/whisper_vad.rs
Normal file
298
src/whisper_vad.rs
Normal file
|
|
@ -0,0 +1,298 @@
|
|||
use crate::WhisperError;
|
||||
use std::ffi::{c_char, CString};
|
||||
use std::iter::Peekable;
|
||||
use std::os::raw::c_int;
|
||||
use whisper_rs_sys::{
|
||||
whisper_vad_context, whisper_vad_context_params, whisper_vad_detect_speech, whisper_vad_free,
|
||||
whisper_vad_free_segments, whisper_vad_init_from_file_with_params, whisper_vad_n_probs,
|
||||
whisper_vad_params, whisper_vad_probs, whisper_vad_segments, whisper_vad_segments_from_probs,
|
||||
whisper_vad_segments_from_samples, whisper_vad_segments_get_segment_t0,
|
||||
whisper_vad_segments_get_segment_t1, whisper_vad_segments_n_segments,
|
||||
};
|
||||
|
||||
/// Configuration for Voice Activity Detection in `whisper.cpp`.
|
||||
///
|
||||
/// See [the `whisper.cpp` README](https://github.com/ggml-org/whisper.cpp/#voice-activity-detection-vad) for more details.
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct WhisperVadParams {
|
||||
params: whisper_vad_params,
|
||||
}
|
||||
|
||||
impl Default for WhisperVadParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
params: whisper_vad_params {
|
||||
threshold: 0.5,
|
||||
min_speech_duration_ms: 250,
|
||||
min_silence_duration_ms: 100,
|
||||
max_speech_duration_s: f32::MAX,
|
||||
speech_pad_ms: 30,
|
||||
samples_overlap: 0.1,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperVadParams {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the probability threshold to consider as speech.
|
||||
/// A probability for a speech segment/frame above this threshold will be considered as speech.
|
||||
///
|
||||
/// Defaults to 0.5.
|
||||
pub fn set_threshold(&mut self, threshold: f32) {
|
||||
self.params.threshold = threshold;
|
||||
}
|
||||
|
||||
/// Set the minimum duration for a valid speech segment, in milliseconds.
|
||||
/// Speech segments shorter than this value will be discarded to filter out brief noise or false positives.
|
||||
///
|
||||
/// Defaults to 250 milliseconds.
|
||||
pub fn set_min_speech_duration(&mut self, min_speech_duration: c_int) {
|
||||
self.params.min_speech_duration_ms = min_speech_duration;
|
||||
}
|
||||
|
||||
/// Set the minimum silence duration to consider speech as ended.
|
||||
/// Silence periods must be at least this long to end a speech segment.
|
||||
/// Shorter silence periods will be ignored and included as part of the speech.
|
||||
///
|
||||
/// Defaults to 100 milliseconds.
|
||||
pub fn set_min_silence_duration(&mut self, min_silence_duration: c_int) {
|
||||
self.params.min_silence_duration_ms = min_silence_duration;
|
||||
}
|
||||
|
||||
/// Set the maximum duration of a speech segment before forcing a new segment.
|
||||
/// Speech segments longer than this will be automatically split into multiple segments at
|
||||
/// silence points exceeding 98ms to prevent excessively long segments.
|
||||
///
|
||||
/// Defaults to [`f32::MAX`].
|
||||
pub fn set_max_speech_duration(&mut self, max_speech_duration: f32) {
|
||||
self.params.max_speech_duration_s = max_speech_duration;
|
||||
}
|
||||
|
||||
/// Set the amount of padding added before and after speech segments, in milliseconds.
|
||||
/// Adds this amount of padding before and after each detected speech segment to avoid cutting off speech edges.
|
||||
///
|
||||
/// Defaults to 30 milliseconds.
|
||||
pub fn set_speech_pad(&mut self, speech_pad: c_int) {
|
||||
self.params.speech_pad_ms = speech_pad;
|
||||
}
|
||||
|
||||
/// Sets the amount of audio to extend from each speech segment into the next one, in seconds (e.g., 0.10 = 100ms overlap).
|
||||
/// This ensures speech isn't cut off abruptly between segments when they're concatenated together.
|
||||
///
|
||||
/// Defaults to 0.1 seconds.
|
||||
pub fn set_samples_overlap(&mut self, samples_overlap: f32) {
|
||||
self.params.samples_overlap = samples_overlap;
|
||||
}
|
||||
|
||||
pub(crate) fn into_inner(self) -> whisper_vad_params {
|
||||
self.params
|
||||
}
|
||||
}
|
||||
|
||||
/// Whisper VAD context parameters
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct WhisperVadContextParams {
|
||||
params: whisper_vad_context_params,
|
||||
}
|
||||
|
||||
impl Default for WhisperVadContextParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
params: whisper_vad_context_params {
|
||||
n_threads: 4,
|
||||
use_gpu: false,
|
||||
gpu_device: 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl WhisperVadContextParams {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set the number of threads to use for processing
|
||||
pub fn set_n_threads(&mut self, n_threads: c_int) {
|
||||
self.params.n_threads = n_threads;
|
||||
}
|
||||
|
||||
/// Enable the GPU for VAD?
|
||||
pub fn set_use_gpu(&mut self, use_gpu: bool) {
|
||||
self.params.use_gpu = use_gpu;
|
||||
}
|
||||
|
||||
/// The CUDA device to use if `use_gpu` is true
|
||||
pub fn set_gpu_device(&mut self, gpu_device: c_int) {
|
||||
self.params.gpu_device = gpu_device;
|
||||
}
|
||||
|
||||
fn into_inner(self) -> whisper_vad_context_params {
|
||||
self.params
|
||||
}
|
||||
}
|
||||
|
||||
/// A handle to use `whisper.cpp`'s built in VAD standalone.
|
||||
///
|
||||
/// You probably want to use [`Self::segments_from_samples`].
|
||||
pub struct WhisperVadContext {
|
||||
ptr: *mut whisper_vad_context,
|
||||
}
|
||||
|
||||
impl WhisperVadContext {
|
||||
pub fn new(model_path: &str, params: WhisperVadContextParams) -> Result<Self, WhisperError> {
|
||||
let model_path = CString::new(model_path)
|
||||
.expect("VAD model path contains null byte")
|
||||
.into_raw() as *const c_char;
|
||||
let ptr =
|
||||
unsafe { whisper_vad_init_from_file_with_params(model_path, params.into_inner()) };
|
||||
|
||||
if ptr.is_null() {
|
||||
Err(WhisperError::NullPointer)
|
||||
} else {
|
||||
Ok(Self { ptr })
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect speech in `samples`. Call [`Self::segments_from_probabilities`] to finish the pipeline.
|
||||
///
|
||||
/// # Errors
|
||||
/// This function will exclusively return `WhisperError::GenericError(-1)` on error.
|
||||
/// If you've registered logging hooks, they will have much more detailed information.
|
||||
pub fn detect_speech(&mut self, samples: &[f32]) -> Result<(), WhisperError> {
|
||||
let (samples, len) = (samples.as_ptr(), samples.len() as c_int);
|
||||
|
||||
let success = unsafe { whisper_vad_detect_speech(self.ptr, samples, len) };
|
||||
|
||||
if !success {
|
||||
Err(WhisperError::GenericError(-1))
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Get an array of probabilities. Undocumented use.
|
||||
pub fn probabilities(&self) -> &[f32] {
|
||||
let prob_ptr = unsafe { whisper_vad_probs(self.ptr) };
|
||||
let prob_count = unsafe { whisper_vad_n_probs(self.ptr) }
|
||||
.try_into()
|
||||
.expect("n_probs is too large to fit into usize");
|
||||
unsafe { core::slice::from_raw_parts(prob_ptr, prob_count) }
|
||||
}
|
||||
|
||||
/// Finish running the VAD pipeline and return segment details.
|
||||
///
|
||||
/// # Errors
|
||||
/// The only possible error is [`WhisperError::NullPointer`].
|
||||
pub fn segments_from_probabilities(
|
||||
&mut self,
|
||||
params: WhisperVadParams,
|
||||
) -> Result<WhisperVadSegments, WhisperError> {
|
||||
let ptr = unsafe { whisper_vad_segments_from_probs(self.ptr, params.into_inner()) };
|
||||
|
||||
if ptr.is_null() {
|
||||
Err(WhisperError::NullPointer)
|
||||
} else {
|
||||
Ok(WhisperVadSegments { ptr })
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the entire VAD pipeline.
|
||||
/// This calls both [`Self::detect_speech`] and [`Self::segments_from_probabilities`] behind the scenes.
|
||||
///
|
||||
/// # Errors
|
||||
/// The only possible error is [`WhisperError::NullPointer`].
|
||||
pub fn segments_from_samples(
|
||||
&mut self,
|
||||
params: WhisperVadParams,
|
||||
samples: &[f32],
|
||||
) -> Result<WhisperVadSegments, WhisperError> {
|
||||
let (sample_ptr, sample_len) = (samples.as_ptr(), samples.len() as c_int);
|
||||
let ptr = unsafe {
|
||||
whisper_vad_segments_from_samples(self.ptr, params.into_inner(), sample_ptr, sample_len)
|
||||
};
|
||||
|
||||
if ptr.is_null() {
|
||||
Err(WhisperError::NullPointer)
|
||||
} else {
|
||||
Ok(WhisperVadSegments { ptr })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WhisperVadContext {
|
||||
fn drop(&mut self) {
|
||||
unsafe { whisper_vad_free(self.ptr) }
|
||||
}
|
||||
}
|
||||
|
||||
/// You can obtain this struct from a [`WhisperVadContext`].
|
||||
pub struct WhisperVadSegments {
|
||||
ptr: *mut whisper_vad_segments,
|
||||
segment_count: c_int,
|
||||
iter_idx: c_int,
|
||||
}
|
||||
|
||||
impl WhisperVadSegments {
|
||||
fn new(ptr: *mut whisper_vad_segments) -> Self {
|
||||
let segment_count = unsafe { whisper_vad_segments_n_segments(ptr) };
|
||||
Self {
|
||||
ptr,
|
||||
segment_count,
|
||||
iter_idx: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn num_segments(&self) -> c_int {
|
||||
self.segment_count
|
||||
}
|
||||
|
||||
/// Return the start timestamp of this segment in centiseconds (10s of milliseconds).
|
||||
pub fn get_segment_start_timestamp(&self, idx: c_int) -> Option<f32> {
|
||||
if idx < 0 || idx > self.segment_count {
|
||||
None
|
||||
} else {
|
||||
Some(unsafe { whisper_vad_segments_get_segment_t0(self.ptr, idx) })
|
||||
}
|
||||
}
|
||||
|
||||
/// Return the end timestamp of this segment in centiseconds (10s of milliseconds).
|
||||
pub fn get_segment_end_timestamp(&self, idx: c_int) -> Option<f32> {
|
||||
if idx < 0 || idx > self.segment_count {
|
||||
None
|
||||
} else {
|
||||
Some(unsafe { whisper_vad_segments_get_segment_t1(self.ptr, idx) })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Iterator for WhisperVadSegments {
|
||||
type Item = WhisperVadSegment;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.iter_idx > self.segment_count {
|
||||
return None;
|
||||
}
|
||||
|
||||
let start = unsafe { whisper_vad_segments_get_segment_t0(self.ptr, self.iter_idx) };
|
||||
let end = unsafe { whisper_vad_segments_get_segment_t1(self.ptr, self.iter_idx) };
|
||||
self.iter_idx += 1;
|
||||
Some(WhisperVadSegment { start, end })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone)]
|
||||
pub struct WhisperVadSegment {
|
||||
start: f32,
|
||||
end: f32,
|
||||
}
|
||||
|
||||
impl Drop for WhisperVadSegments {
|
||||
fn drop(&mut self) {
|
||||
unsafe { whisper_vad_free_segments(self.ptr) }
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue