Add set_initial_prompt to WhisperParams
This commit is contained in:
parent
2a17adde85
commit
8c8a5d78d9
1 changed files with 94 additions and 1 deletions
|
|
@ -1,5 +1,5 @@
|
|||
use crate::whisper_grammar::WhisperGrammarElement;
|
||||
use std::ffi::{c_float, c_int, CString};
|
||||
use std::ffi::{c_char, c_float, c_int, CString};
|
||||
use std::marker::PhantomData;
|
||||
use whisper_rs_sys::whisper_token;
|
||||
|
||||
|
|
@ -579,6 +579,22 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
pub fn set_grammar_penalty(&mut self, grammar_penalty: f32) {
|
||||
self.fp.grammar_penalty = grammar_penalty;
|
||||
}
|
||||
|
||||
/// Set the initial prompt for the model.
|
||||
///
|
||||
/// This is the text that will be used as the starting point for the model's decoding.
|
||||
/// Calling this more than once will overwrite the previous initial prompt.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `initial_prompt` - A string slice representing the initial prompt text.
|
||||
///
|
||||
/// # Panics
|
||||
/// This method will panic if `initial_prompt` contains a null byte, as it cannot be converted into a `CString`.
|
||||
pub fn set_initial_prompt(&mut self, initial_prompt: &str) {
|
||||
self.fp.initial_prompt = CString::new(initial_prompt)
|
||||
.expect("Initial prompt contains null byte")
|
||||
.into_raw() as *const c_char;
|
||||
}
|
||||
}
|
||||
|
||||
// following implementations are safe
|
||||
|
|
@ -586,3 +602,80 @@ impl<'a, 'b> FullParams<'a, 'b> {
|
|||
// concurrent usage is prevented by &mut self on methods that modify the struct
|
||||
unsafe impl<'a, 'b> Send for FullParams<'a, 'b> {}
|
||||
unsafe impl<'a, 'b> Sync for FullParams<'a, 'b> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_whisper_params_initial_prompt {
|
||||
use super::*;
|
||||
|
||||
impl<'a, 'b> FullParams<'a, 'b> {
|
||||
pub fn get_initial_prompt(&self) -> &str {
|
||||
// SAFETY: Ensure this is safe and respects the lifetime of the string in self.fp
|
||||
unsafe {
|
||||
std::ffi::CStr::from_ptr(self.fp.initial_prompt)
|
||||
.to_str()
|
||||
.unwrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initial_prompt_normal_usage() {
|
||||
let mut params = FullParams::new(SamplingStrategy::default());
|
||||
let prompt = "Hello, world!";
|
||||
params.set_initial_prompt(prompt);
|
||||
assert_eq!(params.get_initial_prompt(), prompt);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "Initial prompt contains null byte")]
|
||||
fn test_initial_prompt_null_byte() {
|
||||
let mut params = FullParams::new(SamplingStrategy::default());
|
||||
let prompt = "Hello\0, world!";
|
||||
params.set_initial_prompt(prompt);
|
||||
// Should panic
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initial_prompt_empty_string() {
|
||||
let mut params = FullParams::new(SamplingStrategy::default());
|
||||
let prompt = "";
|
||||
params.set_initial_prompt(prompt);
|
||||
|
||||
assert_eq!(
|
||||
params.get_initial_prompt(),
|
||||
prompt,
|
||||
"The initial prompt should be an empty string."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initial_prompt_repeated_calls() {
|
||||
let mut params = FullParams::new(SamplingStrategy::default());
|
||||
params.set_initial_prompt("First prompt");
|
||||
assert_eq!(
|
||||
params.get_initial_prompt(),
|
||||
"First prompt",
|
||||
"The initial prompt should be 'First prompt'."
|
||||
);
|
||||
|
||||
params.set_initial_prompt("Second prompt");
|
||||
assert_eq!(
|
||||
params.get_initial_prompt(),
|
||||
"Second prompt",
|
||||
"The initial prompt should be 'Second prompt' after second set."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_initial_prompt_long_string() {
|
||||
let mut params = FullParams::new(SamplingStrategy::default());
|
||||
let long_prompt = "a".repeat(10000); // a long string of 10,000 'a' characters
|
||||
params.set_initial_prompt(&long_prompt);
|
||||
|
||||
assert_eq!(
|
||||
params.get_initial_prompt(),
|
||||
long_prompt.as_str(),
|
||||
"The initial prompt should match the long string provided."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue