diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index ea22ae0..0609af2 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -475,6 +475,8 @@ pub struct WhisperContextParameters { /// **Warning**: Does not have an effect if OpenCL is selected as GPU backend /// (in that case, GPU is always enabled). pub use_gpu: bool, + /// Enable flash attention, default false + pub flash_attn : bool, /// GPU device id, default 0 pub gpu_device: c_int, /// [EXPERIMENTAL] Enable Token-level timestamps with DTW, default 0 @@ -497,6 +499,7 @@ impl Default for WhisperContextParameters { fn default() -> Self { Self { use_gpu: cfg!(feature = "_gpu"), + flash_attn : false, gpu_device: 0, dtw_token_timestamps: false, dtw_aheads_preset : whisper_rs_sys::whisper_alignment_heads_preset_WHISPER_AHEADS_NONE, @@ -514,6 +517,10 @@ impl WhisperContextParameters { self.use_gpu = use_gpu; self } + pub fn flash_attn(&mut self, flash_attn: bool) -> &mut Self { + self.flash_attn = flash_attn; + self + } pub fn gpu_device(&mut self, gpu_device: c_int) -> &mut Self { self.gpu_device = gpu_device; self @@ -541,6 +548,7 @@ impl WhisperContextParameters { fn to_c_struct(&self) -> whisper_rs_sys::whisper_context_params { whisper_rs_sys::whisper_context_params { use_gpu: self.use_gpu, + flash_attn: self.flash_attn, gpu_device: self.gpu_device, dtw_token_timestamps: self.dtw_token_timestamps, dtw_aheads_preset: self.dtw_aheads_preset, diff --git a/sys/whisper.cpp b/sys/whisper.cpp index 7395c70..08981d1 160000 --- a/sys/whisper.cpp +++ b/sys/whisper.cpp @@ -1 +1 @@ -Subproject commit 7395c70a748753e3800b63e3422a2b558a097c80 +Subproject commit 08981d1bacbe494ff1c943af6c577c669a2d9f4d