diff --git a/wyoming-whisper-rs/src/protocol.rs b/wyoming-whisper-rs/src/protocol.rs index 02ddbe1..860456e 100644 --- a/wyoming-whisper-rs/src/protocol.rs +++ b/wyoming-whisper-rs/src/protocol.rs @@ -6,26 +6,24 @@ use crate::error::Error; const MAX_PAYLOAD: usize = 100 * 1024 * 1024; // 100 MB +/// Wire header: the JSON line sent before data/payload bytes. #[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Event { +struct Header { #[serde(rename = "type")] - pub event_type: String, + event_type: String, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub data: Option, + #[serde(default)] + data_length: usize, - #[serde(default, skip_serializing_if = "is_zero")] - pub data_length: usize, - - #[serde(default, skip_serializing_if = "is_zero")] - pub payload_length: usize, - - #[serde(skip)] - pub payload: Option>, + #[serde(default)] + payload_length: usize, } -fn is_zero(v: &usize) -> bool { - *v == 0 +#[derive(Debug, Clone)] +pub struct Event { + pub event_type: String, + pub data: Option, + pub payload: Option>, } impl Event { @@ -33,15 +31,11 @@ impl Event { Self { event_type: event_type.into(), data: None, - data_length: 0, - payload_length: 0, payload: None, } } pub fn with_data(mut self, data: Value) -> Self { - let serialized = serde_json::to_string(&data).unwrap_or_default(); - self.data_length = serialized.len(); self.data = Some(data); self } @@ -56,31 +50,67 @@ pub async fn read_event( return Ok(None); } - let mut event: Event = serde_json::from_str(line.trim())?; + let header: Header = serde_json::from_str(line.trim())?; - if event.payload_length > 0 { - if event.payload_length > MAX_PAYLOAD { + let data = if header.data_length > 0 { + if header.data_length > MAX_PAYLOAD { return Err(Error::PayloadTooLarge { - size: event.payload_length, + size: header.data_length, max: MAX_PAYLOAD, }); } - let mut buf = vec![0u8; event.payload_length]; + let mut buf = vec![0u8; header.data_length]; reader.read_exact(&mut buf).await?; - event.payload = Some(buf); - } + Some(serde_json::from_slice(&buf)?) + } else { + None + }; - Ok(Some(event)) + let payload = if header.payload_length > 0 { + if header.payload_length > MAX_PAYLOAD { + return Err(Error::PayloadTooLarge { + size: header.payload_length, + max: MAX_PAYLOAD, + }); + } + let mut buf = vec![0u8; header.payload_length]; + reader.read_exact(&mut buf).await?; + Some(buf) + } else { + None + }; + + Ok(Some(Event { + event_type: header.event_type, + data, + payload, + })) } pub async fn write_event( writer: &mut W, event: &Event, ) -> Result<(), Error> { - let json = serde_json::to_string(event)?; - writer.write_all(json.as_bytes()).await?; + let data_bytes = event + .data + .as_ref() + .map(|d| serde_json::to_vec(d)) + .transpose()?; + + let header = Header { + event_type: event.event_type.clone(), + data_length: data_bytes.as_ref().map_or(0, |b| b.len()), + payload_length: event.payload.as_ref().map_or(0, |b| b.len()), + }; + + let header_json = serde_json::to_string(&header)?; + writer.write_all(header_json.as_bytes()).await?; writer.write_all(b"\n").await?; + if let Some(ref data) = data_bytes { + writer.write_all(data).await?; + } + if let Some(ref payload) = event.payload { writer.write_all(payload).await?; }