Merge pull request #142 from arizhih/update-whisper-cpp
Update whisper.cpp version to 1.6.2
This commit is contained in:
commit
d7c20844fd
9 changed files with 1875 additions and 256 deletions
|
|
@ -4,7 +4,7 @@ exclude = ["examples/full_usage"]
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "whisper-rs"
|
name = "whisper-rs"
|
||||||
version = "0.11.0"
|
version = "0.12.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Rust bindings for whisper.cpp"
|
description = "Rust bindings for whisper.cpp"
|
||||||
license = "Unlicense"
|
license = "Unlicense"
|
||||||
|
|
@ -14,7 +14,7 @@ repository = "https://github.com/tazz4843/whisper-rs"
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
whisper-rs-sys = { path = "sys", version = "0.8" }
|
whisper-rs-sys = { path = "sys", version = "0.10.0" }
|
||||||
log = { version = "0.4", optional = true }
|
log = { version = "0.4", optional = true }
|
||||||
tracing = { version = "0.1", optional = true }
|
tracing = { version = "0.1", optional = true }
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,37 @@ use whisper_rs::{FullParams, SamplingStrategy, WhisperContext, WhisperContextPar
|
||||||
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
/// Loads a context and model, processes an audio file, and prints the resulting transcript to stdout.
|
||||||
fn main() -> Result<(), &'static str> {
|
fn main() -> Result<(), &'static str> {
|
||||||
// Load a context and model.
|
// Load a context and model.
|
||||||
|
let mut context_param = WhisperContextParameters::default();
|
||||||
|
|
||||||
|
// Enable DTW token level timestamp for known model by using model preset
|
||||||
|
context_param.dtw_parameters.mode = whisper_rs::DtwMode::ModelPreset {
|
||||||
|
model_preset: whisper_rs::DtwModelPreset::BaseEn,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Enable DTW token level timestamp for unknown model by providing custom aheads
|
||||||
|
// see details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
|
||||||
|
// values corresponds to ggml-base.en.bin, result will be the same as with DtwModelPreset::BaseEn
|
||||||
|
let custom_aheads = [
|
||||||
|
(3, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
(4, 7),
|
||||||
|
(5, 1),
|
||||||
|
(5, 2),
|
||||||
|
(5, 4),
|
||||||
|
(5, 6),
|
||||||
|
]
|
||||||
|
.map(|(n_text_layer, n_head)| whisper_rs::DtwAhead {
|
||||||
|
n_text_layer,
|
||||||
|
n_head,
|
||||||
|
});
|
||||||
|
context_param.dtw_parameters.mode = whisper_rs::DtwMode::Custom {
|
||||||
|
aheads: &custom_aheads,
|
||||||
|
};
|
||||||
|
|
||||||
let ctx = WhisperContext::new_with_params(
|
let ctx = WhisperContext::new_with_params(
|
||||||
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
|
"example/path/to/model/whisper.cpp/models/ggml-base.en.bin",
|
||||||
WhisperContextParameters::default(),
|
context_param,
|
||||||
)
|
)
|
||||||
.expect("failed to load model");
|
.expect("failed to load model");
|
||||||
// Create a state
|
// Create a state
|
||||||
|
|
@ -33,6 +61,8 @@ fn main() -> Result<(), &'static str> {
|
||||||
params.set_print_progress(false);
|
params.set_print_progress(false);
|
||||||
params.set_print_realtime(false);
|
params.set_print_realtime(false);
|
||||||
params.set_print_timestamps(false);
|
params.set_print_timestamps(false);
|
||||||
|
// Enable token level timestamps
|
||||||
|
params.set_token_timestamps(true);
|
||||||
|
|
||||||
// Open the audio file.
|
// Open the audio file.
|
||||||
let reader = hound::WavReader::open("audio.wav").expect("failed to open file");
|
let reader = hound::WavReader::open("audio.wav").expect("failed to open file");
|
||||||
|
|
@ -87,8 +117,24 @@ fn main() -> Result<(), &'static str> {
|
||||||
.full_get_segment_t1(i)
|
.full_get_segment_t1(i)
|
||||||
.expect("failed to get end timestamp");
|
.expect("failed to get end timestamp");
|
||||||
|
|
||||||
|
let first_token_dtw_ts = if let Ok(token_count) = state.full_n_tokens(i) {
|
||||||
|
if token_count > 0 {
|
||||||
|
if let Ok(token_data) = state.full_get_token_data(i, 0) {
|
||||||
|
token_data.t_dtw
|
||||||
|
} else {
|
||||||
|
-1i64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
-1i64
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
-1i64
|
||||||
|
};
|
||||||
// Print the segment to stdout.
|
// Print the segment to stdout.
|
||||||
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
println!(
|
||||||
|
"[{} - {} ({})]: {}",
|
||||||
|
start_timestamp, end_timestamp, first_token_dtw_ts, segment
|
||||||
|
);
|
||||||
|
|
||||||
// Format the segment information as a string.
|
// Format the segment information as a string.
|
||||||
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
|
let line = format!("[{} - {}]: {}\n", start_timestamp, end_timestamp, segment);
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,9 @@ pub use standalone::*;
|
||||||
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
|
#[cfg(any(feature = "whisper-cpp-log", feature = "whisper-cpp-tracing"))]
|
||||||
use std::sync::Once;
|
use std::sync::Once;
|
||||||
pub use utilities::*;
|
pub use utilities::*;
|
||||||
|
pub use whisper_ctx::DtwMode;
|
||||||
|
pub use whisper_ctx::DtwModelPreset;
|
||||||
|
pub use whisper_ctx::DtwParameters;
|
||||||
pub use whisper_ctx::WhisperContextParameters;
|
pub use whisper_ctx::WhisperContextParameters;
|
||||||
use whisper_ctx::WhisperInnerContext;
|
use whisper_ctx::WhisperInnerContext;
|
||||||
pub use whisper_ctx_wrapper::WhisperContext;
|
pub use whisper_ctx_wrapper::WhisperContext;
|
||||||
|
|
@ -44,5 +47,6 @@ pub type WhisperNewSegmentCallback = whisper_rs_sys::whisper_new_segment_callbac
|
||||||
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
|
pub type WhisperStartEncoderCallback = whisper_rs_sys::whisper_encoder_begin_callback;
|
||||||
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
|
pub type WhisperProgressCallback = whisper_rs_sys::whisper_progress_callback;
|
||||||
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
|
pub type WhisperLogitsFilterCallback = whisper_rs_sys::whisper_logits_filter_callback;
|
||||||
pub type WhisperAbortCallback = whisper_rs_sys::whisper_abort_callback;
|
pub type WhisperAbortCallback = whisper_rs_sys::ggml_abort_callback;
|
||||||
pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback;
|
pub type WhisperLogCallback = whisper_rs_sys::ggml_log_callback;
|
||||||
|
pub type DtwAhead = whisper_rs_sys::whisper_ahead;
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ pub struct SystemInfo {
|
||||||
pub f16c: bool,
|
pub f16c: bool,
|
||||||
pub blas: bool,
|
pub blas: bool,
|
||||||
pub clblast: bool,
|
pub clblast: bool,
|
||||||
pub cublas: bool,
|
pub cuda: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SystemInfo {
|
impl Default for SystemInfo {
|
||||||
|
|
@ -118,7 +118,7 @@ impl Default for SystemInfo {
|
||||||
f16c: whisper_rs_sys::ggml_cpu_has_f16c() != 0,
|
f16c: whisper_rs_sys::ggml_cpu_has_f16c() != 0,
|
||||||
blas: whisper_rs_sys::ggml_cpu_has_blas() != 0,
|
blas: whisper_rs_sys::ggml_cpu_has_blas() != 0,
|
||||||
clblast: whisper_rs_sys::ggml_cpu_has_clblast() != 0,
|
clblast: whisper_rs_sys::ggml_cpu_has_clblast() != 0,
|
||||||
cublas: whisper_rs_sys::ggml_cpu_has_cublas() != 0,
|
cuda: whisper_rs_sys::ggml_cpu_has_cuda() != 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -469,23 +469,34 @@ impl Drop for WhisperInnerContext {
|
||||||
unsafe impl Send for WhisperInnerContext {}
|
unsafe impl Send for WhisperInnerContext {}
|
||||||
unsafe impl Sync for WhisperInnerContext {}
|
unsafe impl Sync for WhisperInnerContext {}
|
||||||
|
|
||||||
pub struct WhisperContextParameters {
|
pub struct WhisperContextParameters<'a> {
|
||||||
/// Use GPU if available.
|
/// Use GPU if available.
|
||||||
///
|
///
|
||||||
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
|
/// **Warning**: Does not have an effect if OpenCL is selected as GPU backend
|
||||||
/// (in that case, GPU is always enabled).
|
/// (in that case, GPU is always enabled).
|
||||||
pub use_gpu: bool,
|
pub use_gpu: bool,
|
||||||
|
/// Enable flash attention, default false
|
||||||
|
///
|
||||||
|
/// **Warning** Can't be used with DTW. DTW will be disabled if flash_attn is true
|
||||||
|
pub flash_attn: bool,
|
||||||
|
/// GPU device id, default 0
|
||||||
|
pub gpu_device: c_int,
|
||||||
|
/// DTW token level timestamp parameters
|
||||||
|
pub dtw_parameters: DtwParameters<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::derivable_impls)] // this impl cannot be derived
|
#[allow(clippy::derivable_impls)] // this impl cannot be derived
|
||||||
impl Default for WhisperContextParameters {
|
impl<'a> Default for WhisperContextParameters<'a> {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
use_gpu: cfg!(feature = "_gpu"),
|
use_gpu: cfg!(feature = "_gpu"),
|
||||||
|
flash_attn: false,
|
||||||
|
gpu_device: 0,
|
||||||
|
dtw_parameters: DtwParameters::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl WhisperContextParameters {
|
impl<'a> WhisperContextParameters<'a> {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self::default()
|
Self::default()
|
||||||
}
|
}
|
||||||
|
|
@ -493,13 +504,156 @@ impl WhisperContextParameters {
|
||||||
self.use_gpu = use_gpu;
|
self.use_gpu = use_gpu;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
pub fn flash_attn(&mut self, flash_attn: bool) -> &mut Self {
|
||||||
|
self.flash_attn = flash_attn;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
pub fn gpu_device(&mut self, gpu_device: c_int) -> &mut Self {
|
||||||
|
self.gpu_device = gpu_device;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
pub fn dtw_parameters(&mut self, dtw_parameters: DtwParameters<'a>) -> &mut Self {
|
||||||
|
self.dtw_parameters = dtw_parameters;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
|
fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params {
|
||||||
|
let dtw_token_timestamps = !matches!(self.dtw_parameters.mode, DtwMode::None);
|
||||||
|
let mut dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE;
|
||||||
|
let mut dtw_n_top: c_int = -1;
|
||||||
|
let mut dtw_aheads = whisper_rs_sys::whisper_aheads {
|
||||||
|
n_heads: 0,
|
||||||
|
heads: std::ptr::null(),
|
||||||
|
};
|
||||||
|
|
||||||
|
match &self.dtw_parameters.mode {
|
||||||
|
DtwMode::None => {}
|
||||||
|
DtwMode::TopMost { n_top } => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_N_TOP_MOST;
|
||||||
|
dtw_n_top = *n_top;
|
||||||
|
}
|
||||||
|
DtwMode::Custom { aheads } => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_CUSTOM;
|
||||||
|
|
||||||
|
dtw_aheads = whisper_rs_sys::whisper_aheads {
|
||||||
|
n_heads: aheads.len(),
|
||||||
|
heads: aheads.as_ptr(),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
DtwMode::ModelPreset { model_preset } => match model_preset {
|
||||||
|
DtwModelPreset::TinyEn => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY_EN;
|
||||||
|
}
|
||||||
|
DtwModelPreset::Tiny => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_TINY;
|
||||||
|
}
|
||||||
|
DtwModelPreset::BaseEn => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE_EN;
|
||||||
|
}
|
||||||
|
DtwModelPreset::Base => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_BASE;
|
||||||
|
}
|
||||||
|
DtwModelPreset::SmallEn => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL_EN;
|
||||||
|
}
|
||||||
|
DtwModelPreset::Small => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_SMALL;
|
||||||
|
}
|
||||||
|
DtwModelPreset::MediumEn => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM_EN;
|
||||||
|
}
|
||||||
|
DtwModelPreset::Medium => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_MEDIUM;
|
||||||
|
}
|
||||||
|
DtwModelPreset::LargeV1 => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V1;
|
||||||
|
}
|
||||||
|
DtwModelPreset::LargeV2 => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V2;
|
||||||
|
}
|
||||||
|
DtwModelPreset::LargeV3 => {
|
||||||
|
dtw_aheads_preset =
|
||||||
|
whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_LARGE_V3;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
whisper_rs_sys::whisper_context_params {
|
whisper_rs_sys::whisper_context_params {
|
||||||
use_gpu: self.use_gpu,
|
use_gpu: self.use_gpu,
|
||||||
|
flash_attn: self.flash_attn,
|
||||||
|
gpu_device: self.gpu_device,
|
||||||
|
dtw_token_timestamps,
|
||||||
|
dtw_aheads_preset,
|
||||||
|
dtw_n_top,
|
||||||
|
dtw_aheads,
|
||||||
|
dtw_mem_size: self.dtw_parameters.dtw_mem_size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default Disabled
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DtwParameters<'a> {
|
||||||
|
pub mode: DtwMode<'a>,
|
||||||
|
pub dtw_mem_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DtwParameters<'_> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
mode: DtwMode::None,
|
||||||
|
dtw_mem_size: 1024 * 1024 * 128,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum DtwMode<'a> {
|
||||||
|
/// DTW token level timestamps disabled
|
||||||
|
None,
|
||||||
|
/// Use N Top Most layers from loaded model
|
||||||
|
TopMost {
|
||||||
|
/// Number of top text layers used from model, should be 0 < n_top <= model n_text_layer
|
||||||
|
n_top: c_int,
|
||||||
|
},
|
||||||
|
/// Use custom aheads, non-empty list of whisper_ahead.
|
||||||
|
/// 0 < n_text_layer < model n_text_layer, 0 < n_head < model n_text_head for each element
|
||||||
|
/// See details https://github.com/ggerganov/whisper.cpp/pull/1485#discussion_r1519681143
|
||||||
|
Custom {
|
||||||
|
aheads: &'a [whisper_rs_sys::whisper_ahead],
|
||||||
|
},
|
||||||
|
/// Use predefined preset for standard models
|
||||||
|
ModelPreset { model_preset: DtwModelPreset },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum DtwModelPreset {
|
||||||
|
TinyEn,
|
||||||
|
Tiny,
|
||||||
|
BaseEn,
|
||||||
|
Base,
|
||||||
|
SmallEn,
|
||||||
|
Small,
|
||||||
|
MediumEn,
|
||||||
|
Medium,
|
||||||
|
LargeV1,
|
||||||
|
LargeV2,
|
||||||
|
LargeV3,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
#[cfg(feature = "test-with-tiny-model")]
|
#[cfg(feature = "test-with-tiny-model")]
|
||||||
mod test_with_tiny_model {
|
mod test_with_tiny_model {
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
[package]
|
[package]
|
||||||
name = "whisper-rs-sys"
|
name = "whisper-rs-sys"
|
||||||
version = "0.8.1"
|
version = "0.10.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Rust bindings for whisper.cpp (FFI bindings)"
|
description = "Rust bindings for whisper.cpp (FFI bindings)"
|
||||||
license = "Unlicense"
|
license = "Unlicense"
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ fn main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg!(feature = "cuda") {
|
if cfg!(feature = "cuda") {
|
||||||
config.define("WHISPER_CUBLAS", "ON");
|
config.define("WHISPER_CUDA", "ON");
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg!(feature = "openblas") {
|
if cfg!(feature = "openblas") {
|
||||||
|
|
|
||||||
1901
sys/src/bindings.rs
1901
sys/src/bindings.rs
File diff suppressed because it is too large
Load diff
|
|
@ -1 +1 @@
|
||||||
Subproject commit 0b9af32a8b3fa7e2ae5f15a9a08f5b10394993f5
|
Subproject commit c7b6988678779901d02ceba1a8212d2c9908956e
|
||||||
Loading…
Add table
Add a link
Reference in a new issue