Merge pull request #43 from yuniruyuni/feat/call-method-via-state-object
migrate state method into state object
This commit is contained in:
commit
b5072d486d
5 changed files with 539 additions and 581 deletions
|
|
@ -12,7 +12,7 @@ fn main() -> Result<(), &'static str> {
|
||||||
let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
|
let ctx = WhisperContext::new("example/path/to/model/whisper.cpp/models/ggml-base.en.bin")
|
||||||
.expect("failed to load model");
|
.expect("failed to load model");
|
||||||
// Create a state
|
// Create a state
|
||||||
let state = ctx.create_state().expect("failed to create key");
|
let mut state = ctx.create_state().expect("failed to create key");
|
||||||
|
|
||||||
// Create a params object for running the model.
|
// Create a params object for running the model.
|
||||||
// The number of past samples to consider defaults to 0.
|
// The number of past samples to consider defaults to 0.
|
||||||
|
|
@ -63,26 +63,25 @@ fn main() -> Result<(), &'static str> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Run the model.
|
// Run the model.
|
||||||
ctx.full(&state, params, &audio[..])
|
state.full(params, &audio[..]).expect("failed to run model");
|
||||||
.expect("failed to run model");
|
|
||||||
|
|
||||||
// Create a file to write the transcript to.
|
// Create a file to write the transcript to.
|
||||||
let mut file = File::create("transcript.txt").expect("failed to create file");
|
let mut file = File::create("transcript.txt").expect("failed to create file");
|
||||||
|
|
||||||
// Iterate through the segments of the transcript.
|
// Iterate through the segments of the transcript.
|
||||||
let num_segments = ctx
|
let num_segments = state
|
||||||
.full_n_segments(&state)
|
.full_n_segments()
|
||||||
.expect("failed to get number of segments");
|
.expect("failed to get number of segments");
|
||||||
for i in 0..num_segments {
|
for i in 0..num_segments {
|
||||||
// Get the transcribed text and timestamps for the current segment.
|
// Get the transcribed text and timestamps for the current segment.
|
||||||
let segment = ctx
|
let segment = state
|
||||||
.full_get_segment_text(&state, i)
|
.full_get_segment_text(i)
|
||||||
.expect("failed to get segment");
|
.expect("failed to get segment");
|
||||||
let start_timestamp = ctx
|
let start_timestamp = state
|
||||||
.full_get_segment_t0(&state, i)
|
.full_get_segment_t0(i)
|
||||||
.expect("failed to get start timestamp");
|
.expect("failed to get start timestamp");
|
||||||
let end_timestamp = ctx
|
let end_timestamp = state
|
||||||
.full_get_segment_t1(&state, i)
|
.full_get_segment_t1(i)
|
||||||
.expect("failed to get end timestamp");
|
.expect("failed to get end timestamp");
|
||||||
|
|
||||||
// Print the segment to stdout.
|
// Print the segment to stdout.
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ pub fn usage() -> Result<(), &'static str> {
|
||||||
// load a context and model
|
// load a context and model
|
||||||
let ctx = WhisperContext::new("path/to/model").expect("failed to load model");
|
let ctx = WhisperContext::new("path/to/model").expect("failed to load model");
|
||||||
// make a state
|
// make a state
|
||||||
let state = ctx.create_state().expect("failed to create state");
|
let mut state = ctx.create_state().expect("failed to create state");
|
||||||
|
|
||||||
// create a params object
|
// create a params object
|
||||||
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
|
// note that currently the only implemented strategy is Greedy, BeamSearch is a WIP
|
||||||
|
|
@ -44,22 +44,23 @@ pub fn usage() -> Result<(), &'static str> {
|
||||||
|
|
||||||
// now we can run the model
|
// now we can run the model
|
||||||
// note the key we use here is the one we created above
|
// note the key we use here is the one we created above
|
||||||
ctx.full(&state, params, &audio_data[..])
|
state
|
||||||
|
.full(params, &audio_data[..])
|
||||||
.expect("failed to run model");
|
.expect("failed to run model");
|
||||||
|
|
||||||
// fetch the results
|
// fetch the results
|
||||||
let num_segments = ctx
|
let num_segments = state
|
||||||
.full_n_segments(&state)
|
.full_n_segments()
|
||||||
.expect("failed to get number of segments");
|
.expect("failed to get number of segments");
|
||||||
for i in 0..num_segments {
|
for i in 0..num_segments {
|
||||||
let segment = ctx
|
let segment = state
|
||||||
.full_get_segment_text(&state, i)
|
.full_get_segment_text(i)
|
||||||
.expect("failed to get segment");
|
.expect("failed to get segment");
|
||||||
let start_timestamp = ctx
|
let start_timestamp = state
|
||||||
.full_get_segment_t0(&state, i)
|
.full_get_segment_t0(i)
|
||||||
.expect("failed to get segment start timestamp");
|
.expect("failed to get segment start timestamp");
|
||||||
let end_timestamp = ctx
|
let end_timestamp = state
|
||||||
.full_get_segment_t1(&state, i)
|
.full_get_segment_t1(i)
|
||||||
.expect("failed to get segment end timestamp");
|
.expect("failed to get segment end timestamp");
|
||||||
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,17 +47,17 @@ fn main() {
|
||||||
|
|
||||||
let ctx =
|
let ctx =
|
||||||
WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
|
WhisperContext::new(&whisper_path.to_string_lossy()).expect("failed to open model");
|
||||||
let state = ctx.create_state().expect("failed to create key");
|
let mut state = ctx.create_state().expect("failed to create key");
|
||||||
let params = FullParams::new(SamplingStrategy::default());
|
let params = FullParams::new(SamplingStrategy::default());
|
||||||
|
|
||||||
ctx.full(&state, params, &samples)
|
state.full(params, &samples)
|
||||||
.expect("failed to convert samples");
|
.expect("failed to convert samples");
|
||||||
|
|
||||||
let num_segments = ctx.full_n_segments(&state).expect("failed to get number of segments");
|
let num_segments = state.full_n_segments().expect("failed to get number of segments");
|
||||||
for i in 0..num_segments {
|
for i in 0..num_segments {
|
||||||
let segment = ctx.full_get_segment_text(&state, i).expect("failed to get segment");
|
let segment = state.full_get_segment_text(i).expect("failed to get segment");
|
||||||
let start_timestamp = ctx.full_get_segment_t0(&state, i).expect("failed to get start timestamp");
|
let start_timestamp = state.full_get_segment_t0(i).expect("failed to get start timestamp");
|
||||||
let end_timestamp = ctx.full_get_segment_t1(&state, i).expect("failed to get end timestamp");
|
let end_timestamp = state.full_get_segment_t1(i).expect("failed to get end timestamp");
|
||||||
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
println!("[{} - {}]: {}", start_timestamp, end_timestamp, segment);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
use crate::error::WhisperError;
|
use crate::error::WhisperError;
|
||||||
use crate::whisper_params::FullParams;
|
|
||||||
use crate::whisper_state::WhisperState;
|
use crate::whisper_state::WhisperState;
|
||||||
use crate::{WhisperToken, WhisperTokenData};
|
use crate::WhisperToken;
|
||||||
use std::ffi::{c_int, CStr, CString};
|
use std::ffi::{c_int, CStr, CString};
|
||||||
|
|
||||||
/// Safe Rust wrapper around a Whisper context.
|
/// Safe Rust wrapper around a Whisper context.
|
||||||
|
|
@ -70,203 +69,7 @@ impl WhisperContext {
|
||||||
Err(WhisperError::InitError)
|
Err(WhisperError::InitError)
|
||||||
} else {
|
} else {
|
||||||
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
|
// SAFETY: this is known to be a valid pointer to a `whisper_state` struct
|
||||||
Ok(WhisperState::new(state))
|
Ok(WhisperState::new(self.ctx, state))
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
|
||||||
/// The resulting spectrogram is stored in the context transparently.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * pcm: The raw PCM audio.
|
|
||||||
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(()) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
|
||||||
pub fn pcm_to_mel(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
pcm: &[f32],
|
|
||||||
threads: usize,
|
|
||||||
) -> Result<(), WhisperError> {
|
|
||||||
if threads < 1 {
|
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
|
||||||
}
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_pcm_to_mel_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
pcm.as_ptr(),
|
|
||||||
pcm.len() as c_int,
|
|
||||||
threads as c_int,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::UnableToCalculateSpectrogram)
|
|
||||||
} else if ret == 0 {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(WhisperError::GenericError(ret))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
|
||||||
/// Applies a Phase Vocoder to speed up the audio x2.
|
|
||||||
/// The resulting spectrogram is stored in the context transparently.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * pcm: The raw PCM audio.
|
|
||||||
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(()) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
|
||||||
pub fn pcm_to_mel_phase_vocoder(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
pcm: &[f32],
|
|
||||||
threads: usize,
|
|
||||||
) -> Result<(), WhisperError> {
|
|
||||||
if threads < 1 {
|
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
|
||||||
}
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
pcm.as_ptr(),
|
|
||||||
pcm.len() as c_int,
|
|
||||||
threads as c_int,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::UnableToCalculateSpectrogram)
|
|
||||||
} else if ret == 0 {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(WhisperError::GenericError(ret))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// This can be used to set a custom log mel spectrogram inside the provided whisper state.
|
|
||||||
/// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
|
||||||
///
|
|
||||||
/// # Note
|
|
||||||
/// This is a low-level function.
|
|
||||||
/// If you're a typical user, you probably don't want to use this function.
|
|
||||||
/// See instead [WhisperContext::pcm_to_mel].
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * data: The log mel spectrogram.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(()) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
|
|
||||||
pub fn set_mel(&self, state: &WhisperState, data: &[f32]) -> Result<(), WhisperError> {
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_set_mel_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
data.as_ptr(),
|
|
||||||
data.len() as c_int,
|
|
||||||
80 as c_int,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::InvalidMelBands)
|
|
||||||
} else if ret == 0 {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(WhisperError::GenericError(ret))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper context.
|
|
||||||
/// Make sure to call [WhisperContext::pcm_to_mel] or [WhisperContext::set_mel] first.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * offset: Can be used to specify the offset of the first frame in the spectrogram. Usually 0.
|
|
||||||
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(()) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
|
|
||||||
pub fn encode(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
offset: usize,
|
|
||||||
threads: usize,
|
|
||||||
) -> Result<(), WhisperError> {
|
|
||||||
if threads < 1 {
|
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
|
||||||
}
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_encode_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
offset as c_int,
|
|
||||||
threads as c_int,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::UnableToCalculateEvaluation)
|
|
||||||
} else if ret == 0 {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(WhisperError::GenericError(ret))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
|
||||||
/// Make sure to call [WhisperContext::encode] first.
|
|
||||||
/// tokens + n_tokens is the provided context for the decoder.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * tokens: The tokens to decode.
|
|
||||||
/// * n_tokens: The number of tokens to decode.
|
|
||||||
/// * n_past: The number of past tokens to use for the decoding.
|
|
||||||
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(()) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
|
||||||
pub fn decode(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
tokens: &[WhisperToken],
|
|
||||||
n_past: usize,
|
|
||||||
threads: usize,
|
|
||||||
) -> Result<(), WhisperError> {
|
|
||||||
if threads < 1 {
|
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
|
||||||
}
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_decode_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
tokens.as_ptr(),
|
|
||||||
tokens.len() as c_int,
|
|
||||||
n_past as c_int,
|
|
||||||
threads as c_int,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::UnableToCalculateEvaluation)
|
|
||||||
} else if ret == 0 {
|
|
||||||
Ok(())
|
|
||||||
} else {
|
|
||||||
Err(WhisperError::GenericError(ret))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -304,74 +107,6 @@ impl WhisperContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Language functions
|
|
||||||
/// Use mel data at offset_ms to try and auto-detect the spoken language
|
|
||||||
/// Make sure to call pcm_to_mel() or set_mel() first
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * offset_ms: The offset in milliseconds to use for the language detection.
|
|
||||||
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(Vec<f32>) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
|
|
||||||
pub fn lang_detect(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
offset_ms: usize,
|
|
||||||
threads: usize,
|
|
||||||
) -> Result<Vec<f32>, WhisperError> {
|
|
||||||
if threads < 1 {
|
|
||||||
return Err(WhisperError::InvalidThreadCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_lang_auto_detect_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
offset_ms as c_int,
|
|
||||||
threads as c_int,
|
|
||||||
lang_probs.as_mut_ptr(),
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::UnableToCalculateEvaluation)
|
|
||||||
} else {
|
|
||||||
assert_eq!(
|
|
||||||
ret as usize,
|
|
||||||
lang_probs.len(),
|
|
||||||
"lang_probs length mismatch: this is a bug in whisper.cpp"
|
|
||||||
);
|
|
||||||
// if we're still running, double check that the length is correct, otherwise print to stderr
|
|
||||||
// and abort, as this will cause Undefined Behavior
|
|
||||||
// might get here due to the unwind being caught by a user-installed panic handler
|
|
||||||
if lang_probs.len() != ret as usize {
|
|
||||||
eprintln!(
|
|
||||||
"lang_probs length mismatch: this is a bug in whisper.cpp, \
|
|
||||||
aborting to avoid Undefined Behavior"
|
|
||||||
);
|
|
||||||
std::process::abort();
|
|
||||||
}
|
|
||||||
Ok(lang_probs)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// model attributes
|
|
||||||
/// Get the mel spectrogram length.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(c_int) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_n_len_from_state(struct whisper_context * ctx)`
|
|
||||||
#[inline]
|
|
||||||
pub fn n_len(&self, state: &WhisperState) -> Result<c_int, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(state.as_ptr()) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get n_vocab.
|
/// Get n_vocab.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
|
|
@ -561,46 +296,7 @@ impl WhisperContext {
|
||||||
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_model_type(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
// logit functions
|
/// token functions
|
||||||
/// Get the logits obtained from the last call to [WhisperContext::decode].
|
|
||||||
/// The logits for the last token are stored in the last row of the matrix.
|
|
||||||
///
|
|
||||||
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
|
|
||||||
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: The segment to fetch data for.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `float * whisper_get_logits(struct whisper_context * ctx)`
|
|
||||||
pub fn get_logits(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
) -> Result<Vec<Vec<f32>>, WhisperError> {
|
|
||||||
let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(state.as_ptr()) };
|
|
||||||
if ret.is_null() {
|
|
||||||
return Err(WhisperError::NullPointer);
|
|
||||||
}
|
|
||||||
let mut logits = Vec::new();
|
|
||||||
let n_vocab = self.n_vocab();
|
|
||||||
let n_tokens = self.full_n_tokens(state, segment)?;
|
|
||||||
for i in 0..n_tokens {
|
|
||||||
let mut row = Vec::new();
|
|
||||||
for j in 0..n_vocab {
|
|
||||||
let idx = (i * n_vocab) + j;
|
|
||||||
let val = unsafe { *ret.offset(idx as isize) };
|
|
||||||
row.push(val);
|
|
||||||
}
|
|
||||||
logits.push(row);
|
|
||||||
}
|
|
||||||
Ok(logits)
|
|
||||||
}
|
|
||||||
|
|
||||||
// token functions
|
|
||||||
/// Convert a token ID to a string.
|
/// Convert a token ID to a string.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
@ -719,249 +415,6 @@ impl WhisperContext {
|
||||||
pub fn reset_timings(&self) {
|
pub fn reset_timings(&self) {
|
||||||
unsafe { whisper_rs_sys::whisper_reset_timings(self.ctx) }
|
unsafe { whisper_rs_sys::whisper_reset_timings(self.ctx) }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
|
||||||
/// Uses the specified decoding strategy to obtain the text.
|
|
||||||
///
|
|
||||||
/// This is usually the only function you need to call as an end user.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * params: [crate::FullParams] struct.
|
|
||||||
/// * pcm: PCM audio data.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(c_int) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
|
|
||||||
pub fn full(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
params: FullParams,
|
|
||||||
data: &[f32],
|
|
||||||
) -> Result<c_int, WhisperError> {
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_with_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
params.fp,
|
|
||||||
data.as_ptr(),
|
|
||||||
data.len() as c_int,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret == -1 {
|
|
||||||
Err(WhisperError::UnableToCalculateSpectrogram)
|
|
||||||
} else if ret == 7 {
|
|
||||||
Err(WhisperError::FailedToEncode)
|
|
||||||
} else if ret == 8 {
|
|
||||||
Err(WhisperError::FailedToDecode)
|
|
||||||
} else if ret == 0 {
|
|
||||||
Ok(ret)
|
|
||||||
} else {
|
|
||||||
Err(WhisperError::GenericError(ret))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Number of generated text segments.
|
|
||||||
/// A segment can be a few words, a sentence, or even a paragraph.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_n_segments(&self, state: &WhisperState) -> Result<c_int, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(state.as_ptr()) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Language ID associated with the provided state.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_lang_id_from_state(&self, state: &WhisperState) -> Result<c_int, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(state.as_ptr()) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the start time of the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_segment_t0(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
) -> Result<i64, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_segment_t0_from_state(state.as_ptr(), segment)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the end time of the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_segment_t1(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
) -> Result<i64, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_segment_t1_from_state(state.as_ptr(), segment)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the text of the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
|
||||||
pub fn full_get_segment_text(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
) -> Result<String, WhisperError> {
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_segment_text_from_state(state.as_ptr(), segment)
|
|
||||||
};
|
|
||||||
if ret.is_null() {
|
|
||||||
return Err(WhisperError::NullPointer);
|
|
||||||
}
|
|
||||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
|
||||||
let r_str = c_str.to_str()?;
|
|
||||||
Ok(r_str.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get number of tokens in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// c_int
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_n_tokens(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
) -> Result<c_int, WhisperError> {
|
|
||||||
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(state.as_ptr(), segment) })
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the token text of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// Ok(String) on success, Err(WhisperError) on failure.
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
pub fn full_get_token_text(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<String, WhisperError> {
|
|
||||||
let ret = unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
|
||||||
self.ctx,
|
|
||||||
state.as_ptr(),
|
|
||||||
segment,
|
|
||||||
token,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
if ret.is_null() {
|
|
||||||
return Err(WhisperError::NullPointer);
|
|
||||||
}
|
|
||||||
let c_str = unsafe { CStr::from_ptr(ret) };
|
|
||||||
let r_str = c_str.to_str()?;
|
|
||||||
Ok(r_str.to_string())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the token ID of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// [crate::WhisperToken]
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
pub fn full_get_token_id(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<WhisperToken, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_id_from_state(state.as_ptr(), segment, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get token data for the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// [crate::WhisperTokenData]
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_token_data(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<WhisperTokenData, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_data_from_state(state.as_ptr(), segment, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the probability of the specified token in the specified segment.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
/// * segment: Segment index.
|
|
||||||
/// * token: Token index.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
/// f32
|
|
||||||
///
|
|
||||||
/// # C++ equivalent
|
|
||||||
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
|
||||||
#[inline]
|
|
||||||
pub fn full_get_token_prob(
|
|
||||||
&self,
|
|
||||||
state: &WhisperState,
|
|
||||||
segment: c_int,
|
|
||||||
token: c_int,
|
|
||||||
) -> Result<f32, WhisperError> {
|
|
||||||
Ok(unsafe {
|
|
||||||
whisper_rs_sys::whisper_full_get_token_p_from_state(state.as_ptr(), segment, token)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Drop for WhisperContext {
|
impl Drop for WhisperContext {
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
|
use crate::{FullParams, WhisperContext, WhisperError, WhisperToken, WhisperTokenData};
|
||||||
|
use std::ffi::{c_int, CStr};
|
||||||
use std::marker::PhantomData;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
/// Rustified pointer to a Whisper state.
|
/// Rustified pointer to a Whisper state.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WhisperState<'a> {
|
pub struct WhisperState<'a> {
|
||||||
|
ctx: *mut whisper_rs_sys::whisper_context,
|
||||||
ptr: *mut whisper_rs_sys::whisper_state,
|
ptr: *mut whisper_rs_sys::whisper_state,
|
||||||
_phantom: PhantomData<&'a ()>,
|
_phantom: PhantomData<&'a WhisperContext>,
|
||||||
}
|
}
|
||||||
|
|
||||||
unsafe impl<'a> Send for WhisperState<'a> {}
|
unsafe impl<'a> Send for WhisperState<'a> {}
|
||||||
|
|
@ -19,14 +22,516 @@ impl<'a> Drop for WhisperState<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> WhisperState<'a> {
|
impl<'a> WhisperState<'a> {
|
||||||
pub(crate) fn new(ptr: *mut whisper_rs_sys::whisper_state) -> Self {
|
pub(crate) fn new(
|
||||||
|
ctx: *mut whisper_rs_sys::whisper_context,
|
||||||
|
ptr: *mut whisper_rs_sys::whisper_state,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
ctx,
|
||||||
ptr,
|
ptr,
|
||||||
_phantom: PhantomData,
|
_phantom: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn as_ptr(&self) -> *mut whisper_rs_sys::whisper_state {
|
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
||||||
self.ptr
|
/// The resulting spectrogram is stored in the context transparently.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * pcm: The raw PCM audio.
|
||||||
|
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(()) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
||||||
|
pub fn pcm_to_mel(&mut self, pcm: &[f32], threads: usize) -> Result<(), WhisperError> {
|
||||||
|
if threads < 1 {
|
||||||
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
|
}
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_pcm_to_mel_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
pcm.as_ptr(),
|
||||||
|
pcm.len() as c_int,
|
||||||
|
threads as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateSpectrogram)
|
||||||
|
} else if ret == 0 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert raw PCM audio (floating point 32 bit) to log mel spectrogram.
|
||||||
|
/// Applies a Phase Vocoder to speed up the audio x2.
|
||||||
|
/// The resulting spectrogram is stored in the context transparently.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * pcm: The raw PCM audio.
|
||||||
|
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(()) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_pcm_to_mel(struct whisper_context * ctx, const float * samples, int n_samples, int n_threads)`
|
||||||
|
pub fn pcm_to_mel_phase_vocoder(
|
||||||
|
&mut self,
|
||||||
|
pcm: &[f32],
|
||||||
|
threads: usize,
|
||||||
|
) -> Result<(), WhisperError> {
|
||||||
|
if threads < 1 {
|
||||||
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
|
}
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_pcm_to_mel_phase_vocoder_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
pcm.as_ptr(),
|
||||||
|
pcm.len() as c_int,
|
||||||
|
threads as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateSpectrogram)
|
||||||
|
} else if ret == 0 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This can be used to set a custom log mel spectrogram inside the provided whisper state.
|
||||||
|
/// Use this instead of whisper_pcm_to_mel() if you want to provide your own log mel spectrogram.
|
||||||
|
///
|
||||||
|
/// # Note
|
||||||
|
/// This is a low-level function.
|
||||||
|
/// If you're a typical user, you probably don't want to use this function.
|
||||||
|
/// See instead [WhisperContext::pcm_to_mel].
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * data: The log mel spectrogram.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(()) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_set_mel(struct whisper_context * ctx, const float * data, int n_len, int n_mel)`
|
||||||
|
pub fn set_mel(&mut self, data: &[f32]) -> Result<(), WhisperError> {
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_set_mel_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
data.as_ptr(),
|
||||||
|
data.len() as c_int,
|
||||||
|
80 as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::InvalidMelBands)
|
||||||
|
} else if ret == 0 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the Whisper encoder on the log mel spectrogram stored inside the provided whisper state.
|
||||||
|
/// Make sure to call [WhisperContext::pcm_to_mel] or [WhisperContext::set_mel] first.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * offset: Can be used to specify the offset of the first frame in the spectrogram. Usually 0.
|
||||||
|
/// * threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(()) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_encode(struct whisper_context * ctx, int offset, int n_threads)`
|
||||||
|
pub fn encode(&mut self, offset: usize, threads: usize) -> Result<(), WhisperError> {
|
||||||
|
if threads < 1 {
|
||||||
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
|
}
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_encode_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
offset as c_int,
|
||||||
|
threads as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateEvaluation)
|
||||||
|
} else if ret == 0 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the Whisper decoder to obtain the logits and probabilities for the next token.
|
||||||
|
/// Make sure to call [WhisperContext::encode] first.
|
||||||
|
/// tokens + n_tokens is the provided context for the decoder.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * tokens: The tokens to decode.
|
||||||
|
/// * n_tokens: The number of tokens to decode.
|
||||||
|
/// * n_past: The number of past tokens to use for the decoding.
|
||||||
|
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(()) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_decode(struct whisper_context * ctx, const whisper_token * tokens, int n_tokens, int n_past, int n_threads)`
|
||||||
|
pub fn decode(
|
||||||
|
&mut self,
|
||||||
|
tokens: &[WhisperToken],
|
||||||
|
n_past: usize,
|
||||||
|
threads: usize,
|
||||||
|
) -> Result<(), WhisperError> {
|
||||||
|
if threads < 1 {
|
||||||
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
|
}
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_decode_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
tokens.as_ptr(),
|
||||||
|
tokens.len() as c_int,
|
||||||
|
n_past as c_int,
|
||||||
|
threads as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateEvaluation)
|
||||||
|
} else if ret == 0 {
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Language functions
|
||||||
|
/// Use mel data at offset_ms to try and auto-detect the spoken language
|
||||||
|
/// Make sure to call pcm_to_mel() or set_mel() first
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * offset_ms: The offset in milliseconds to use for the language detection.
|
||||||
|
/// * n_threads: How many threads to use. Defaults to 1. Must be at least 1, returns an error otherwise.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(Vec<f32>) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_lang_auto_detect(struct whisper_context * ctx, int offset_ms, int n_threads, float * lang_probs)`
|
||||||
|
pub fn lang_detect(&self, offset_ms: usize, threads: usize) -> Result<Vec<f32>, WhisperError> {
|
||||||
|
if threads < 1 {
|
||||||
|
return Err(WhisperError::InvalidThreadCount);
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut lang_probs: Vec<f32> = vec![0.0; crate::standalone::get_lang_max_id() as usize + 1];
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_lang_auto_detect_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
offset_ms as c_int,
|
||||||
|
threads as c_int,
|
||||||
|
lang_probs.as_mut_ptr(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateEvaluation)
|
||||||
|
} else {
|
||||||
|
assert_eq!(
|
||||||
|
ret as usize,
|
||||||
|
lang_probs.len(),
|
||||||
|
"lang_probs length mismatch: this is a bug in whisper.cpp"
|
||||||
|
);
|
||||||
|
// if we're still running, double check that the length is correct, otherwise print to stderr
|
||||||
|
// and abort, as this will cause Undefined Behavior
|
||||||
|
// might get here due to the unwind being caught by a user-installed panic handler
|
||||||
|
if lang_probs.len() != ret as usize {
|
||||||
|
eprintln!(
|
||||||
|
"lang_probs length mismatch: this is a bug in whisper.cpp, \
|
||||||
|
aborting to avoid Undefined Behavior"
|
||||||
|
);
|
||||||
|
std::process::abort();
|
||||||
|
}
|
||||||
|
Ok(lang_probs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// logit functions
|
||||||
|
/// Get the logits obtained from the last call to [WhisperContext::decode].
|
||||||
|
/// The logits for the last token are stored in the last row of the matrix.
|
||||||
|
///
|
||||||
|
/// Note: this function may be somewhat expensive depending on the size of the matrix returned, as it
|
||||||
|
/// needs to be rebuilt from the raw data. Try to avoid calling it more than once if possible.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: The segment to fetch data for.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// 2D matrix of logits. Row count is equal to n_tokens, column count is equal to n_vocab.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `float * whisper_get_logits(struct whisper_context * ctx)`
|
||||||
|
pub fn get_logits(&self, segment: c_int) -> Result<Vec<Vec<f32>>, WhisperError> {
|
||||||
|
let ret = unsafe { whisper_rs_sys::whisper_get_logits_from_state(self.ptr) };
|
||||||
|
if ret.is_null() {
|
||||||
|
return Err(WhisperError::NullPointer);
|
||||||
|
}
|
||||||
|
let mut logits = Vec::new();
|
||||||
|
let n_vocab = self.n_vocab();
|
||||||
|
let n_tokens = self.full_n_tokens(segment)?;
|
||||||
|
for i in 0..n_tokens {
|
||||||
|
let mut row = Vec::new();
|
||||||
|
for j in 0..n_vocab {
|
||||||
|
let idx = (i * n_vocab) + j;
|
||||||
|
let val = unsafe { *ret.offset(idx as isize) };
|
||||||
|
row.push(val);
|
||||||
|
}
|
||||||
|
logits.push(row);
|
||||||
|
}
|
||||||
|
Ok(logits)
|
||||||
|
}
|
||||||
|
|
||||||
|
// model attributes
|
||||||
|
/// Get the mel spectrogram length.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(c_int) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_n_len_from_state(struct whisper_context * ctx)`
|
||||||
|
#[inline]
|
||||||
|
pub fn n_len(&self) -> Result<c_int, WhisperError> {
|
||||||
|
Ok(unsafe { whisper_rs_sys::whisper_n_len_from_state(self.ptr) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get n_vocab.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// c_int
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_n_vocab (struct whisper_context * ctx)`
|
||||||
|
#[inline]
|
||||||
|
pub fn n_vocab(&self) -> c_int {
|
||||||
|
unsafe { whisper_rs_sys::whisper_n_vocab(self.ctx) }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||||
|
/// Uses the specified decoding strategy to obtain the text.
|
||||||
|
///
|
||||||
|
/// This is usually the only function you need to call as an end user.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * params: [crate::FullParams] struct.
|
||||||
|
/// * pcm: PCM audio data.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(c_int) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_full(struct whisper_context * ctx, struct whisper_full_params params, const float * samples, int n_samples)`
|
||||||
|
pub fn full(&mut self, params: FullParams, data: &[f32]) -> Result<c_int, WhisperError> {
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_with_state(
|
||||||
|
self.ctx,
|
||||||
|
self.ptr,
|
||||||
|
params.fp,
|
||||||
|
data.as_ptr(),
|
||||||
|
data.len() as c_int,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret == -1 {
|
||||||
|
Err(WhisperError::UnableToCalculateSpectrogram)
|
||||||
|
} else if ret == 7 {
|
||||||
|
Err(WhisperError::FailedToEncode)
|
||||||
|
} else if ret == 8 {
|
||||||
|
Err(WhisperError::FailedToDecode)
|
||||||
|
} else if ret == 0 {
|
||||||
|
Ok(ret)
|
||||||
|
} else {
|
||||||
|
Err(WhisperError::GenericError(ret))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Number of generated text segments.
|
||||||
|
/// A segment can be a few words, a sentence, or even a paragraph.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_full_n_segments(struct whisper_context * ctx)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_n_segments(&self) -> Result<c_int, WhisperError> {
|
||||||
|
Ok(unsafe { whisper_rs_sys::whisper_full_n_segments_from_state(self.ptr) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Language ID associated with the provided state.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_full_lang_id_from_state(struct whisper_state * state);`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_lang_id_from_state(&self) -> Result<c_int, WhisperError> {
|
||||||
|
Ok(unsafe { whisper_rs_sys::whisper_full_lang_id_from_state(self.ptr) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the start time of the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int64_t whisper_full_get_segment_t0(struct whisper_context * ctx, int i_segment)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_get_segment_t0(&self, segment: c_int) -> Result<i64, WhisperError> {
|
||||||
|
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t0_from_state(self.ptr, segment) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the end time of the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int64_t whisper_full_get_segment_t1(struct whisper_context * ctx, int i_segment)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_get_segment_t1(&self, segment: c_int) -> Result<i64, WhisperError> {
|
||||||
|
Ok(unsafe { whisper_rs_sys::whisper_full_get_segment_t1_from_state(self.ptr, segment) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the text of the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_segment_text(struct whisper_context * ctx, int i_segment)`
|
||||||
|
pub fn full_get_segment_text(&self, segment: c_int) -> Result<String, WhisperError> {
|
||||||
|
let ret =
|
||||||
|
unsafe { whisper_rs_sys::whisper_full_get_segment_text_from_state(self.ptr, segment) };
|
||||||
|
if ret.is_null() {
|
||||||
|
return Err(WhisperError::NullPointer);
|
||||||
|
}
|
||||||
|
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||||
|
let r_str = c_str.to_str()?;
|
||||||
|
Ok(r_str.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get number of tokens in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// c_int
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `int whisper_full_n_tokens(struct whisper_context * ctx, int i_segment)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_n_tokens(&self, segment: c_int) -> Result<c_int, WhisperError> {
|
||||||
|
Ok(unsafe { whisper_rs_sys::whisper_full_n_tokens_from_state(self.ptr, segment) })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token text of the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// Ok(String) on success, Err(WhisperError) on failure.
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `const char * whisper_full_get_token_text(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn full_get_token_text(
|
||||||
|
&self,
|
||||||
|
segment: c_int,
|
||||||
|
token: c_int,
|
||||||
|
) -> Result<String, WhisperError> {
|
||||||
|
let ret = unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_text_from_state(
|
||||||
|
self.ctx, self.ptr, segment, token,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
if ret.is_null() {
|
||||||
|
return Err(WhisperError::NullPointer);
|
||||||
|
}
|
||||||
|
let c_str = unsafe { CStr::from_ptr(ret) };
|
||||||
|
let r_str = c_str.to_str()?;
|
||||||
|
Ok(r_str.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the token ID of the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// [crate::WhisperToken]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `whisper_token whisper_full_get_token_id (struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
pub fn full_get_token_id(
|
||||||
|
&self,
|
||||||
|
segment: c_int,
|
||||||
|
token: c_int,
|
||||||
|
) -> Result<WhisperToken, WhisperError> {
|
||||||
|
Ok(unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_id_from_state(self.ptr, segment, token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get token data for the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// [crate::WhisperTokenData]
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `whisper_token_data whisper_full_get_token_data(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_get_token_data(
|
||||||
|
&self,
|
||||||
|
segment: c_int,
|
||||||
|
token: c_int,
|
||||||
|
) -> Result<WhisperTokenData, WhisperError> {
|
||||||
|
Ok(unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_data_from_state(self.ptr, segment, token)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the probability of the specified token in the specified segment.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * segment: Segment index.
|
||||||
|
/// * token: Token index.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// f32
|
||||||
|
///
|
||||||
|
/// # C++ equivalent
|
||||||
|
/// `float whisper_full_get_token_p(struct whisper_context * ctx, int i_segment, int i_token)`
|
||||||
|
#[inline]
|
||||||
|
pub fn full_get_token_prob(&self, segment: c_int, token: c_int) -> Result<f32, WhisperError> {
|
||||||
|
Ok(
|
||||||
|
unsafe {
|
||||||
|
whisper_rs_sys::whisper_full_get_token_p_from_state(self.ptr, segment, token)
|
||||||
|
},
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue