diff --git a/Cargo.lock b/Cargo.lock index 0dd6b26..d940f9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4862,6 +4862,7 @@ dependencies = [ "dialoguer", "directories", "fantoccini", + "futures", "futures-util", "glob", "hex", diff --git a/Cargo.toml b/Cargo.toml index c825139..79dcdfe 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -3,7 +3,7 @@ name = "zeroclaw" version = "0.1.0" edition = "2021" authors = ["theonlyhennygod"] -license = "MIT" +license = "Apache-2.0" description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant." repository = "https://github.com/zeroclaw-labs/zeroclaw" readme = "README.md" @@ -85,6 +85,7 @@ glob = "0.3" # Discord WebSocket gateway tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] } futures-util = { version = "0.3", default-features = false, features = ["sink"] } +futures = "0.3" hostname = "0.4.2" lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] } mail-parser = "0.11.2" diff --git a/firmware/zeroclaw-esp32/Cargo.lock b/firmware/zeroclaw-esp32/Cargo.lock index 6f8ad22..2580883 100644 --- a/firmware/zeroclaw-esp32/Cargo.lock +++ b/firmware/zeroclaw-esp32/Cargo.lock @@ -483,7 +483,6 @@ dependencies = [ "tempfile", "thiserror 1.0.69", "which", - "xmas-elf", ] [[package]] @@ -1806,21 +1805,6 @@ dependencies = [ "wasmparser", ] -[[package]] -name = "xmas-elf" -version = "0.9.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42c49817e78342f7f30a181573d82ff55b88a35f86ccaf07fc64b3008f56d1c6" -dependencies = [ - "zero", -] - -[[package]] -name = "zero" -version = "0.1.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fe21bcc34ca7fe6dd56cc2cb1261ea59d6b93620215aefb5ea6032265527784" - [[package]] name = "zeroclaw-esp32" version = "0.1.0" diff --git a/firmware/zeroclaw-esp32/Cargo.toml b/firmware/zeroclaw-esp32/Cargo.toml index 2f7a001..70d2611 100644 --- a/firmware/zeroclaw-esp32/Cargo.toml +++ b/firmware/zeroclaw-esp32/Cargo.toml @@ -22,7 +22,7 @@ serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [build-dependencies] -embuild = { version = "0.31", features = ["elf"] } +embuild = "0.31" [profile.release] opt-level = "s" diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index a9942f0..cca5623 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -4,9 +4,10 @@ use crate::providers::traits::{ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, - Provider, ToolCall as ProviderToolCall, + Provider, StreamChunk, StreamError, StreamOptions, StreamResult, ToolCall as ProviderToolCall, }; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -219,6 +220,154 @@ struct ResponsesContent { text: Option, } +// ═══════════════════════════════════════════════════════════════ +// Streaming support (SSE parser) +// ═══════════════════════════════════════════════════════════════ + +/// Server-Sent Event stream chunk for OpenAI-compatible streaming. +#[derive(Debug, Deserialize)] +struct StreamChunkResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct StreamChoice { + delta: StreamDelta, + finish_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct StreamDelta { + #[serde(default)] + content: Option, +} + +/// Parse SSE (Server-Sent Events) stream from OpenAI-compatible providers. +/// Handles the `data: {...}` format and `[DONE]` sentinel. +fn parse_sse_line(line: &str) -> StreamResult> { + let line = line.trim(); + + // Skip empty lines and comments + if line.is_empty() || line.starts_with(':') { + return Ok(None); + } + + // SSE format: "data: {...}" + if let Some(data) = line.strip_prefix("data:") { + let data = data.trim(); + + // Check for [DONE] sentinel + if data == "[DONE]" { + return Ok(None); + } + + // Parse JSON delta + let chunk: StreamChunkResponse = serde_json::from_str(data).map_err(StreamError::Json)?; + + // Extract content from delta + if let Some(choice) = chunk.choices.first() { + if let Some(content) = &choice.delta.content { + return Ok(Some(content.clone())); + } + } + } + + Ok(None) +} + +/// Convert SSE byte stream to text chunks. +async fn sse_bytes_to_chunks( + mut response: reqwest::Response, + count_tokens: bool, +) -> stream::BoxStream<'static, StreamResult> { + use tokio::io::AsyncBufReadExt; + + let name = "stream".to_string(); + + // Create a channel to send chunks + let (mut tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + // Buffer for incomplete lines + let mut buffer = String::new(); + + // Get response body as bytes stream + match response.error_for_status_ref() { + Ok(_) => {} + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + } + + let mut bytes_stream = response.bytes_stream(); + + while let Some(item) = bytes_stream.next().await { + match item { + Ok(bytes) => { + // Convert bytes to string and process line by line + let text = match String::from_utf8(bytes.to_vec()) { + Ok(t) => t, + Err(e) => { + let _ = tx + .send(Err(StreamError::InvalidSse(format!( + "Invalid UTF-8: {}", + e + )))) + .await; + break; + } + }; + + buffer.push_str(&text); + + // Process complete lines + while let Some(pos) = buffer.find('\n') { + let line = buffer.drain(..=pos).collect::(); + buffer = buffer[pos + 1..].to_string(); + + match parse_sse_line(&line) { + Ok(Some(content)) => { + let mut chunk = StreamChunk::delta(content); + if count_tokens { + chunk = chunk.with_token_estimate(); + } + if tx.send(Ok(chunk)).await.is_err() { + return; // Receiver dropped + } + } + Ok(None) => { + // Empty line or [DONE] sentinel - continue + continue; + } + Err(e) => { + let _ = tx.send(Err(e)).await; + return; + } + } + } + } + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + break; + } + } + } + + // Send final chunk + let _ = tx.send(Ok(StreamChunk::final_chunk())).await; + }); + + // Convert channel receiver to stream + stream::unfold(rx, |mut rx| async { + match rx.recv().await { + Some(chunk) => Some((chunk, rx)), + None => None, + } + }) + .boxed() +} + fn first_nonempty(text: Option<&str>) -> Option { text.and_then(|value| { let trimmed = value.trim(); @@ -525,6 +674,115 @@ impl Provider for OpenAiCompatibleProvider { fn supports_native_tools(&self) -> bool { true } + + fn supports_streaming(&self) -> bool { + true + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + let api_key = match self.api_key.as_ref() { + Some(key) => key.clone(), + None => { + let provider_name = self.name.clone(); + return stream::once(async move { + Err(StreamError::Provider(format!( + "{} API key not set", + provider_name + ))) + }) + .boxed(); + } + }; + + let mut messages = Vec::new(); + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), + }); + + let request = ChatRequest { + model: model.to_string(), + messages, + temperature, + stream: Some(options.enabled), + }; + + let url = self.chat_completions_url(); + let client = self.client.clone(); + let auth_header = self.auth_header.clone(); + + // Use a channel to bridge the async HTTP response to the stream + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + // Build request with auth + let mut req_builder = client.post(&url).json(&request); + + // Apply auth header + req_builder = match &auth_header { + AuthStyle::Bearer => { + req_builder.header("Authorization", format!("Bearer {}", api_key)) + } + AuthStyle::XApiKey => req_builder.header("x-api-key", &api_key), + AuthStyle::Custom(header) => req_builder.header(header, &api_key), + }; + + // Set accept header for streaming + req_builder = req_builder.header("Accept", "text/event-stream"); + + // Send request + let response = match req_builder.send().await { + Ok(r) => r, + Err(e) => { + let _ = tx.send(Err(StreamError::Http(e))).await; + return; + } + }; + + // Check status + if !response.status().is_success() { + let status = response.status(); + let error = match response.text().await { + Ok(e) => e, + Err(_) => format!("HTTP error: {}", status), + }; + let _ = tx + .send(Err(StreamError::Provider(format!("{}: {}", status, error)))) + .await; + return; + } + + // Convert to chunk stream and forward to channel + let mut chunk_stream = sse_bytes_to_chunks(response, options.count_tokens).await; + while let Some(chunk) = chunk_stream.next().await { + if tx.send(chunk).await.is_err() { + break; // Receiver dropped + } + } + }); + + // Convert channel receiver to stream + stream::unfold(rx, |mut rx| async move { + match rx.recv().await { + Some(chunk) => Some((chunk, rx)), + None => None, + } + }) + .boxed() + } } #[cfg(test)] diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 41a0a1a..d91f02c 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,6 +1,7 @@ -use super::traits::ChatMessage; +use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult}; use super::Provider; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; @@ -337,6 +338,82 @@ impl Provider for ReliableProvider { failures.join("\n") ) } + + fn supports_streaming(&self) -> bool { + self.providers.iter().any(|(_, p)| p.supports_streaming()) + } + + fn stream_chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + // Try each provider/model combination for streaming + // For streaming, we use the first provider that supports it and has streaming enabled + for (provider_name, provider) in &self.providers { + if !provider.supports_streaming() || !options.enabled { + continue; + } + + // Clone provider data for the stream + let provider_clone = provider_name.clone(); + + // Try the first model in the chain for streaming + let current_model = match self.model_chain(model).first() { + Some(m) => m.to_string(), + None => model.to_string(), + }; + + // For streaming, we attempt once and propagate errors + // The caller can retry the entire request if needed + let stream = provider.stream_chat_with_system( + system_prompt, + message, + ¤t_model, + temperature, + options, + ); + + // Use a channel to bridge the stream with logging + let (tx, rx) = tokio::sync::mpsc::channel::>(100); + + tokio::spawn(async move { + let mut stream = stream; + while let Some(chunk) = stream.next().await { + if let Err(ref e) = chunk { + tracing::warn!( + provider = provider_clone, + model = current_model, + "Streaming error: {e}" + ); + } + if tx.send(chunk).await.is_err() { + break; // Receiver dropped + } + } + }); + + // Convert channel receiver to stream + return stream::unfold(rx, |mut rx| async move { + match rx.recv().await { + Some(chunk) => Some((chunk, rx)), + None => None, + } + }) + .boxed(); + } + + // No streaming support available + stream::once(async move { + Err(super::traits::StreamError::Provider( + "No provider supports streaming".to_string(), + )) + }) + .boxed() + } } #[cfg(test)] diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 7c61769..31f2cf5 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,5 +1,6 @@ use crate::tools::ToolSpec; use async_trait::async_trait; +use futures_util::{stream, StreamExt}; use serde::{Deserialize, Serialize}; /// A single message in a conversation. @@ -97,6 +98,99 @@ pub enum ConversationMessage { ToolResults(Vec), } +/// A chunk of content from a streaming response. +#[derive(Debug, Clone)] +pub struct StreamChunk { + /// Text delta for this chunk. + pub delta: String, + /// Whether this is the final chunk. + pub is_final: bool, + /// Approximate token count for this chunk (estimated). + pub token_count: usize, +} + +impl StreamChunk { + /// Create a new non-final chunk. + pub fn delta(text: impl Into) -> Self { + Self { + delta: text.into(), + is_final: false, + token_count: 0, + } + } + + /// Create a final chunk. + pub fn final_chunk() -> Self { + Self { + delta: String::new(), + is_final: true, + token_count: 0, + } + } + + /// Create an error chunk. + pub fn error(message: impl Into) -> Self { + Self { + delta: message.into(), + is_final: true, + token_count: 0, + } + } + + /// Estimate tokens (rough approximation: ~4 chars per token). + pub fn with_token_estimate(mut self) -> Self { + self.token_count = (self.delta.len() + 3) / 4; + self + } +} + +/// Options for streaming chat requests. +#[derive(Debug, Clone, Copy, Default)] +pub struct StreamOptions { + /// Whether to enable streaming (default: true). + pub enabled: bool, + /// Whether to include token counts in chunks. + pub count_tokens: bool, +} + +impl StreamOptions { + /// Create new streaming options with enabled flag. + pub fn new(enabled: bool) -> Self { + Self { + enabled, + count_tokens: false, + } + } + + /// Enable token counting. + pub fn with_token_count(mut self) -> Self { + self.count_tokens = true; + self + } +} + +/// Result type for streaming operations. +pub type StreamResult = std::result::Result; + +/// Errors that can occur during streaming. +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + #[error("HTTP error: {0}")] + Http(reqwest::Error), + + #[error("JSON parse error: {0}")] + Json(serde_json::Error), + + #[error("Invalid SSE format: {0}")] + InvalidSse(String), + + #[error("Provider error: {0}")] + Provider(String), + + #[error("IO error: {0}")] + Io(#[from] std::io::Error), +} + #[async_trait] pub trait Provider: Send + Sync { /// Simple one-shot chat (single user message, no explicit system prompt). @@ -187,6 +281,55 @@ pub trait Provider: Send + Sync { tool_calls: Vec::new(), }) } + + /// Whether provider supports streaming responses. + /// Default implementation returns false. + fn supports_streaming(&self) -> bool { + false + } + + /// Streaming chat with optional system prompt. + /// Returns an async stream of text chunks. + /// Default implementation falls back to non-streaming chat. + fn stream_chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + _options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + // Default: return an empty stream (not supported) + stream::empty().boxed() + } + + /// Streaming chat with history. + /// Default implementation falls back to stream_chat_with_system with last user message. + fn stream_chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + options: StreamOptions, + ) -> stream::BoxStream<'static, StreamResult> { + let system = messages + .iter() + .find(|m| m.role == "system") + .map(|m| m.content.clone()); + let last_user = messages + .iter() + .rfind(|m| m.role == "user") + .map(|m| m.content.clone()) + .unwrap_or_default(); + + // For default implementation, we need to convert to owned strings + // This is a limitation of the default implementation + let provider_name = "unknown".to_string(); + + // Create a single empty chunk to indicate not supported + let chunk = StreamChunk::error(format!("{} does not support streaming", provider_name)); + stream::once(async move { Ok(chunk) }).boxed() + } } #[cfg(test)]