Merge pull request #47 from jcsoo/tokenize_with_cstring

Tokenize with cstring
This commit is contained in:
0/0 2023-05-05 01:26:47 +00:00 committed by GitHub
commit 3fce9cbbde
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 28 additions and 1 deletions

View file

@ -23,6 +23,7 @@ hound = "3.5.0"
[features] [features]
simd = [] simd = []
coreml = ["whisper-rs-sys/coreml"] coreml = ["whisper-rs-sys/coreml"]
test-with-tiny-model = []
[package.metadata.docs.rs] [package.metadata.docs.rs]
features = ["simd"] features = ["simd"]

View file

@ -88,12 +88,15 @@ impl WhisperContext {
text: &str, text: &str,
max_tokens: usize, max_tokens: usize,
) -> Result<Vec<WhisperToken>, WhisperError> { ) -> Result<Vec<WhisperToken>, WhisperError> {
// convert the text to a nul-terminated C string. Will raise an error if the text contains
// any nul bytes.
let text = CString::new(text)?;
// allocate at least max_tokens to ensure the memory is valid // allocate at least max_tokens to ensure the memory is valid
let mut tokens: Vec<WhisperToken> = Vec::with_capacity(max_tokens); let mut tokens: Vec<WhisperToken> = Vec::with_capacity(max_tokens);
let ret = unsafe { let ret = unsafe {
whisper_rs_sys::whisper_tokenize( whisper_rs_sys::whisper_tokenize(
self.ctx, self.ctx,
text.as_ptr() as *const _, text.as_ptr(),
tokens.as_mut_ptr(), tokens.as_mut_ptr(),
max_tokens as c_int, max_tokens as c_int,
) )
@ -428,3 +431,26 @@ impl Drop for WhisperContext {
// see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388
unsafe impl Send for WhisperContext {} unsafe impl Send for WhisperContext {}
unsafe impl Sync for WhisperContext {} unsafe impl Sync for WhisperContext {}
#[cfg(test)]
#[cfg(feature = "test-with-tiny-model")]
mod test_with_tiny_model {
use super::*;
const MODEL_PATH: &str = "./sys/whisper.cpp/models/ggml-tiny.en.bin";
// These tests expect that the tiny.en model has been downloaded
// using the script `sys/whisper.cpp/models/download-ggml-model.sh tiny.en`
#[test]
fn test_tokenize_round_trip() {
let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'");
let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country.";
let tokens = ctx.tokenize(text_in, 1024).unwrap();
let text_out = tokens
.into_iter()
.map(|t| ctx.token_to_str(t).unwrap())
.collect::<Vec<_>>()
.join("");
assert_eq!(text_in, text_out);
}
}