Merge pull request #102 from codesoda/add_initial_prompt_param

Add initial prompt param
This commit is contained in:
Niko 2023-12-01 19:03:51 -07:00 committed by GitHub
commit 2c0d6404b7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 102 additions and 1 deletions

View file

@ -51,6 +51,7 @@ fn main() {
).expect("failed to open model");
let mut state = ctx.create_state().expect("failed to create key");
let mut params = FullParams::new(SamplingStrategy::default());
params.set_initial_prompt("experience");
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
let st = std::time::Instant::now();

View file

@ -1,5 +1,5 @@
use crate::whisper_grammar::WhisperGrammarElement;
use std::ffi::{c_float, c_int, CString};
use std::ffi::{c_char, c_float, c_int, CString};
use std::marker::PhantomData;
use whisper_rs_sys::whisper_token;
@ -579,6 +579,29 @@ impl<'a, 'b> FullParams<'a, 'b> {
pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) {
self.fp.grammar_penalty = grammar_penalty;
}
/// Set the initial prompt for the model.
///
/// This is the text that will be used as the starting point for the model's decoding.
/// Calling this more than once will overwrite the previous initial prompt.
///
/// # Arguments
/// * `initial_prompt` - A string slice representing the initial prompt text.
///
/// # Panics
/// This method will panic if `initial_prompt` contains a null byte, as it cannot be converted into a `CString`.
///
/// # Examples
/// ```
/// let mut params = FullParams::new(SamplingStrategy::default());
/// params.set_initial_prompt("Hello, world!");
/// // ... further usage of params ...
/// ```
pub fn set_initial_prompt(&mut self, initial_prompt: &str) {
self.fp.initial_prompt = CString::new(initial_prompt)
.expect("Initial prompt contains null byte")
.into_raw() as *const c_char;
}
}
// following implementations are safe
@ -586,3 +609,80 @@ impl<'a, 'b> FullParams<'a, 'b> {
// concurrent usage is prevented by &mut self on methods that modify the struct
unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {}
unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {}
#[cfg(test)]
mod test_whisper_params_initial_prompt {
use super::*;
impl<'a, 'b> FullParams<'a, 'b> {
pub fn get_initial_prompt(&self) -> &str {
// SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp
unsafe {
std::ffi::CStr::from_ptr(self.fp.initial_prompt)
.to_str()
.unwrap()
}
}
}
#[test]
fn test_initial_prompt_normal_usage() {
let mut params = FullParams::new(SamplingStrategy::default());
let prompt = "Hello, world!";
params.set_initial_prompt(prompt);
assert_eq!(params.get_initial_prompt(), prompt);
}
#[test]
#[should_panic(expected = "Initial prompt contains null byte")]
fn test_initial_prompt_null_byte() {
let mut params = FullParams::new(SamplingStrategy::default());
let prompt = "Hello\0, world!";
params.set_initial_prompt(prompt);
// Should panic
}
#[test]
fn test_initial_prompt_empty_string() {
let mut params = FullParams::new(SamplingStrategy::default());
let prompt = "";
params.set_initial_prompt(prompt);
assert_eq!(
params.get_initial_prompt(),
prompt,
"The initial prompt should be an empty string."
);
}
#[test]
fn test_initial_prompt_repeated_calls() {
let mut params = FullParams::new(SamplingStrategy::default());
params.set_initial_prompt("First prompt");
assert_eq!(
params.get_initial_prompt(),
"First prompt",
"The initial prompt should be 'First prompt'."
);
params.set_initial_prompt("Second prompt");
assert_eq!(
params.get_initial_prompt(),
"Second prompt",
"The initial prompt should be 'Second prompt' after second set."
);
}
#[test]
fn test_initial_prompt_long_string() {
let mut params = FullParams::new(SamplingStrategy::default());
let long_prompt = "a".repeat(10000); // a long string of 10,000 'a' characters
params.set_initial_prompt(&long_prompt);
assert_eq!(
params.get_initial_prompt(),
long_prompt.as_str(),
"The initial prompt should match the long string provided."
);
}
}