From 8c8a5d78d946dce9c53ee1a4e37ce9a19fdcc2bc Mon Sep 17 00:00:00 2001 From: Chris Raethke Date: Mon, 27 Nov 2023 11:15:57 +1000 Subject: [PATCH 1/2] Add set_initial_prompt to WhisperParams --- src/whisper_params.rs | 95 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 692a040..795cada 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -1,5 +1,5 @@ use crate::whisper_grammar::WhisperGrammarElement; -use std::ffi::{c_float, c_int, CString}; +use std::ffi::{c_char, c_float, c_int, CString}; use std::marker::PhantomData; use whisper_rs_sys::whisper_token; @@ -579,6 +579,22 @@ impl<'a, 'b> FullParams<'a, 'b> { pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) { self.fp.grammar_penalty = grammar_penalty; } + + /// Set the initial prompt for the model. + /// + /// This is the text that will be used as the starting point for the model's decoding. + /// Calling this more than once will overwrite the previous initial prompt. + /// + /// # Arguments + /// * `initial_prompt` - A string slice representing the initial prompt text. + /// + /// # Panics + /// This method will panic if `initial_prompt` contains a null byte, as it cannot be converted into a `CString`. + pub fn set_initial_prompt(&mut self, initial_prompt: &str) { + self.fp.initial_prompt = CString::new(initial_prompt) + .expect("Initial prompt contains null byte") + .into_raw() as *const c_char; + } } // following implementations are safe @@ -586,3 +602,80 @@ impl<'a, 'b> FullParams<'a, 'b> { // concurrent usage is prevented by &mut self on methods that modify the struct unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {} unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {} + +#[cfg(test)] +mod test_whisper_params_initial_prompt { + use super::*; + + impl<'a, 'b> FullParams<'a, 'b> { + pub fn get_initial_prompt(&self) -> &str { + // SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp + unsafe { + std::ffi::CStr::from_ptr(self.fp.initial_prompt) + .to_str() + .unwrap() + } + } + } + + #[test] + fn test_initial_prompt_normal_usage() { + let mut params = FullParams::new(SamplingStrategy::default()); + let prompt = "Hello, world!"; + params.set_initial_prompt(prompt); + assert_eq!(params.get_initial_prompt(), prompt); + } + + #[test] + #[should_panic(expected = "Initial prompt contains null byte")] + fn test_initial_prompt_null_byte() { + let mut params = FullParams::new(SamplingStrategy::default()); + let prompt = "Hello\0, world!"; + params.set_initial_prompt(prompt); + // Should panic + } + + #[test] + fn test_initial_prompt_empty_string() { + let mut params = FullParams::new(SamplingStrategy::default()); + let prompt = ""; + params.set_initial_prompt(prompt); + + assert_eq!( + params.get_initial_prompt(), + prompt, + "The initial prompt should be an empty string." + ); + } + + #[test] + fn test_initial_prompt_repeated_calls() { + let mut params = FullParams::new(SamplingStrategy::default()); + params.set_initial_prompt("First prompt"); + assert_eq!( + params.get_initial_prompt(), + "First prompt", + "The initial prompt should be 'First prompt'." + ); + + params.set_initial_prompt("Second prompt"); + assert_eq!( + params.get_initial_prompt(), + "Second prompt", + "The initial prompt should be 'Second prompt' after second set." + ); + } + + #[test] + fn test_initial_prompt_long_string() { + let mut params = FullParams::new(SamplingStrategy::default()); + let long_prompt = "a".repeat(10000); // a long string of 10,000 'a' characters + params.set_initial_prompt(&long_prompt); + + assert_eq!( + params.get_initial_prompt(), + long_prompt.as_str(), + "The initial prompt should match the long string provided." + ); + } +} From 9aa24296063d38b059d5d9183717d8c2e4aa0724 Mon Sep 17 00:00:00 2001 From: Chris Raethke Date: Mon, 27 Nov 2023 11:53:19 +1000 Subject: [PATCH 2/2] Add examples for set_initial_prompt(...) --- examples/full_usage/src/main.rs | 1 + src/whisper_params.rs | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/examples/full_usage/src/main.rs b/examples/full_usage/src/main.rs index 5826a5b..eee6658 100644 --- a/examples/full_usage/src/main.rs +++ b/examples/full_usage/src/main.rs @@ -51,6 +51,7 @@ fn main() { ).expect("failed to open model"); let mut state = ctx.create_state().expect("failed to create key"); let mut params = FullParams::new(SamplingStrategy::default()); + params.set_initial_prompt("experience"); params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress)); let st = std::time::Instant::now(); diff --git a/src/whisper_params.rs b/src/whisper_params.rs index 795cada..37281cf 100644 --- a/src/whisper_params.rs +++ b/src/whisper_params.rs @@ -590,6 +590,13 @@ impl<'a, 'b> FullParams<'a, 'b> { /// /// # Panics /// This method will panic if `initial_prompt` contains a null byte, as it cannot be converted into a `CString`. + /// + /// # Examples + /// ``` + /// let mut params = FullParams::new(SamplingStrategy::default()); + /// params.set_initial_prompt("Hello, world!"); + /// // ... further usage of params ... + /// ``` pub fn set_initial_prompt(&mut self, initial_prompt: &str) { self.fp.initial_prompt = CString::new(initial_prompt) .expect("Initial prompt contains null byte")