rename library to whisper-rs

This commit is contained in:
0/0 2022-10-09 20:48:00 -06:00
parent fec86e0d56
commit 82c83c860f
No known key found for this signature in database
GPG key ID: DE8D5010C0AAA3DC
6 changed files with 38 additions and 38 deletions

View file

@ -9,7 +9,7 @@ use std::ffi::{c_int, CStr, CString};
/// then run a full transcription with [WhisperContext::full].
#[derive(Debug)]
pub struct WhisperContext {
ctx: *mut whisper_cpp_sys::whisper_context,
ctx: *mut whisper_rs_sys::whisper_context,
/// has the spectrogram been initialized in at least one way?
spectrogram_initialized: bool,
/// has the data been encoded?
@ -31,7 +31,7 @@ impl WhisperContext {
/// `struct whisper_context * whisper_init(const char * path_model);`
pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_cpp_sys::whisper_init(path_cstr.as_ptr()) };
let ctx = unsafe { whisper_rs_sys::whisper_init(path_cstr.as_ptr()) };
if ctx.is_null() {
Err(WhisperError::InitError)
} else {
@ -61,7 +61,7 @@ impl WhisperContext {
return Err(WhisperError::InvalidThreadCount);
}
let ret = unsafe {
whisper_cpp_sys::whisper_pcm_to_mel(
whisper_rs_sys::whisper_pcm_to_mel(
self.ctx,
pcm.as_ptr(),
pcm.len() as c_int,
@ -94,7 +94,7 @@ impl WhisperContext {
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> {
let ret = unsafe {
whisper_cpp_sys::whisper_set_mel(
whisper_rs_sys::whisper_set_mel(
self.ctx,
data.as_ptr(),
data.len() as c_int,
@ -129,7 +129,7 @@ impl WhisperContext {
return Err(WhisperError::InvalidThreadCount);
}
let ret =
unsafe { whisper_cpp_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) };
unsafe { whisper_rs_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) };
if ret == 0 {
self.encode_complete = true;
Ok(())
@ -166,7 +166,7 @@ impl WhisperContext {
return Err(WhisperError::InvalidThreadCount);
}
let ret = unsafe {
whisper_cpp_sys::whisper_decode(
whisper_rs_sys::whisper_decode(
self.ctx,
tokens.as_ptr(),
tokens.len() as c_int,
@ -198,7 +198,7 @@ impl WhisperContext {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_cpp_sys::whisper_sample_best(self.ctx, needs_timestamp) };
let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) };
Ok(ret)
}
@ -214,7 +214,7 @@ impl WhisperContext {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_cpp_sys::whisper_sample_timestamp(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_sample_timestamp(self.ctx) };
Ok(ret)
}
@ -227,7 +227,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_n_len (struct whisper_context * ctx)`
pub fn n_len(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_len(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_n_len(self.ctx) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
@ -243,7 +243,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_n_vocab (struct whisper_context * ctx)`
pub fn n_vocab(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_vocab(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
@ -259,7 +259,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_n_text_ctx (struct whisper_context * ctx)`
pub fn n_text_ctx(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_text_ctx(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) };
if ret < 0 {
Err(WhisperError::GenericError(ret))
} else {
@ -272,7 +272,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_is_multilingual(struct whisper_context * ctx)`
pub fn is_multilingual(&self) -> bool {
unsafe { whisper_cpp_sys::whisper_is_multilingual(self.ctx) != 0 }
unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
}
/// The probabilities for the next token.
@ -287,7 +287,7 @@ impl WhisperContext {
if !self.decode_once {
return Err(WhisperError::DecodeNotComplete);
}
let ret = unsafe { whisper_cpp_sys::whisper_get_probs(self.ctx) };
let ret = unsafe { whisper_rs_sys::whisper_get_probs(self.ctx) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
@ -305,7 +305,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<String, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_token_to_str(self.ctx, token_id) };
let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
@ -320,7 +320,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
pub fn token_eot(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_eot(self.ctx) }
unsafe { whisper_rs_sys::whisper_token_eot(self.ctx) }
}
/// Get the ID of the sot token.
@ -328,7 +328,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
pub fn token_sot(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_sot(self.ctx) }
unsafe { whisper_rs_sys::whisper_token_sot(self.ctx) }
}
/// Get the ID of the prev token.
@ -336,7 +336,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
pub fn token_prev(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_prev(self.ctx) }
unsafe { whisper_rs_sys::whisper_token_prev(self.ctx) }
}
/// Get the ID of the solm token.
@ -344,7 +344,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
pub fn token_solm(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_solm(self.ctx) }
unsafe { whisper_rs_sys::whisper_token_solm(self.ctx) }
}
/// Get the ID of the not token.
@ -352,7 +352,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token whisper_token_not (struct whisper_context * ctx)`
pub fn token_not(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_not(self.ctx) }
unsafe { whisper_rs_sys::whisper_token_not(self.ctx) }
}
/// Get the ID of the beg token.
@ -360,7 +360,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
pub fn token_beg(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_beg(self.ctx) }
unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
}
/// Print performance statistics to stdout.
@ -368,7 +368,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `void whisper_print_timings(struct whisper_context * ctx)`
pub fn print_timings(&self) {
unsafe { whisper_cpp_sys::whisper_print_timings(self.ctx) }
unsafe { whisper_rs_sys::whisper_print_timings(self.ctx) }
}
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
@ -387,7 +387,7 @@ impl WhisperContext {
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
let ret = unsafe {
whisper_cpp_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int)
whisper_rs_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int)
};
if ret < 0 {
Err(WhisperError::GenericError(ret))
@ -402,7 +402,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
pub fn full_n_segments(&self) -> c_int {
unsafe { whisper_cpp_sys::whisper_full_n_segments(self.ctx) }
unsafe { whisper_rs_sys::whisper_full_n_segments(self.ctx) }
}
/// Get the start time of the specified segment.
@ -413,7 +413,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_t0(&self, segment: c_int) -> i64 {
unsafe { whisper_cpp_sys::whisper_full_get_segment_t0(self.ctx, segment) }
unsafe { whisper_rs_sys::whisper_full_get_segment_t0(self.ctx, segment) }
}
/// Get the end time of the specified segment.
@ -424,7 +424,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_t1(&self, segment: c_int) -> i64 {
unsafe { whisper_cpp_sys::whisper_full_get_segment_t1(self.ctx, segment) }
unsafe { whisper_rs_sys::whisper_full_get_segment_t1(self.ctx, segment) }
}
/// Get the text of the specified segment.
@ -438,7 +438,7 @@ impl WhisperContext {
/// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_full_get_segment_text(self.ctx, segment) };
let ret = unsafe { whisper_rs_sys::whisper_full_get_segment_text(self.ctx, segment) };
if ret.is_null() {
return Err(WhisperError::NullPointer);
}
@ -450,6 +450,6 @@ impl WhisperContext {
impl Drop for WhisperContext {
fn drop(&mut self) {
unsafe { whisper_cpp_sys::whisper_free(self.ctx) };
unsafe { whisper_rs_sys::whisper_free(self.ctx) };
}
}