Merge pull request #102 from codesoda/add_initial_prompt_param
Add initial prompt param
This commit is contained in:
commit
2c0d6404b7
2 changed files with 102 additions and 1 deletions
|
|
@ -51,6 +51,7 @@ fn main() {
|
||||||
).expect("failed to open model");
|
).expect("failed to open model");
|
||||||
let mut state = ctx.create_state().expect("failed to create key");
|
let mut state = ctx.create_state().expect("failed to create key");
|
||||||
let mut params = FullParams::new(SamplingStrategy::default());
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
params.set_initial_prompt("experience");
|
||||||
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
|
params.set_progress_callback_safe(|progress| println!("Progress callback: {}%", progress));
|
||||||
|
|
||||||
let st = std::time::Instant::now();
|
let st = std::time::Instant::now();
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use crate::whisper_grammar::WhisperGrammarElement;
|
use crate::whisper_grammar::WhisperGrammarElement;
|
||||||
use std::ffi::{c_float, c_int, CString};
|
use std::ffi::{c_char, c_float, c_int, CString};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
use whisper_rs_sys::whisper_token;
|
use whisper_rs_sys::whisper_token;
|
||||||
|
|
||||||
|
|
@ -579,6 +579,29 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) {
|
pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) {
|
||||||
self.fp.grammar_penalty = grammar_penalty;
|
self.fp.grammar_penalty = grammar_penalty;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Set the initial prompt for the model.
|
||||||
|
///
|
||||||
|
/// This is the text that will be used as the starting point for the model's decoding.
|
||||||
|
/// Calling this more than once will overwrite the previous initial prompt.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `initial_prompt` - A string slice representing the initial prompt text.
|
||||||
|
///
|
||||||
|
/// # Panics
|
||||||
|
/// This method will panic if `initial_prompt` contains a null byte, as it cannot be converted into a `CString`.
|
||||||
|
///
|
||||||
|
/// # Examples
|
||||||
|
/// ```
|
||||||
|
/// let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
/// params.set_initial_prompt("Hello, world!");
|
||||||
|
/// // ... further usage of params ...
|
||||||
|
/// ```
|
||||||
|
pub fn set_initial_prompt(&mut self, initial_prompt: &str) {
|
||||||
|
self.fp.initial_prompt = CString::new(initial_prompt)
|
||||||
|
.expect("Initial prompt contains null byte")
|
||||||
|
.into_raw() as *const c_char;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// following implementations are safe
|
// following implementations are safe
|
||||||
|
|
@ -586,3 +609,80 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
// concurrent usage is prevented by &mut self on methods that modify the struct
|
// concurrent usage is prevented by &mut self on methods that modify the struct
|
||||||
unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {}
|
unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {}
|
||||||
unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {}
|
unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test_whisper_params_initial_prompt {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
impl<'a, 'b> FullParams<'a, 'b> {
|
||||||
|
pub fn get_initial_prompt(&self) -> &str {
|
||||||
|
// SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp
|
||||||
|
unsafe {
|
||||||
|
std::ffi::CStr::from_ptr(self.fp.initial_prompt)
|
||||||
|
.to_str()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_initial_prompt_normal_usage() {
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
let prompt = "Hello, world!";
|
||||||
|
params.set_initial_prompt(prompt);
|
||||||
|
assert_eq!(params.get_initial_prompt(), prompt);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic(expected = "Initial prompt contains null byte")]
|
||||||
|
fn test_initial_prompt_null_byte() {
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
let prompt = "Hello\0, world!";
|
||||||
|
params.set_initial_prompt(prompt);
|
||||||
|
// Should panic
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_initial_prompt_empty_string() {
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
let prompt = "";
|
||||||
|
params.set_initial_prompt(prompt);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
params.get_initial_prompt(),
|
||||||
|
prompt,
|
||||||
|
"The initial prompt should be an empty string."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_initial_prompt_repeated_calls() {
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
params.set_initial_prompt("First prompt");
|
||||||
|
assert_eq!(
|
||||||
|
params.get_initial_prompt(),
|
||||||
|
"First prompt",
|
||||||
|
"The initial prompt should be 'First prompt'."
|
||||||
|
);
|
||||||
|
|
||||||
|
params.set_initial_prompt("Second prompt");
|
||||||
|
assert_eq!(
|
||||||
|
params.get_initial_prompt(),
|
||||||
|
"Second prompt",
|
||||||
|
"The initial prompt should be 'Second prompt' after second set."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_initial_prompt_long_string() {
|
||||||
|
let mut params = FullParams::new(SamplingStrategy::default());
|
||||||
|
let long_prompt = "a".repeat(10000); // a long string of 10,000 'a' characters
|
||||||
|
params.set_initial_prompt(&long_prompt);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
params.get_initial_prompt(),
|
||||||
|
long_prompt.as_str(),
|
||||||
|
"The initial prompt should match the long string provided."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue