Merge PR #500: streaming support and security fixes
- feat(streaming): add streaming support for LLM responses (fixes #211) - security(deps): remove vulnerable xmas-elf dependency via embuild (fixes #399) - fix: resolve merge conflicts and integrate chat_with_tools from main Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f75f73a50d
commit
69a9adde33
7 changed files with 484 additions and 20 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -4862,6 +4862,7 @@ dependencies = [
|
||||||
"dialoguer",
|
"dialoguer",
|
||||||
"directories",
|
"directories",
|
||||||
"fantoccini",
|
"fantoccini",
|
||||||
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"glob",
|
"glob",
|
||||||
"hex",
|
"hex",
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ name = "zeroclaw"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
authors = ["theonlyhennygod"]
|
authors = ["theonlyhennygod"]
|
||||||
license = "MIT"
|
license = "Apache-2.0"
|
||||||
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
description = "Zero overhead. Zero compromise. 100% Rust. The fastest, smallest AI assistant."
|
||||||
repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
repository = "https://github.com/zeroclaw-labs/zeroclaw"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
@ -85,6 +85,7 @@ glob = "0.3"
|
||||||
# Discord WebSocket gateway
|
# Discord WebSocket gateway
|
||||||
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
|
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
|
||||||
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
||||||
|
futures = "0.3"
|
||||||
hostname = "0.4.2"
|
hostname = "0.4.2"
|
||||||
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
|
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
|
||||||
mail-parser = "0.11.2"
|
mail-parser = "0.11.2"
|
||||||
|
|
|
||||||
16
firmware/zeroclaw-esp32/Cargo.lock
generated
16
firmware/zeroclaw-esp32/Cargo.lock
generated
|
|
@ -483,7 +483,6 @@ dependencies = [
|
||||||
"tempfile",
|
"tempfile",
|
||||||
"thiserror 1.0.69",
|
"thiserror 1.0.69",
|
||||||
"which",
|
"which",
|
||||||
"xmas-elf",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1806,21 +1805,6 @@ dependencies = [
|
||||||
"wasmparser",
|
"wasmparser",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "xmas-elf"
|
|
||||||
version = "0.9.1"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "42c49817e78342f7f30a181573d82ff55b88a35f86ccaf07fc64b3008f56d1c6"
|
|
||||||
dependencies = [
|
|
||||||
"zero",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "zero"
|
|
||||||
version = "0.1.3"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "2fe21bcc34ca7fe6dd56cc2cb1261ea59d6b93620215aefb5ea6032265527784"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "zeroclaw-esp32"
|
name = "zeroclaw-esp32"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
embuild = { version = "0.31", features = ["elf"] }
|
embuild = "0.31"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
opt-level = "s"
|
opt-level = "s"
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,10 @@
|
||||||
|
|
||||||
use crate::providers::traits::{
|
use crate::providers::traits::{
|
||||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||||
Provider, ToolCall as ProviderToolCall,
|
Provider, StreamChunk, StreamError, StreamOptions, StreamResult, ToolCall as ProviderToolCall,
|
||||||
};
|
};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures_util::{stream, StreamExt};
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
|
@ -219,6 +220,154 @@ struct ResponsesContent {
|
||||||
text: Option<String>,
|
text: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ═══════════════════════════════════════════════════════════════
|
||||||
|
// Streaming support (SSE parser)
|
||||||
|
// ═══════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/// Server-Sent Event stream chunk for OpenAI-compatible streaming.
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChunkResponse {
|
||||||
|
choices: Vec<StreamChoice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamChoice {
|
||||||
|
delta: StreamDelta,
|
||||||
|
finish_reason: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct StreamDelta {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse SSE (Server-Sent Events) stream from OpenAI-compatible providers.
|
||||||
|
/// Handles the `data: {...}` format and `[DONE]` sentinel.
|
||||||
|
fn parse_sse_line(line: &str) -> StreamResult<Option<String>> {
|
||||||
|
let line = line.trim();
|
||||||
|
|
||||||
|
// Skip empty lines and comments
|
||||||
|
if line.is_empty() || line.starts_with(':') {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE format: "data: {...}"
|
||||||
|
if let Some(data) = line.strip_prefix("data:") {
|
||||||
|
let data = data.trim();
|
||||||
|
|
||||||
|
// Check for [DONE] sentinel
|
||||||
|
if data == "[DONE]" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse JSON delta
|
||||||
|
let chunk: StreamChunkResponse = serde_json::from_str(data).map_err(StreamError::Json)?;
|
||||||
|
|
||||||
|
// Extract content from delta
|
||||||
|
if let Some(choice) = chunk.choices.first() {
|
||||||
|
if let Some(content) = &choice.delta.content {
|
||||||
|
return Ok(Some(content.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert SSE byte stream to text chunks.
|
||||||
|
async fn sse_bytes_to_chunks(
|
||||||
|
mut response: reqwest::Response,
|
||||||
|
count_tokens: bool,
|
||||||
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
|
use tokio::io::AsyncBufReadExt;
|
||||||
|
|
||||||
|
let name = "stream".to_string();
|
||||||
|
|
||||||
|
// Create a channel to send chunks
|
||||||
|
let (mut tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Buffer for incomplete lines
|
||||||
|
let mut buffer = String::new();
|
||||||
|
|
||||||
|
// Get response body as bytes stream
|
||||||
|
match response.error_for_status_ref() {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(StreamError::Http(e))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut bytes_stream = response.bytes_stream();
|
||||||
|
|
||||||
|
while let Some(item) = bytes_stream.next().await {
|
||||||
|
match item {
|
||||||
|
Ok(bytes) => {
|
||||||
|
// Convert bytes to string and process line by line
|
||||||
|
let text = match String::from_utf8(bytes.to_vec()) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx
|
||||||
|
.send(Err(StreamError::InvalidSse(format!(
|
||||||
|
"Invalid UTF-8: {}",
|
||||||
|
e
|
||||||
|
))))
|
||||||
|
.await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.push_str(&text);
|
||||||
|
|
||||||
|
// Process complete lines
|
||||||
|
while let Some(pos) = buffer.find('\n') {
|
||||||
|
let line = buffer.drain(..=pos).collect::<String>();
|
||||||
|
buffer = buffer[pos + 1..].to_string();
|
||||||
|
|
||||||
|
match parse_sse_line(&line) {
|
||||||
|
Ok(Some(content)) => {
|
||||||
|
let mut chunk = StreamChunk::delta(content);
|
||||||
|
if count_tokens {
|
||||||
|
chunk = chunk.with_token_estimate();
|
||||||
|
}
|
||||||
|
if tx.send(Ok(chunk)).await.is_err() {
|
||||||
|
return; // Receiver dropped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
// Empty line or [DONE] sentinel - continue
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(e)).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(StreamError::Http(e))).await;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send final chunk
|
||||||
|
let _ = tx.send(Ok(StreamChunk::final_chunk())).await;
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert channel receiver to stream
|
||||||
|
stream::unfold(rx, |mut rx| async {
|
||||||
|
match rx.recv().await {
|
||||||
|
Some(chunk) => Some((chunk, rx)),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
|
|
||||||
fn first_nonempty(text: Option<&str>) -> Option<String> {
|
fn first_nonempty(text: Option<&str>) -> Option<String> {
|
||||||
text.and_then(|value| {
|
text.and_then(|value| {
|
||||||
let trimmed = value.trim();
|
let trimmed = value.trim();
|
||||||
|
|
@ -525,6 +674,115 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
fn supports_native_tools(&self) -> bool {
|
fn supports_native_tools(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_chat_with_system(
|
||||||
|
&self,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
message: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
options: StreamOptions,
|
||||||
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
|
let api_key = match self.api_key.as_ref() {
|
||||||
|
Some(key) => key.clone(),
|
||||||
|
None => {
|
||||||
|
let provider_name = self.name.clone();
|
||||||
|
return stream::once(async move {
|
||||||
|
Err(StreamError::Provider(format!(
|
||||||
|
"{} API key not set",
|
||||||
|
provider_name
|
||||||
|
)))
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
if let Some(sys) = system_prompt {
|
||||||
|
messages.push(Message {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: sys.to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
messages.push(Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: message.to_string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages,
|
||||||
|
temperature,
|
||||||
|
stream: Some(options.enabled),
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = self.chat_completions_url();
|
||||||
|
let client = self.client.clone();
|
||||||
|
let auth_header = self.auth_header.clone();
|
||||||
|
|
||||||
|
// Use a channel to bridge the async HTTP response to the stream
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// Build request with auth
|
||||||
|
let mut req_builder = client.post(&url).json(&request);
|
||||||
|
|
||||||
|
// Apply auth header
|
||||||
|
req_builder = match &auth_header {
|
||||||
|
AuthStyle::Bearer => {
|
||||||
|
req_builder.header("Authorization", format!("Bearer {}", api_key))
|
||||||
|
}
|
||||||
|
AuthStyle::XApiKey => req_builder.header("x-api-key", &api_key),
|
||||||
|
AuthStyle::Custom(header) => req_builder.header(header, &api_key),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Set accept header for streaming
|
||||||
|
req_builder = req_builder.header("Accept", "text/event-stream");
|
||||||
|
|
||||||
|
// Send request
|
||||||
|
let response = match req_builder.send().await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let _ = tx.send(Err(StreamError::Http(e))).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check status
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let error = match response.text().await {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(_) => format!("HTTP error: {}", status),
|
||||||
|
};
|
||||||
|
let _ = tx
|
||||||
|
.send(Err(StreamError::Provider(format!("{}: {}", status, error))))
|
||||||
|
.await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to chunk stream and forward to channel
|
||||||
|
let mut chunk_stream = sse_bytes_to_chunks(response, options.count_tokens).await;
|
||||||
|
while let Some(chunk) = chunk_stream.next().await {
|
||||||
|
if tx.send(chunk).await.is_err() {
|
||||||
|
break; // Receiver dropped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert channel receiver to stream
|
||||||
|
stream::unfold(rx, |mut rx| async move {
|
||||||
|
match rx.recv().await {
|
||||||
|
Some(chunk) => Some((chunk, rx)),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
use super::traits::ChatMessage;
|
use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult};
|
||||||
use super::Provider;
|
use super::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures_util::{stream, StreamExt};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
@ -337,6 +338,82 @@ impl Provider for ReliableProvider {
|
||||||
failures.join("\n")
|
failures.join("\n")
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
self.providers.iter().any(|(_, p)| p.supports_streaming())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stream_chat_with_system(
|
||||||
|
&self,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
message: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
options: StreamOptions,
|
||||||
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
|
// Try each provider/model combination for streaming
|
||||||
|
// For streaming, we use the first provider that supports it and has streaming enabled
|
||||||
|
for (provider_name, provider) in &self.providers {
|
||||||
|
if !provider.supports_streaming() || !options.enabled {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone provider data for the stream
|
||||||
|
let provider_clone = provider_name.clone();
|
||||||
|
|
||||||
|
// Try the first model in the chain for streaming
|
||||||
|
let current_model = match self.model_chain(model).first() {
|
||||||
|
Some(m) => m.to_string(),
|
||||||
|
None => model.to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// For streaming, we attempt once and propagate errors
|
||||||
|
// The caller can retry the entire request if needed
|
||||||
|
let stream = provider.stream_chat_with_system(
|
||||||
|
system_prompt,
|
||||||
|
message,
|
||||||
|
¤t_model,
|
||||||
|
temperature,
|
||||||
|
options,
|
||||||
|
);
|
||||||
|
|
||||||
|
// Use a channel to bridge the stream with logging
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut stream = stream;
|
||||||
|
while let Some(chunk) = stream.next().await {
|
||||||
|
if let Err(ref e) = chunk {
|
||||||
|
tracing::warn!(
|
||||||
|
provider = provider_clone,
|
||||||
|
model = current_model,
|
||||||
|
"Streaming error: {e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if tx.send(chunk).await.is_err() {
|
||||||
|
break; // Receiver dropped
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Convert channel receiver to stream
|
||||||
|
return stream::unfold(rx, |mut rx| async move {
|
||||||
|
match rx.recv().await {
|
||||||
|
Some(chunk) => Some((chunk, rx)),
|
||||||
|
None => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.boxed();
|
||||||
|
}
|
||||||
|
|
||||||
|
// No streaming support available
|
||||||
|
stream::once(async move {
|
||||||
|
Err(super::traits::StreamError::Provider(
|
||||||
|
"No provider supports streaming".to_string(),
|
||||||
|
))
|
||||||
|
})
|
||||||
|
.boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
use crate::tools::ToolSpec;
|
use crate::tools::ToolSpec;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures_util::{stream, StreamExt};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A single message in a conversation.
|
/// A single message in a conversation.
|
||||||
|
|
@ -97,6 +98,99 @@ pub enum ConversationMessage {
|
||||||
ToolResults(Vec<ToolResultMessage>),
|
ToolResults(Vec<ToolResultMessage>),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// A chunk of content from a streaming response.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct StreamChunk {
|
||||||
|
/// Text delta for this chunk.
|
||||||
|
pub delta: String,
|
||||||
|
/// Whether this is the final chunk.
|
||||||
|
pub is_final: bool,
|
||||||
|
/// Approximate token count for this chunk (estimated).
|
||||||
|
pub token_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamChunk {
|
||||||
|
/// Create a new non-final chunk.
|
||||||
|
pub fn delta(text: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
delta: text.into(),
|
||||||
|
is_final: false,
|
||||||
|
token_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a final chunk.
|
||||||
|
pub fn final_chunk() -> Self {
|
||||||
|
Self {
|
||||||
|
delta: String::new(),
|
||||||
|
is_final: true,
|
||||||
|
token_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create an error chunk.
|
||||||
|
pub fn error(message: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
delta: message.into(),
|
||||||
|
is_final: true,
|
||||||
|
token_count: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Estimate tokens (rough approximation: ~4 chars per token).
|
||||||
|
pub fn with_token_estimate(mut self) -> Self {
|
||||||
|
self.token_count = (self.delta.len() + 3) / 4;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Options for streaming chat requests.
|
||||||
|
#[derive(Debug, Clone, Copy, Default)]
|
||||||
|
pub struct StreamOptions {
|
||||||
|
/// Whether to enable streaming (default: true).
|
||||||
|
pub enabled: bool,
|
||||||
|
/// Whether to include token counts in chunks.
|
||||||
|
pub count_tokens: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StreamOptions {
|
||||||
|
/// Create new streaming options with enabled flag.
|
||||||
|
pub fn new(enabled: bool) -> Self {
|
||||||
|
Self {
|
||||||
|
enabled,
|
||||||
|
count_tokens: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Enable token counting.
|
||||||
|
pub fn with_token_count(mut self) -> Self {
|
||||||
|
self.count_tokens = true;
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Result type for streaming operations.
|
||||||
|
pub type StreamResult<T> = std::result::Result<T, StreamError>;
|
||||||
|
|
||||||
|
/// Errors that can occur during streaming.
|
||||||
|
#[derive(Debug, thiserror::Error)]
|
||||||
|
pub enum StreamError {
|
||||||
|
#[error("HTTP error: {0}")]
|
||||||
|
Http(reqwest::Error),
|
||||||
|
|
||||||
|
#[error("JSON parse error: {0}")]
|
||||||
|
Json(serde_json::Error),
|
||||||
|
|
||||||
|
#[error("Invalid SSE format: {0}")]
|
||||||
|
InvalidSse(String),
|
||||||
|
|
||||||
|
#[error("Provider error: {0}")]
|
||||||
|
Provider(String),
|
||||||
|
|
||||||
|
#[error("IO error: {0}")]
|
||||||
|
Io(#[from] std::io::Error),
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||||
|
|
@ -187,6 +281,55 @@ pub trait Provider: Send + Sync {
|
||||||
tool_calls: Vec::new(),
|
tool_calls: Vec::new(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Whether provider supports streaming responses.
|
||||||
|
/// Default implementation returns false.
|
||||||
|
fn supports_streaming(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming chat with optional system prompt.
|
||||||
|
/// Returns an async stream of text chunks.
|
||||||
|
/// Default implementation falls back to non-streaming chat.
|
||||||
|
fn stream_chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
_options: StreamOptions,
|
||||||
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
|
// Default: return an empty stream (not supported)
|
||||||
|
stream::empty().boxed()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Streaming chat with history.
|
||||||
|
/// Default implementation falls back to stream_chat_with_system with last user message.
|
||||||
|
fn stream_chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
options: StreamOptions,
|
||||||
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
|
let system = messages
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.role == "system")
|
||||||
|
.map(|m| m.content.clone());
|
||||||
|
let last_user = messages
|
||||||
|
.iter()
|
||||||
|
.rfind(|m| m.role == "user")
|
||||||
|
.map(|m| m.content.clone())
|
||||||
|
.unwrap_or_default();
|
||||||
|
|
||||||
|
// For default implementation, we need to convert to owned strings
|
||||||
|
// This is a limitation of the default implementation
|
||||||
|
let provider_name = "unknown".to_string();
|
||||||
|
|
||||||
|
// Create a single empty chunk to indicate not supported
|
||||||
|
let chunk = StreamChunk::error(format!("{} does not support streaming", provider_name));
|
||||||
|
stream::once(async move { Ok(chunk) }).boxed()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue