diff --git a/Cargo.toml b/Cargo.toml index c825139..848eb52 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "zeroclaw" version = "0.1.0" edition = "2021" authors = ["theonlyhennygod"] -license = "MIT" +license = "Apache-2.0" description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant." repository = "https://github.com/zeroclaw-labs/zeroclaw" readme = "README.md" diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index a9942f0..c1ce0bb 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -4,9 +4,10 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ToolCall as ProviderToolCall, + Provider, StreamChunk, StreamError, StreamOptions, StreamResult, ToolCall as ProviderToolCall, }; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -219,6 +220,149 @@ struct ResponsesContent { text: Option, } +// ═══════════════════════════════════════════════════════════════ +// Streaming support (SSE parser) +// ═══════════════════════════════════════════════════════════════ + +/// Server-Sent Event stream chunk for OpenAI-compatible streaming. +#[derive(Debug, Deserialize)] +struct StreamChunkResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct StreamDelta { + #[serde(default)] + content: Option, +} + +/// Parse SSE (Server-Sent Events) stream from OpenAI-compatible providers. +/// Handles the `data: {...}` format and `[DONE]` sentinel. +fn parse_sse_line(line: &str) -> StreamResult> { + let line = line.trim(); + + // Skip empty lines and comments + if line.is_empty() || line.starts_with(':') { + return Ok(None); + } + + // SSE format: "data: {...}" + if let Some(data) = line.strip_prefix("data:") { + let data = data.trim(); + + // Check for [DONE] sentinel + if data == "[DONE]" { + return Ok(None); + } + + // Parse JSON delta + let chunk: StreamChunkResponse = serde_json::from_str(data) + .map_err(StreamError::Json)?; + + // Extract content from delta + if let Some(choice) = chunk.choices.first() { + if let Some(content) = &choice.delta.content { + return Ok(Some(content.clone())); + } + } + } + + Ok(None) +} + +/// Convert SSE byte stream to text chunks. +async fn sse_bytes_to_chunks( + mut response: reqwest::Response, + count_tokens: bool, +) -> stream::BoxStream<'static, StreamResult> { + use tokio::io::AsyncBufReadExt; + + let name = "stream".to_string(); + + // Create a channel to send chunks + let (mut tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + // Buffer for incomplete lines + let mut buffer = String::new(); + + // Get response body as bytes stream + match response.error_for_status_ref() { + Ok(_) => {}, + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + } + + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + // Convert bytes to string and process line by line + let text = match String::from_utf8(bytes.to_vec()) { + Ok(t) => t, + Err(e) => { + let _ = tx.send(Err(StreamError::InvalidSse(format!("Invalid UTF-8: {}", e)))).await; + break; + } + }; + + buffer.push_str(&text); + + // Process complete lines + while let Some(pos) = buffer.find('\n') { + let line = buffer.drain(..=pos).collect::(); + buffer = buffer[pos + 1..].to_string(); + + match parse_sse_line(&line) { + Ok(Some(content)) => { + let mut chunk = StreamChunk::delta(content); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; // Receiver dropped + } + } + Ok(None) => { + // Empty line or [DONE] sentinel - continue + continue; + } + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + // Send final chunk + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + // Convert channel receiver to stream + stream::unfold(rx, |mut rx| async { + match rx.recv().await { + Some(chunk) => Some((chunk, rx)), + None => None, + } + }).boxed() +} + fn first_nonempty(text: Option<&str>) -> Option { text.and_then(|value| { let trimmed = value.trim(); @@ -525,6 +669,109 @@ impl Provider for OpenAiCompatibleProvider { fn supports_native_tools(&self) -> bool { true } + + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + let api_key = match self.api_key.as_ref() { + Some(key) => key.clone(), + None => { + let provider_name = self.name.clone(); + return stream::once(async move { + Err(StreamError::Provider(format!( + "{} API key not set", + provider_name + ))) + }).boxed(); + } + }; + + let mut messages = Vec::new(); + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), + }); + + let request = ChatRequest { + model: model.to_string(), + messages, + temperature, + stream: Some(options.enabled), + }; + + let url = self.chat_completions_url(); + let client = self.client.clone(); + let auth_header = self.auth_header.clone(); + + // Use a channel to bridge the async HTTP response to the stream + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + // Build request with auth + let mut req_builder = client.post(&url).json(&request); + + // Apply auth header + req_builder = match &auth_header { + AuthStyle::Bearer => req_builder.header("Authorization", format!("Bearer {}", api_key)), + AuthStyle::XApiKey => req_builder.header("x-api-key", &api_key), + AuthStyle::Custom(header) => req_builder.header(header, &api_key), + }; + + // Set accept header for streaming + req_builder = req_builder.header("Accept", "text/event-stream"); + + // Send request + let response = match req_builder.send().await { + Ok(r) => r, + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + }; + + // Check status + if !response.status().is_success() { + let status = response.status(); + let error = match response.text().await { + Ok(e) => e, + Err(_) => format!("HTTP error: {}", status), + }; + let _ = tx.send(Err(StreamError::Provider(format!("{}: {}", status, error)))).await; + return; + } + + // Convert to chunk stream and forward to channel + let mut chunk_stream = sse_bytes_to_chunks(response, options.count_tokens).await; + while let Some(chunk) = chunk_stream.next().await { + if tx.send(chunk).await.is_err() { + break; // Receiver dropped + } + } + }); + + // Convert channel receiver to stream + stream::unfold(rx, |mut rx| async move { + match rx.recv().await { + Some(chunk) => Some((chunk, rx)), + None => None, + } + }).boxed() + } } #[cfg(test)] diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 41a0a1a..f5e1e23 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,6 +1,7 @@ -use super::traits::ChatMessage; +use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult}; use super::Provider; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -337,6 +338,80 @@ impl Provider for ReliableProvider { failures.join("\n") ) } + + fn supports_streaming(&self) -> bool { + self.providers.iter().any(|(_, p)| p.supports_streaming()) + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + // Try each provider/model combination for streaming + // For streaming, we use the first provider that supports it and has streaming enabled + for (provider_name, provider) in &self.providers { + if !provider.supports_streaming() || !options.enabled { + continue; + } + + // Clone provider data for the stream + let provider_clone = provider_name.clone(); + + // Try the first model in the chain for streaming + let current_model = match self.model_chain(model).first() { + Some(m) => m.to_string(), + None => model.to_string(), + }; + + // For streaming, we attempt once and propagate errors + // The caller can retry the entire request if needed + let stream = provider.stream_chat_with_system( + system_prompt, + message, + ¤t_model, + temperature, + options, + ); + + // Use a channel to bridge the stream with logging + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let mut stream = stream; + while let Some(chunk) = stream.next().await { + if let Err(ref e) = chunk { + tracing::warn!( + provider = provider_clone, + model = current_model, + "Streaming error: {e}" + ); + } + if tx.send(chunk).await.is_err() { + break; // Receiver dropped + } + } + }); + + // Convert channel receiver to stream + return stream::unfold(rx, |mut rx| async move { + match rx.recv().await { + Some(chunk) => Some((chunk, rx)), + None => None, + } + }).boxed(); + } + + // No streaming support available + stream::once(async move { + Err(super::traits::StreamError::Provider( + "No provider supports streaming".to_string() + )) + }).boxed() + } } #[cfg(test)]