Merge branch 'tazz4843:master' into update-whisper-cpp

This commit is contained in:
arizhih 2024-05-07 13:36:30 +02:00 committed by GitHub
commit 1160b1b94e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 468 additions and 113 deletions

View file

@ -1,19 +1,18 @@
use crate::error::WhisperError;
use crate::whisper_state::WhisperState;
use crate::WhisperToken;
use std::ffi::{c_int, CStr, CString};
/// Safe Rust wrapper around a Whisper context.
///
/// You likely want to create this with [WhisperContext::new_with_params],
/// create a state with [WhisperContext::create_state],
/// You likely want to create this with [WhisperInnerContext::new_with_params],
/// create a state with [WhisperInnerContext::create_state],
/// then run a full transcription with [WhisperState::full].
#[derive(Debug)]
pub struct WhisperContext {
ctx: *mut whisper_rs_sys::whisper_context,
pub struct WhisperInnerContext {
pub(crate) ctx: *mut whisper_rs_sys::whisper_context,
}
impl WhisperContext {
impl WhisperInnerContext {
/// Create a new WhisperContext from a file, with parameters.
///
/// # Arguments
@ -71,68 +70,6 @@ impl WhisperContext {
}
}
/// Create a new WhisperContext from a file.
///
/// # Arguments
/// * path: The path to the model file.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)`
#[deprecated = "Use `new_with_params` instead"]
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
/// Create a new WhisperContext from a buffer.
///
/// # Arguments
/// * buffer: The buffer containing the model.
///
/// # Returns
/// Ok(Self) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)`
#[deprecated = "Use `new_from_buffer_with_params` instead"]
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
let ctx = unsafe {
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
};
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
Ok(Self { ctx })
}
}
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
/// Create a new state object, ready for use.
///
/// # Returns
/// Ok(WhisperState) on success, Err(WhisperError) on failure.
///
/// # C++ equivalent
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
pub fn create_state(&self) -> Result<WhisperState, WhisperError> {
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) };
if state.is_null() {
Err(WhisperError::InitError)
} else {
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
Ok(WhisperState::new(self.ctx, state))
}
}
/// Convert the provided text into tokens.
///
/// # Arguments
@ -518,23 +455,9 @@ impl WhisperContext {
pub fn token_transcribe(&self) -> WhisperToken {
unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) }
}
/// Get whether the next segment is predicted as a speaker turn
///
/// # Arguments
/// * i_segment: Segment index.
///
/// # Returns
/// bool
///
/// # C++ equivalent
/// `bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next(self.ctx, i_segment) }
}
}
impl Drop for WhisperContext {
impl Drop for WhisperInnerContext {
#[inline]
fn drop(&mut self) {
unsafe { whisper_rs_sys::whisper_free(self.ctx) };
@ -543,8 +466,8 @@ impl Drop for WhisperContext {
// following implementations are safe
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
unsafe impl Send for WhisperContext {}
unsafe impl Sync for WhisperContext {}
unsafe impl Send for WhisperInnerContext {}
unsafe impl Sync for WhisperInnerContext {}
pub struct WhisperContextParameters {
/// Use GPU if available.
@ -639,7 +562,7 @@ mod test_with_tiny_model {
#[test]
fn test_tokenize_round_trip() {
let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
let ctx = WhisperInnerContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.";
let tokens = ctx.tokenize(text_in, 1024).unwrap();
let text_out = tokens