diff --git a/Cargo.toml b/Cargo.toml index 35a1bcb..5cac3fb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,6 +23,7 @@ hound = "3.5.0" [features] simd = [] coreml = ["whisper-rs-sys/coreml"] +test-with-tiny-model = [] [package.metadata.docs.rs] features = ["simd"] diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 98afa99..8783737 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -88,12 +88,15 @@ impl WhisperContext { text: &str, max_tokens: usize, ) -> Result, WhisperError> { + // convert the text to a nul-terminated C string. Will raise an error if the text contains + // any nul bytes. + let text = CString::new(text)?; // allocate at least max_tokens to ensure the memory is valid let mut tokens: Vec = Vec::with_capacity(max_tokens); let ret = unsafe { whisper_rs_sys::whisper_tokenize( self.ctx, - text.as_ptr() as *const _, + text.as_ptr(), tokens.as_mut_ptr(), max_tokens as c_int, ) @@ -428,3 +431,26 @@ impl Drop for WhisperContext { // see https://github.com/ggerganov/whisper.cpp/issues/32#issuecomment-1272790388 unsafe impl Send for WhisperContext {} unsafe impl Sync for WhisperContext {} + +#[cfg(test)] +#[cfg(feature = "test-with-tiny-model")] +mod test_with_tiny_model { + use super::*; + const MODEL_PATH: &str = "./sys/whisper.cpp/models/ggml-tiny.en.bin"; + + // These tests expect that the tiny.en model has been downloaded + // using the script `sys/whisper.cpp/models/download-ggml-model.sh tiny.en` + + #[test] + fn test_tokenize_round_trip() { + let ctx = WhisperContext::new(MODEL_PATH).expect("Download the ggml-tiny.en model using 'sys/whisper.cpp/models/download-ggml-model.sh tiny.en'"); + let text_in = " And so my fellow Americans, ask not what your country can do for you, ask what you can do for your country."; + let tokens = ctx.tokenize(text_in, 1024).unwrap(); + let text_out = tokens + .into_iter() + .map(|t| ctx.token_to_str(t).unwrap()) + .collect::>() + .join(""); + assert_eq!(text_in, text_out); + } +}