Merge branch 'tazz4843:master' into update-whisper-cpp
This commit is contained in:
commit
1160b1b94e
5 changed files with 468 additions and 113 deletions
|
|
@ -1,19 +1,18 @@
|
|||
use crate::error::WhisperError;
|
||||
use crate::whisper_state::WhisperState;
|
||||
use crate::WhisperToken;
|
||||
use std::ffi::{c_int, CStr, CString};
|
||||
|
||||
/// Safe Rust wrapper around a Whisper context.
|
||||
///
|
||||
/// You likely want to create this with [WhisperContext::new_with_params],
|
||||
/// create a state with [WhisperContext::create_state],
|
||||
/// You likely want to create this with [WhisperInnerContext::new_with_params],
|
||||
/// create a state with [WhisperInnerContext::create_state],
|
||||
/// then run a full transcription with [WhisperState::full].
|
||||
#[derive(Debug)]
|
||||
pub struct WhisperContext {
|
||||
ctx: *mut whisper_rs_sys::whisper_context,
|
||||
pub struct WhisperInnerContext {
|
||||
pub(crate) ctx: *mut whisper_rs_sys::whisper_context,
|
||||
}
|
||||
|
||||
impl WhisperContext {
|
||||
impl WhisperInnerContext {
|
||||
/// Create a new WhisperContext from a file, with parameters.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
@ -71,68 +70,6 @@ impl WhisperContext {
|
|||
}
|
||||
}
|
||||
|
||||
/// Create a new WhisperContext from a file.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * path: The path to the model file.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_context * whisper_init_from_file_no_state(const char * path_model)`
|
||||
#[deprecated = "Use `new_with_params` instead"]
|
||||
pub fn new(path: &str) -> Result<Self, WhisperError> {
|
||||
let path_cstr = CString::new(path)?;
|
||||
let ctx = unsafe { whisper_rs_sys::whisper_init_from_file_no_state(path_cstr.as_ptr()) };
|
||||
if ctx.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
Ok(Self { ctx })
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new WhisperContext from a buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * buffer: The buffer containing the model.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(Self) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_context * whisper_init_from_buffer_no_state(void * buffer, size_t buffer_size)`
|
||||
#[deprecated = "Use `new_from_buffer_with_params` instead"]
|
||||
pub fn new_from_buffer(buffer: &[u8]) -> Result<Self, WhisperError> {
|
||||
let ctx = unsafe {
|
||||
whisper_rs_sys::whisper_init_from_buffer_no_state(buffer.as_ptr() as _, buffer.len())
|
||||
};
|
||||
if ctx.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
Ok(Self { ctx })
|
||||
}
|
||||
}
|
||||
|
||||
// we don't implement `whisper_init()` here since i have zero clue what `whisper_model_loader` does
|
||||
|
||||
/// Create a new state object, ready for use.
|
||||
///
|
||||
/// # Returns
|
||||
/// Ok(WhisperState) on success, Err(WhisperError) on failure.
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `struct whisper_state * whisper_init_state(struct whisper_context * ctx);`
|
||||
pub fn create_state(&self) -> Result<WhisperState, WhisperError> {
|
||||
let state = unsafe { whisper_rs_sys::whisper_init_state(self.ctx) };
|
||||
if state.is_null() {
|
||||
Err(WhisperError::InitError)
|
||||
} else {
|
||||
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
|
||||
Ok(WhisperState::new(self.ctx, state))
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the provided text into tokens.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
@ -518,23 +455,9 @@ impl WhisperContext {
|
|||
pub fn token_transcribe(&self) -> WhisperToken {
|
||||
unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) }
|
||||
}
|
||||
|
||||
/// Get whether the next segment is predicted as a speaker turn
|
||||
///
|
||||
/// # Arguments
|
||||
/// * i_segment: Segment index.
|
||||
///
|
||||
/// # Returns
|
||||
/// bool
|
||||
///
|
||||
/// # C++ equivalent
|
||||
/// `bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment)`
|
||||
pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool {
|
||||
unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next(self.ctx, i_segment) }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WhisperContext {
|
||||
impl Drop for WhisperInnerContext {
|
||||
#[inline]
|
||||
fn drop(&mut self) {
|
||||
unsafe { whisper_rs_sys::whisper_free(self.ctx) };
|
||||
|
|
@ -543,8 +466,8 @@ impl Drop for WhisperContext {
|
|||
|
||||
// following implementations are safe
|
||||
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
|
||||
unsafe impl Send for WhisperContext {}
|
||||
unsafe impl Sync for WhisperContext {}
|
||||
unsafe impl Send for WhisperInnerContext {}
|
||||
unsafe impl Sync for WhisperInnerContext {}
|
||||
|
||||
pub struct WhisperContextParameters {
|
||||
/// Use GPU if available.
|
||||
|
|
@ -639,7 +562,7 @@ mod test_with_tiny_model {
|
|||
|
||||
#[test]
|
||||
fn test_tokenize_round_trip() {
|
||||
let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
|
||||
let ctx = WhisperInnerContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
|
||||
let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.";
|
||||
let tokens = ctx.tokenize(text_in, 1024).unwrap();
|
||||
let text_out = tokens
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue