diff --git a/src/channels/discord.rs b/src/channels/discord.rs index b9e4da6..5e83b4d 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -158,7 +158,12 @@ impl Channel for DiscordChannel { tracing::info!("Discord: connected and identified"); - // Spawn heartbeat task + // Track the last sequence number for heartbeats and resume. + // Only accessed in the select! loop below, so a plain i64 suffices. + let mut sequence: i64 = -1; + + // Spawn heartbeat timer — sends a tick signal, actual heartbeat + // is assembled in the select! loop where `sequence` lives. let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1); let hb_interval = heartbeat_interval; tokio::spawn(async move { @@ -176,7 +181,8 @@ impl Channel for DiscordChannel { loop { tokio::select! { _ = hb_rx.recv() => { - let hb = json!({"op": 1, "d": null}); + let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; + let hb = json!({"op": 1, "d": d}); if write.send(Message::Text(hb.to_string())).await.is_err() { break; } @@ -193,6 +199,36 @@ impl Channel for DiscordChannel { Err(_) => continue, }; + // Track sequence number from all dispatch events + if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) { + sequence = s; + } + + let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0); + + match op { + // Op 1: Server requests an immediate heartbeat + 1 => { + let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; + let hb = json!({"op": 1, "d": d}); + if write.send(Message::Text(hb.to_string())).await.is_err() { + break; + } + continue; + } + // Op 7: Reconnect + 7 => { + tracing::warn!("Discord: received Reconnect (op 7), closing for restart"); + break; + } + // Op 9: Invalid Session + 9 => { + tracing::warn!("Discord: received Invalid Session (op 9), closing for restart"); + break; + } + _ => {} + } + // Only handle MESSAGE_CREATE (opcode 0, type "MESSAGE_CREATE") let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or(""); if event_type != "MESSAGE_CREATE" { diff --git a/src/config/schema.rs b/src/config/schema.rs index ecc0b9b..c6b02d2 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -9,7 +9,11 @@ use std::path::PathBuf; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Config { + /// Workspace directory - computed from home, not serialized + #[serde(skip)] pub workspace_dir: PathBuf, + /// Path to config.toml - computed from home, not serialized + #[serde(skip)] pub config_path: PathBuf, pub api_key: Option, pub default_provider: Option, @@ -694,11 +698,16 @@ impl Config { if config_path.exists() { let contents = fs::read_to_string(&config_path).context("Failed to read config file")?; - let config: Config = + let mut config: Config = toml::from_str(&contents).context("Failed to parse config file")?; + // Set computed paths that are skipped during serialization + config.config_path = config_path.clone(); + config.workspace_dir = zeroclaw_dir.join("workspace"); Ok(config) } else { - let config = Config::default(); + let mut config = Config::default(); + config.config_path = config_path.clone(); + config.workspace_dir = zeroclaw_dir.join("workspace"); config.save()?; Ok(config) } diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index 93e6914..b56f337 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -129,13 +129,17 @@ impl SqliteMemory { } } - /// Simple content hash for embedding cache + /// Deterministic content hash for embedding cache. + /// Uses SHA-256 (truncated) instead of DefaultHasher, which is + /// explicitly documented as unstable across Rust versions. fn content_hash(text: &str) -> String { - use std::collections::hash_map::DefaultHasher; - use std::hash::{Hash, Hasher}; - let mut hasher = DefaultHasher::new(); - text.hash(&mut hasher); - format!("{:016x}", hasher.finish()) + use sha2::{Digest, Sha256}; + let hash = Sha256::digest(text.as_bytes()); + // First 8 bytes → 16 hex chars, matching previous format length + format!( + "{:016x}", + u64::from_be_bytes(hash[..8].try_into().expect("SHA-256 always produces >= 8 bytes")) + ) } /// Get embedding from cache, or compute + cache it diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index c81bac0..d9da513 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -50,7 +50,10 @@ impl AnthropicProvider { .map(str::trim) .filter(|k| !k.is_empty()) .map(ToString::to_string), +<<<<<<< HEAD +======= base_url, +>>>>>>> origin/main client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -92,7 +95,11 @@ impl Provider for AnthropicProvider { let mut request = self .client +<<<<<<< HEAD + .post("https://api.anthropic.com/v1/messages") +======= .post(format!("{}/v1/messages", self.base_url)) +>>>>>>> origin/main .header("anthropic-version", "2023-06-01") .header("content-type", "application/json") .json(&request); @@ -129,14 +136,20 @@ mod tests { let p = AnthropicProvider::new(Some("sk-ant-test123")); assert!(p.credential.is_some()); assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); +<<<<<<< HEAD +======= assert_eq!(p.base_url, "https://api.anthropic.com"); +>>>>>>> origin/main } #[test] fn creates_without_key() { let p = AnthropicProvider::new(None); assert!(p.credential.is_none()); +<<<<<<< HEAD +======= assert_eq!(p.base_url, "https://api.anthropic.com"); +>>>>>>> origin/main } #[test] @@ -150,6 +163,8 @@ mod tests { let p = AnthropicProvider::new(Some(" sk-ant-test123 ")); assert!(p.credential.is_some()); assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); +<<<<<<< HEAD +======= } #[test] @@ -169,6 +184,7 @@ mod tests { fn default_base_url_when_none_provided() { let p = AnthropicProvider::with_base_url(None, None); assert_eq!(p.base_url, "https://api.anthropic.com"); +>>>>>>> origin/main } #[tokio::test] diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index e55e1f0..6aac0e2 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -43,6 +43,28 @@ impl OpenAiCompatibleProvider { .unwrap_or_else(|_| Client::new()), } } + + /// Build the full URL for chat completions, detecting if base_url already includes the path. + /// This allows custom providers with non-standard endpoints (e.g., VolcEngine ARK uses + /// `/api/coding/v3/chat/completions` instead of `/v1/chat/completions`). + fn chat_completions_url(&self) -> String { + // If base_url already contains "chat/completions", use it as-is + if self.base_url.contains("chat/completions") { + self.base_url.clone() + } else { + format!("{}/chat/completions", self.base_url) + } + } + + /// Build the full URL for responses API, detecting if base_url already includes the path. + fn responses_url(&self) -> String { + // If base_url already contains "responses", use it as-is + if self.base_url.contains("responses") { + self.base_url.clone() + } else { + format!("{}/v1/responses", self.base_url) + } + } } #[derive(Debug, Serialize)] @@ -177,7 +199,7 @@ impl OpenAiCompatibleProvider { stream: Some(false), }; - let url = format!("{}/v1/responses", self.base_url); + let url = self.responses_url(); let response = self .apply_auth_header(self.client.post(&url).json(&request), api_key) @@ -232,7 +254,7 @@ impl Provider for OpenAiCompatibleProvider { temperature, }; - let url = format!("{}/v1/chat/completions", self.base_url); + let url = self.chat_completions_url(); let response = self .apply_auth_header(self.client.post(&url).json(&request), api_key) @@ -421,4 +443,85 @@ mod tests { Some("Fallback text") ); } + + // ══════════════════════════════════════════════════════════ + // Custom endpoint path tests (Issue #114) + // ══════════════════════════════════════════════════════════ + + #[test] + fn chat_completions_url_standard_openai() { + // Standard OpenAI-compatible providers get /chat/completions appended + let p = make_provider("openai", "https://api.openai.com/v1", None); + assert_eq!(p.chat_completions_url(), "https://api.openai.com/v1/chat/completions"); + } + + #[test] + fn chat_completions_url_trailing_slash() { + // Trailing slash is stripped, then /chat/completions appended + let p = make_provider("test", "https://api.example.com/v1/", None); + assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions"); + } + + #[test] + fn chat_completions_url_volcengine_ark() { + // VolcEngine ARK uses custom path - should use as-is + let p = make_provider( + "volcengine", + "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://ark.cn-beijing.volces.com/api/coding/v3/chat/completions" + ); + } + + #[test] + fn chat_completions_url_custom_full_endpoint() { + // Custom provider with full endpoint path + let p = make_provider( + "custom", + "https://my-api.example.com/v2/llm/chat/completions", + None, + ); + assert_eq!( + p.chat_completions_url(), + "https://my-api.example.com/v2/llm/chat/completions" + ); + } + + #[test] + fn responses_url_standard() { + // Standard providers get /v1/responses appended + let p = make_provider("test", "https://api.example.com", None); + assert_eq!(p.responses_url(), "https://api.example.com/v1/responses"); + } + + #[test] + fn responses_url_custom_full_endpoint() { + // Custom provider with full responses endpoint + let p = make_provider( + "custom", + "https://my-api.example.com/api/v2/responses", + None, + ); + assert_eq!( + p.responses_url(), + "https://my-api.example.com/api/v2/responses" + ); + } + + #[test] + fn chat_completions_url_without_v1() { + // Provider configured without /v1 in base URL + let p = make_provider("test", "https://api.example.com", None); + assert_eq!(p.chat_completions_url(), "https://api.example.com/chat/completions"); + } + + #[test] + fn chat_completions_url_base_with_v1() { + // Provider configured with /v1 in base URL + let p = make_provider("test", "https://api.example.com/v1", None); + assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions"); + } }