Merge remote-tracking branch 'origin/master' into feature/fix-metal
This commit is contained in:
commit
3f27c17fdf
9 changed files with 1882 additions and 256 deletions
|
|
@ -22,6 +22,9 @@ pub use standalone::*;
|
|||
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
|
||||
use std::sync::Once;
|
||||
pub use utilities::*;
|
||||
pub use whisper_ctx::DtwMode;
|
||||
pub use whisper_ctx::DtwModelPreset;
|
||||
pub use whisper_ctx::DtwParameters;
|
||||
pub use whisper_ctx::WhisperContextParameters;
|
||||
use whisper_ctx::WhisperInnerContext;
|
||||
pub use whisper_ctx_wrapper::WhisperContext;
|
||||
|
|
@ -44,5 +47,6 @@ pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callbac
|
|||
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
|
||||
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
|
||||
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
|
||||
pub type WhisperAbortCallback = whisper_rs_sys::whisper_abort_callback;
|
||||
pub type WhisperAbortCallback = whisper_rs_sys::ggml_abort_callback;
|
||||
pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback;
|
||||
pub type DtwAhead = whisper_rs_sys::whisper_ahead;
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ pub struct SystemInfo {
|
|||
pub f16c: bool,
|
||||
pub blas: bool,
|
||||
pub clblast: bool,
|
||||
pub cublas: bool,
|
||||
pub cuda: bool,
|
||||
}
|
||||
|
||||
impl Default for SystemInfo {
|
||||
|
|
@ -124,7 +124,7 @@ impl Default for SystemInfo {
|
|||
f16c: whisper_rs_sys::ggml_cpu_has_f16c() != 0,
|
||||
blas: whisper_rs_sys::ggml_cpu_has_blas() != 0,
|
||||
clblast: whisper_rs_sys::ggml_cpu_has_clblast() != 0,
|
||||
cublas: whisper_rs_sys::ggml_cpu_has_cublas() != 0,
|
||||
cuda: whisper_rs_sys::ggml_cpu_has_cuda() != 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -469,23 +469,34 @@ impl Drop for WhisperInnerContext {
|
|||
unsafe impl Send for WhisperInnerContext {}
|
||||
unsafe impl Sync for WhisperInnerContext {}
|
||||
|
||||
pub struct WhisperContextParameters {
|
||||
pub struct WhisperContextParameters<'a> {
|
||||
/// Use GPU if available.
|
||||
///
|
||||
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
|
||||
/// (in that case, GPU is always enabled).
|
||||
pub use_gpu: bool,
|
||||
/// Enable flash attention, default false
|
||||
///
|
||||
/// **Warning** Can't be used with DTW. DTW will be disabled if flash_attn is true
|
||||
pub flash_attn: bool,
|
||||
/// GPU device id, default 0
|
||||
pub gpu_device: c_int,
|
||||
/// DTW token level timestamp parameters
|
||||
pub dtw_parameters: DtwParameters<'a>,
|
||||
}
|
||||
|
||||
#[allow(clippy::derivable_impls)] // this impl cannot be derived
|
||||
impl Default for WhisperContextParameters {
|
||||
impl<'a> Default for WhisperContextParameters<'a> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
use_gpu: cfg!(feature = "_gpu"),
|
||||
flash_attn: false,
|
||||
gpu_device: 0,
|
||||
dtw_parameters: DtwParameters::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
impl WhisperContextParameters {
|
||||
impl<'a> WhisperContextParameters<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
|
@ -493,13 +504,156 @@ impl WhisperContextParameters {
|
|||
self.use_gpu = use_gpu;
|
||||
self
|
||||
}
|
||||
pub fn flash_attn(&mut self, flash_attn: bool) -> &mut Self {
|
||||
self.flash_attn = flash_attn;
|
||||
self
|
||||
}
|
||||
pub fn gpu_device(&mut self, gpu_device: c_int) -> &mut Self {
|
||||
self.gpu_device = gpu_device;
|
||||
self
|
||||
}
|
||||
pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters<'a>) -> &mut Self {
|
||||
self.dtw_parameters = dtw_parameters;
|
||||
self
|
||||
}
|
||||
|
||||
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
|
||||
let dtw_token_timestamps = !matches!(self.dtw_parameters.mode, DtwMode::None);
|
||||
let mut dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE;
|
||||
let mut dtw_n_top: c_int = -1;
|
||||
let mut dtw_aheads = whisper_rs_sys::whisper_aheads {
|
||||
n_heads: 0,
|
||||
heads: std::ptr::null(),
|
||||
};
|
||||
|
||||
match &self.dtw_parameters.mode {
|
||||
DtwMode::None => {}
|
||||
DtwMode::TopMost { n_top } => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST;
|
||||
dtw_n_top = *n_top;
|
||||
}
|
||||
DtwMode::Custom { aheads } => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM;
|
||||
|
||||
dtw_aheads = whisper_rs_sys::whisper_aheads {
|
||||
n_heads: aheads.len(),
|
||||
heads: aheads.as_ptr(),
|
||||
};
|
||||
}
|
||||
DtwMode::ModelPreset { model_preset } => match model_preset {
|
||||
DtwModelPreset::TinyEn => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN;
|
||||
}
|
||||
DtwModelPreset::Tiny => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY;
|
||||
}
|
||||
DtwModelPreset::BaseEn => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN;
|
||||
}
|
||||
DtwModelPreset::Base => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE;
|
||||
}
|
||||
DtwModelPreset::SmallEn => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN;
|
||||
}
|
||||
DtwModelPreset::Small => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL;
|
||||
}
|
||||
DtwModelPreset::MediumEn => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN;
|
||||
}
|
||||
DtwModelPreset::Medium => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM;
|
||||
}
|
||||
DtwModelPreset::LargeV1 => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1;
|
||||
}
|
||||
DtwModelPreset::LargeV2 => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2;
|
||||
}
|
||||
DtwModelPreset::LargeV3 => {
|
||||
dtw_aheads_preset =
|
||||
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3;
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
whisper_rs_sys::whisper_context_params {
|
||||
use_gpu: self.use_gpu,
|
||||
flash_attn: self.flash_attn,
|
||||
gpu_device: self.gpu_device,
|
||||
dtw_token_timestamps,
|
||||
dtw_aheads_preset,
|
||||
dtw_n_top,
|
||||
dtw_aheads,
|
||||
dtw_mem_size: self.dtw_parameters.dtw_mem_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DtwParameters<'a> {
|
||||
pub mode: DtwMode<'a>,
|
||||
pub dtw_mem_size: usize,
|
||||
}
|
||||
|
||||
impl Default for DtwParameters<'_> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: DtwMode::None,
|
||||
dtw_mem_size: 1024 * 1024 * 128,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DtwMode<'a> {
|
||||
/// DTW token level timestamps disabled
|
||||
None,
|
||||
/// Use N Top Most layers from loaded model
|
||||
TopMost {
|
||||
/// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer
|
||||
n_top: c_int,
|
||||
},
|
||||
/// Use custom aheads, non-empty list of whisper_ahead.
|
||||
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
|
||||
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
|
||||
Custom {
|
||||
aheads: &'a [whisper_rs_sys::whisper_ahead],
|
||||
},
|
||||
/// Use predefined preset for standard models
|
||||
ModelPreset { model_preset: DtwModelPreset },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DtwModelPreset {
|
||||
TinyEn,
|
||||
Tiny,
|
||||
BaseEn,
|
||||
Base,
|
||||
SmallEn,
|
||||
Small,
|
||||
MediumEn,
|
||||
Medium,
|
||||
LargeV1,
|
||||
LargeV2,
|
||||
LargeV3,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(feature = "test-with-tiny-model")]
|
||||
mod test_with_tiny_model {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue