refactor(state): remove lifetime binding from whisper context
This commit is contained in:
parent
c1a37751dd
commit
74e83185bf
1 changed files with 28 additions and 28 deletions
|
|
@ -1,19 +1,20 @@
|
||||||
use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData};
|
|
||||||
use std::ffi::{c_int, CStr};
|
use std::ffi::{c_int, CStr};
|
||||||
use std::marker::PhantomData;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData};
|
||||||
|
|
||||||
/// Rustified pointer to a Whisper state.
|
/// Rustified pointer to a Whisper state.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WhisperState<'a> {
|
pub struct WhisperState {
|
||||||
ctx: *mut whisper_rs_sys::whisper_context,
|
ctx: Arc<WhisperContext>,
|
||||||
ptr: *mut whisper_rs_sys::whisper_state,
|
ptr: *mut whisper_rs_sys::whisper_state,
|
||||||
_phantom: PhantomData<&'a WhisperContext>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<'a> Send for WhisperState<'a> {}
|
unsafe impl Send for WhisperState {}
|
||||||
unsafe impl<'a> Sync for WhisperState<'a> {}
|
|
||||||
|
|
||||||
impl<'a> Drop for WhisperState<'a> {
|
unsafe impl Sync for WhisperState {}
|
||||||
|
|
||||||
|
impl Drop for WhisperState {
|
||||||
fn drop(&mut self) {
|
fn drop(&mut self) {
|
||||||
unsafe {
|
unsafe {
|
||||||
whisper_rs_sys::whisper_free_state(self.ptr);
|
whisper_rs_sys::whisper_free_state(self.ptr);
|
||||||
|
|
@ -21,15 +22,14 @@ impl<'a> Drop for WhisperState<'a> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> WhisperState<'a> {
|
impl WhisperState {
|
||||||
pub(crate) fn new(
|
pub(crate) fn new(
|
||||||
ctx: *mut whisper_rs_sys::whisper_context,
|
ctx: Arc<WhisperContext>,
|
||||||
ptr: *mut whisper_rs_sys::whisper_state,
|
ptr: *mut whisper_rs_sys::whisper_state,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ctx,
|
ctx,
|
||||||
ptr,
|
ptr,
|
||||||
_phantom: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -45,13 +45,13 @@ impl<'a> WhisperState<'a> {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
||||||
pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
|
pub fn pcm_to_mel(&self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
|
||||||
if threads < 1 {
|
if threads < 1 {
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
}
|
}
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_pcm_to_mel_with_state(
|
whisper_rs_sys::whisper_pcm_to_mel_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
pcm.as_ptr(),
|
pcm.as_ptr(),
|
||||||
pcm.len() as c_int,
|
pcm.len() as c_int,
|
||||||
|
|
@ -81,7 +81,7 @@ impl<'a> WhisperState<'a> {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
||||||
pub fn pcm_to_mel_phase_vocoder(
|
pub fn pcm_to_mel_phase_vocoder(
|
||||||
&mut self,
|
&self,
|
||||||
pcm: &[f32],
|
pcm: &[f32],
|
||||||
threads: usize,
|
threads: usize,
|
||||||
) -> Result<(), WhisperError> {
|
) -> Result<(), WhisperError> {
|
||||||
|
|
@ -90,7 +90,7 @@ impl<'a> WhisperState<'a> {
|
||||||
}
|
}
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
|
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
pcm.as_ptr(),
|
pcm.as_ptr(),
|
||||||
pcm.len() as c_int,
|
pcm.len() as c_int,
|
||||||
|
|
@ -122,12 +122,12 @@ impl<'a> WhisperState<'a> {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
|
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
|
||||||
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> {
|
pub fn set_mel(&self, data: &[f32]) -> Result<(), WhisperError> {
|
||||||
let hop_size = 160;
|
let hop_size = 160;
|
||||||
let n_len = (data.len() / hop_size) * 2;
|
let n_len = (data.len() / hop_size) * 2;
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_set_mel_with_state(
|
whisper_rs_sys::whisper_set_mel_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
data.as_ptr(),
|
data.as_ptr(),
|
||||||
n_len as c_int,
|
n_len as c_int,
|
||||||
|
|
@ -155,13 +155,13 @@ impl<'a> WhisperState<'a> {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
|
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
|
||||||
pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> {
|
pub fn encode(&self, offset: usize, threads: usize) -> Result<(), WhisperError> {
|
||||||
if threads < 1 {
|
if threads < 1 {
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
}
|
}
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_encode_with_state(
|
whisper_rs_sys::whisper_encode_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
offset as c_int,
|
offset as c_int,
|
||||||
threads as c_int,
|
threads as c_int,
|
||||||
|
|
@ -192,7 +192,7 @@ impl<'a> WhisperState<'a> {
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
||||||
pub fn decode(
|
pub fn decode(
|
||||||
&mut self,
|
&self,
|
||||||
tokens: &[WhisperToken],
|
tokens: &[WhisperToken],
|
||||||
n_past: usize,
|
n_past: usize,
|
||||||
threads: usize,
|
threads: usize,
|
||||||
|
|
@ -202,7 +202,7 @@ impl<'a> WhisperState<'a> {
|
||||||
}
|
}
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_decode_with_state(
|
whisper_rs_sys::whisper_decode_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
tokens.as_ptr(),
|
tokens.as_ptr(),
|
||||||
tokens.len() as c_int,
|
tokens.len() as c_int,
|
||||||
|
|
@ -240,7 +240,7 @@ impl<'a> WhisperState<'a> {
|
||||||
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
|
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_lang_auto_detect_with_state(
|
whisper_rs_sys::whisper_lang_auto_detect_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
offset_ms as c_int,
|
offset_ms as c_int,
|
||||||
threads as c_int,
|
threads as c_int,
|
||||||
|
|
@ -309,7 +309,7 @@ impl<'a> WhisperState<'a> {
|
||||||
/// `int whisper_n_vocab (struct whisper_context * ctx)`
|
/// `int whisper_n_vocab (struct whisper_context * ctx)`
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn n_vocab(&self) -> c_int {
|
pub fn n_vocab(&self) -> c_int {
|
||||||
unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
|
|
@ -327,7 +327,7 @@ impl<'a> WhisperState<'a> {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
|
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
|
||||||
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
|
pub fn full(&self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
|
||||||
if data.is_empty() {
|
if data.is_empty() {
|
||||||
// can randomly trigger segmentation faults if we don't check this
|
// can randomly trigger segmentation faults if we don't check this
|
||||||
return Err(WhisperError::NoSamples);
|
return Err(WhisperError::NoSamples);
|
||||||
|
|
@ -335,7 +335,7 @@ impl<'a> WhisperState<'a> {
|
||||||
|
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_full_with_state(
|
whisper_rs_sys::whisper_full_with_state(
|
||||||
self.ctx,
|
self.ctx.ctx,
|
||||||
self.ptr,
|
self.ptr,
|
||||||
params.fp,
|
params.fp,
|
||||||
data.as_ptr(),
|
data.as_ptr(),
|
||||||
|
|
@ -495,7 +495,7 @@ impl<'a> WhisperState<'a> {
|
||||||
) -> Result<String, WhisperError> {
|
) -> Result<String, WhisperError> {
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||||
self.ctx, self.ptr, segment, token,
|
self.ctx.ctx, self.ptr, segment, token,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
if ret.is_null() {
|
if ret.is_null() {
|
||||||
|
|
@ -527,7 +527,7 @@ impl<'a> WhisperState<'a> {
|
||||||
) -> Result<String, WhisperError> {
|
) -> Result<String, WhisperError> {
|
||||||
let ret = unsafe {
|
let ret = unsafe {
|
||||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||||
self.ctx, self.ptr, segment, token,
|
self.ctx.ctx, self.ptr, segment, token,
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
if ret.is_null() {
|
if ret.is_null() {
|
||||||
|
|
@ -610,7 +610,7 @@ impl<'a> WhisperState<'a> {
|
||||||
///
|
///
|
||||||
/// # C++ equivalent
|
/// # C++ equivalent
|
||||||
/// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)`
|
/// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)`
|
||||||
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
|
pub fn full_get_segment_speaker_turn_next(&self, i_segment: c_int) -> bool {
|
||||||
unsafe {
|
unsafe {
|
||||||
whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(
|
whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(
|
||||||
self.ptr, i_segment,
|
self.ptr, i_segment,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue