From 384ddd77a5432bacd66e3022e66f450835ac442e Mon Sep 17 00:00:00 2001 From: Wenqing Zong Date: Wed, 20 Mar 2024 16:30:13 +0000 Subject: [PATCH 1/2] Add safe bindings for speaker diarization --- src/whisper_ctx.rs | 14 ++++++++++++++ src/whisper_state.rs | 16 ++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/src/whisper_ctx.rs b/src/whisper_ctx.rs index 330391e..9f7df00 100644 --- a/src/whisper_ctx.rs +++ b/src/whisper_ctx.rs @@ -517,6 +517,20 @@ impl WhisperContext { pub fn token_transcribe(&self) -> WhisperToken { unsafe { whisper_rs_sys::whisper_token_transcribe(self.ctx) } } + + /// Get whether the next segment is predicted as a speaker turn + /// + /// # Arguments + /// * i_segment: Segment index. + /// + /// # Returns + /// bool + /// + /// # C++ equivalent + /// `bool whisper_full_get_segment_speaker_turn_next(struct whisper_context * ctx, int i_segment)` + pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { + unsafe { whisper_rs_sys::whisper_full_get_segment_speaker_turn_next(self.ctx, i_segment) } + } } impl Drop for WhisperContext { diff --git a/src/whisper_state.rs b/src/whisper_state.rs index f336638..b657e8e 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -544,4 +544,20 @@ impl<'a> WhisperState<'a> { }, ) } + + /// Get whether the next segment is predicted as a speaker turn. + /// + /// # Arguments + /// * i_segment: Segment index. + /// + /// # Returns + /// bool + /// + /// # C++ equivalent + /// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)` + pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { + unsafe { + whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(self.ptr, i_segment) + } + } } From a026af12812d521ed8f7caab938dd2d2e8178271 Mon Sep 17 00:00:00 2001 From: Wenqing Zong Date: Wed, 20 Mar 2024 16:36:08 +0000 Subject: [PATCH 2/2] Pass fmt check --- src/whisper_state.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/whisper_state.rs b/src/whisper_state.rs index b657e8e..088af9c 100644 --- a/src/whisper_state.rs +++ b/src/whisper_state.rs @@ -557,7 +557,9 @@ impl<'a> WhisperState<'a> { /// `bool whisper_full_get_segment_speaker_turn_next_from_state(struct whisper_state * state, int i_segment)` pub fn full_get_segment_speaker_turn_next(&mut self, i_segment: c_int) -> bool { unsafe { - whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state(self.ptr, i_segment) + whisper_rs_sys::whisper_full_get_segment_speaker_turn_next_from_state( + self.ptr, i_segment, + ) } } }