rename library to whisper-rs

This commit is contained in:
0/0 2022-10-09 20:48:00 -06:00
parent fec86e0d56
commit 82c83c860f
No known key found for this signature in database
GPG key ID: DE8D5010C0AAA3DC
6 changed files with 38 additions and 38 deletions

View file

@ -2,14 +2,14 @@
members = ["sys"] members = ["sys"]
[package] [package]
name = "whisper-cpp" name = "whisper-rs"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
whisper-cpp-sys = { path = "sys", version = "0.1" } whisper-rs-sys = { path = "sys", version = "0.1" }
[features] [features]
simd = [] simd = []

View file

@ -1,4 +1,4 @@
use whisper_cpp::{DecodeStrategy, FullParams, WhisperContext}; use whisper_rs::{DecodeStrategy, FullParams, WhisperContext};
// note that running this example will not do anything, as it is just a // note that running this example will not do anything, as it is just a
// demonstration of how to use the library, and actual usage requires // demonstration of how to use the library, and actual usage requires
@ -34,8 +34,8 @@ pub fn usage() {
// note that you don't need to use these, you can do it yourself or any other way you want // note that you don't need to use these, you can do it yourself or any other way you want
// these are just provided for convenience // these are just provided for convenience
// SIMD variants of these functions are also available, but only on nightly Rust: see the docs // SIMD variants of these functions are also available, but only on nightly Rust: see the docs
let audio_data = whisper_cpp::convert_stereo_to_mono_audio( let audio_data = whisper_rs::convert_stereo_to_mono_audio(
&whisper_cpp::convert_integer_to_float_audio(&audio_data), &whisper_rs::convert_integer_to_float_audio(&audio_data),
); );
// now we can run the model // now we can run the model

View file

@ -18,7 +18,7 @@ use std::ffi::{c_int, CString};
/// `int whisper_lang_id(const char * lang)` /// `int whisper_lang_id(const char * lang)`
pub fn get_lang_id(lang: &str) -> Option<c_int> { pub fn get_lang_id(lang: &str) -> Option<c_int> {
let c_lang = CString::new(lang).expect("Language contains null byte"); let c_lang = CString::new(lang).expect("Language contains null byte");
let ret = unsafe { whisper_cpp_sys::whisper_lang_id(c_lang.as_ptr()) }; let ret = unsafe { whisper_rs_sys::whisper_lang_id(c_lang.as_ptr()) };
if ret == -1 { if ret == -1 {
None None
} else { } else {
@ -32,7 +32,7 @@ pub fn get_lang_id(lang: &str) -> Option<c_int> {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_translate ()` /// `whisper_token whisper_token_translate ()`
pub fn token_translate() -> WhisperToken { pub fn token_translate() -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_translate() } unsafe { whisper_rs_sys::whisper_token_translate() }
} }
/// Get the ID of the transcribe task token. /// Get the ID of the transcribe task token.
@ -40,5 +40,5 @@ pub fn token_translate() -> WhisperToken {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_transcribe()` /// `whisper_token whisper_token_transcribe()`
pub fn token_transcribe() -> WhisperToken { pub fn token_transcribe() -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_transcribe() } unsafe { whisper_rs_sys::whisper_token_transcribe() }
} }

View file

@ -9,7 +9,7 @@ use std::ffi::{c_int, CStr, CString};
/// then run a full transcription with [WhisperContext::full]. /// then run a full transcription with [WhisperContext::full].
#[derive(Debug)] #[derive(Debug)]
pub struct WhisperContext { pub struct WhisperContext {
ctx: *mut whisper_cpp_sys::whisper_context, ctx: *mut whisper_rs_sys::whisper_context,
/// has the spectrogram been initialized in at least one way? /// has the spectrogram been initialized in at least one way?
spectrogram_initialized: bool, spectrogram_initialized: bool,
/// has the data been encoded? /// has the data been encoded?
@ -31,7 +31,7 @@ impl WhisperContext {
/// `struct whisper_context * whisper_init(const char * path_model);` /// `struct whisper_context * whisper_init(const char * path_model);`
pub fn new(path: &str) -> Result<Self, WhisperError> { pub fn new(path: &str) -> Result<Self, WhisperError> {
let path_cstr = CString::new(path)?; let path_cstr = CString::new(path)?;
let ctx = unsafe { whisper_cpp_sys::whisper_init(path_cstr.as_ptr()) }; let ctx = unsafe { whisper_rs_sys::whisper_init(path_cstr.as_ptr()) };
if ctx.is_null() { if ctx.is_null() {
Err(WhisperError::InitError) Err(WhisperError::InitError)
} else { } else {
@ -61,7 +61,7 @@ impl WhisperContext {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let ret = unsafe { let ret = unsafe {
whisper_cpp_sys::whisper_pcm_to_mel( whisper_rs_sys::whisper_pcm_to_mel(
self.ctx, self.ctx,
pcm.as_ptr(), pcm.as_ptr(),
pcm.len() as c_int, pcm.len() as c_int,
@ -94,7 +94,7 @@ impl WhisperContext {
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)` /// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> { pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> {
let ret = unsafe { let ret = unsafe {
whisper_cpp_sys::whisper_set_mel( whisper_rs_sys::whisper_set_mel(
self.ctx, self.ctx,
data.as_ptr(), data.as_ptr(),
data.len() as c_int, data.len() as c_int,
@ -129,7 +129,7 @@ impl WhisperContext {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let ret = let ret =
unsafe { whisper_cpp_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) }; unsafe { whisper_rs_sys::whisper_encode(self.ctx, offset as c_int, threads as c_int) };
if ret == 0 { if ret == 0 {
self.encode_complete = true; self.encode_complete = true;
Ok(()) Ok(())
@ -166,7 +166,7 @@ impl WhisperContext {
return Err(WhisperError::InvalidThreadCount); return Err(WhisperError::InvalidThreadCount);
} }
let ret = unsafe { let ret = unsafe {
whisper_cpp_sys::whisper_decode( whisper_rs_sys::whisper_decode(
self.ctx, self.ctx,
tokens.as_ptr(), tokens.as_ptr(),
tokens.len() as c_int, tokens.len() as c_int,
@ -198,7 +198,7 @@ impl WhisperContext {
if !self.decode_once { if !self.decode_once {
return Err(WhisperError::DecodeNotComplete); return Err(WhisperError::DecodeNotComplete);
} }
let ret = unsafe { whisper_cpp_sys::whisper_sample_best(self.ctx, needs_timestamp) }; let ret = unsafe { whisper_rs_sys::whisper_sample_best(self.ctx, needs_timestamp) };
Ok(ret) Ok(ret)
} }
@ -214,7 +214,7 @@ impl WhisperContext {
if !self.decode_once { if !self.decode_once {
return Err(WhisperError::DecodeNotComplete); return Err(WhisperError::DecodeNotComplete);
} }
let ret = unsafe { whisper_cpp_sys::whisper_sample_timestamp(self.ctx) }; let ret = unsafe { whisper_rs_sys::whisper_sample_timestamp(self.ctx) };
Ok(ret) Ok(ret)
} }
@ -227,7 +227,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_n_len (struct whisper_context * ctx)` /// `int whisper_n_len (struct whisper_context * ctx)`
pub fn n_len(&self) -> Result<c_int, WhisperError> { pub fn n_len(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_len(self.ctx) }; let ret = unsafe { whisper_rs_sys::whisper_n_len(self.ctx) };
if ret < 0 { if ret < 0 {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
} else { } else {
@ -243,7 +243,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_n_vocab (struct whisper_context * ctx)` /// `int whisper_n_vocab (struct whisper_context * ctx)`
pub fn n_vocab(&self) -> Result<c_int, WhisperError> { pub fn n_vocab(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_vocab(self.ctx) }; let ret = unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) };
if ret < 0 { if ret < 0 {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
} else { } else {
@ -259,7 +259,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_n_text_ctx (struct whisper_context * ctx)` /// `int whisper_n_text_ctx (struct whisper_context * ctx)`
pub fn n_text_ctx(&self) -> Result<c_int, WhisperError> { pub fn n_text_ctx(&self) -> Result<c_int, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_n_text_ctx(self.ctx) }; let ret = unsafe { whisper_rs_sys::whisper_n_text_ctx(self.ctx) };
if ret < 0 { if ret < 0 {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
} else { } else {
@ -272,7 +272,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_is_multilingual(struct whisper_context * ctx)` /// `int whisper_is_multilingual(struct whisper_context * ctx)`
pub fn is_multilingual(&self) -> bool { pub fn is_multilingual(&self) -> bool {
unsafe { whisper_cpp_sys::whisper_is_multilingual(self.ctx) != 0 } unsafe { whisper_rs_sys::whisper_is_multilingual(self.ctx) != 0 }
} }
/// The probabilities for the next token. /// The probabilities for the next token.
@ -287,7 +287,7 @@ impl WhisperContext {
if !self.decode_once { if !self.decode_once {
return Err(WhisperError::DecodeNotComplete); return Err(WhisperError::DecodeNotComplete);
} }
let ret = unsafe { whisper_cpp_sys::whisper_get_probs(self.ctx) }; let ret = unsafe { whisper_rs_sys::whisper_get_probs(self.ctx) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
@ -305,7 +305,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)` /// `const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token)`
pub fn token_to_str(&self, token_id: WhisperToken) -> Result<String, WhisperError> { pub fn token_to_str(&self, token_id: WhisperToken) -> Result<String, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_token_to_str(self.ctx, token_id) }; let ret = unsafe { whisper_rs_sys::whisper_token_to_str(self.ctx, token_id) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
@ -320,7 +320,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_eot (struct whisper_context * ctx)` /// `whisper_token whisper_token_eot (struct whisper_context * ctx)`
pub fn token_eot(&self) -> WhisperToken { pub fn token_eot(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_eot(self.ctx) } unsafe { whisper_rs_sys::whisper_token_eot(self.ctx) }
} }
/// Get the ID of the sot token. /// Get the ID of the sot token.
@ -328,7 +328,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_sot (struct whisper_context * ctx)` /// `whisper_token whisper_token_sot (struct whisper_context * ctx)`
pub fn token_sot(&self) -> WhisperToken { pub fn token_sot(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_sot(self.ctx) } unsafe { whisper_rs_sys::whisper_token_sot(self.ctx) }
} }
/// Get the ID of the prev token. /// Get the ID of the prev token.
@ -336,7 +336,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_prev(struct whisper_context * ctx)` /// `whisper_token whisper_token_prev(struct whisper_context * ctx)`
pub fn token_prev(&self) -> WhisperToken { pub fn token_prev(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_prev(self.ctx) } unsafe { whisper_rs_sys::whisper_token_prev(self.ctx) }
} }
/// Get the ID of the solm token. /// Get the ID of the solm token.
@ -344,7 +344,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_solm(struct whisper_context * ctx)` /// `whisper_token whisper_token_solm(struct whisper_context * ctx)`
pub fn token_solm(&self) -> WhisperToken { pub fn token_solm(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_solm(self.ctx) } unsafe { whisper_rs_sys::whisper_token_solm(self.ctx) }
} }
/// Get the ID of the not token. /// Get the ID of the not token.
@ -352,7 +352,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_not (struct whisper_context * ctx)` /// `whisper_token whisper_token_not (struct whisper_context * ctx)`
pub fn token_not(&self) -> WhisperToken { pub fn token_not(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_not(self.ctx) } unsafe { whisper_rs_sys::whisper_token_not(self.ctx) }
} }
/// Get the ID of the beg token. /// Get the ID of the beg token.
@ -360,7 +360,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `whisper_token whisper_token_beg (struct whisper_context * ctx)` /// `whisper_token whisper_token_beg (struct whisper_context * ctx)`
pub fn token_beg(&self) -> WhisperToken { pub fn token_beg(&self) -> WhisperToken {
unsafe { whisper_cpp_sys::whisper_token_beg(self.ctx) } unsafe { whisper_rs_sys::whisper_token_beg(self.ctx) }
} }
/// Print performance statistics to stdout. /// Print performance statistics to stdout.
@ -368,7 +368,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `void whisper_print_timings(struct whisper_context * ctx)` /// `void whisper_print_timings(struct whisper_context * ctx)`
pub fn print_timings(&self) { pub fn print_timings(&self) {
unsafe { whisper_cpp_sys::whisper_print_timings(self.ctx) } unsafe { whisper_rs_sys::whisper_print_timings(self.ctx) }
} }
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text /// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
@ -387,7 +387,7 @@ impl WhisperContext {
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)` /// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> { pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
let ret = unsafe { let ret = unsafe {
whisper_cpp_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int) whisper_rs_sys::whisper_full(self.ctx, params.fp, data.as_ptr(), data.len() as c_int)
}; };
if ret < 0 { if ret < 0 {
Err(WhisperError::GenericError(ret)) Err(WhisperError::GenericError(ret))
@ -402,7 +402,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int whisper_full_n_segments(struct whisper_context * ctx)` /// `int whisper_full_n_segments(struct whisper_context * ctx)`
pub fn full_n_segments(&self) -> c_int { pub fn full_n_segments(&self) -> c_int {
unsafe { whisper_cpp_sys::whisper_full_n_segments(self.ctx) } unsafe { whisper_rs_sys::whisper_full_n_segments(self.ctx) }
} }
/// Get the start time of the specified segment. /// Get the start time of the specified segment.
@ -413,7 +413,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)` /// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_t0(&self, segment: c_int) -> i64 { pub fn full_get_segment_t0(&self, segment: c_int) -> i64 {
unsafe { whisper_cpp_sys::whisper_full_get_segment_t0(self.ctx, segment) } unsafe { whisper_rs_sys::whisper_full_get_segment_t0(self.ctx, segment) }
} }
/// Get the end time of the specified segment. /// Get the end time of the specified segment.
@ -424,7 +424,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)` /// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_t1(&self, segment: c_int) -> i64 { pub fn full_get_segment_t1(&self, segment: c_int) -> i64 {
unsafe { whisper_cpp_sys::whisper_full_get_segment_t1(self.ctx, segment) } unsafe { whisper_rs_sys::whisper_full_get_segment_t1(self.ctx, segment) }
} }
/// Get the text of the specified segment. /// Get the text of the specified segment.
@ -438,7 +438,7 @@ impl WhisperContext {
/// # C++ equivalent /// # C++ equivalent
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)` /// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> { pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
let ret = unsafe { whisper_cpp_sys::whisper_full_get_segment_text(self.ctx, segment) }; let ret = unsafe { whisper_rs_sys::whisper_full_get_segment_text(self.ctx, segment) };
if ret.is_null() { if ret.is_null() {
return Err(WhisperError::NullPointer); return Err(WhisperError::NullPointer);
} }
@ -450,6 +450,6 @@ impl WhisperContext {
impl Drop for WhisperContext { impl Drop for WhisperContext {
fn drop(&mut self) { fn drop(&mut self) {
unsafe { whisper_cpp_sys::whisper_free(self.ctx) }; unsafe { whisper_rs_sys::whisper_free(self.ctx) };
} }
} }

View file

@ -14,7 +14,7 @@ pub enum DecodeStrategy {
} }
pub struct FullParams<'a> { pub struct FullParams<'a> {
pub(crate) fp: whisper_cpp_sys::whisper_full_params, pub(crate) fp: whisper_rs_sys::whisper_full_params,
phantom: PhantomData<&'a str>, phantom: PhantomData<&'a str>,
} }
@ -22,7 +22,7 @@ impl<'a> FullParams<'a> {
/// Create a new set of parameters for the decoder. /// Create a new set of parameters for the decoder.
pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> { pub fn new(decode_strategy: DecodeStrategy) -> FullParams<'a> {
let mut fp = unsafe { let mut fp = unsafe {
whisper_cpp_sys::whisper_full_default_params(match decode_strategy { whisper_rs_sys::whisper_full_default_params(match decode_strategy {
DecodeStrategy::Greedy { .. } => 0, DecodeStrategy::Greedy { .. } => 0,
DecodeStrategy::BeamSearch { .. } => 1, DecodeStrategy::BeamSearch { .. } => 1,
} as _) } as _)

View file

@ -1,5 +1,5 @@
[package] [package]
name = "whisper-cpp-sys" name = "whisper-rs-sys"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"