diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index caa7e53..965670e 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -2511,4 +2511,185 @@ browser_open/url>https://example.com"#; assert_eq!(calls[0].arguments["command"], "pwd"); assert_eq!(text, "Done"); } + + // ───────────────────────────────────────────────────────────────────── + // TG4 (inline): parse_tool_calls robustness — malformed/edge-case inputs + // Prevents: Pattern 4 issues #746, #418, #777, #848 + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn parse_tool_calls_empty_input_returns_empty() { + let (text, calls) = parse_tool_calls(""); + assert!(calls.is_empty(), "empty input should produce no tool calls"); + assert!(text.is_empty(), "empty input should produce no text"); + } + + #[test] + fn parse_tool_calls_whitespace_only_returns_empty_calls() { + let (text, calls) = parse_tool_calls(" \n\t "); + assert!(calls.is_empty()); + assert!(text.is_empty() || text.trim().is_empty()); + } + + #[test] + fn parse_tool_calls_nested_xml_tags_handled() { + // Double-wrapped tool call should still parse the inner call + let response = r#"{"name":"echo","arguments":{"msg":"hi"}}"#; + let (_text, calls) = parse_tool_calls(response); + // Should find at least one tool call + assert!( + !calls.is_empty(), + "nested XML tags should still yield at least one tool call" + ); + } + + #[test] + fn parse_tool_calls_truncated_json_no_panic() { + // Incomplete JSON inside tool_call tags + let response = r#"{"name":"shell","arguments":{"command":"ls""#; + let (_text, _calls) = parse_tool_calls(response); + // Should not panic — graceful handling of truncated JSON + } + + #[test] + fn parse_tool_calls_empty_json_object_in_tag() { + let response = "{}"; + let (_text, calls) = parse_tool_calls(response); + // Empty JSON object has no name field — should not produce valid tool call + assert!( + calls.is_empty(), + "empty JSON object should not produce a tool call" + ); + } + + #[test] + fn parse_tool_calls_closing_tag_only_returns_text() { + let response = "Some text more text"; + let (text, calls) = parse_tool_calls(response); + assert!(calls.is_empty(), "closing tag only should not produce calls"); + assert!( + !text.is_empty(), + "text around orphaned closing tag should be preserved" + ); + } + + #[test] + fn parse_tool_calls_very_large_arguments_no_panic() { + let large_arg = "x".repeat(100_000); + let response = format!( + r#"{{"name":"echo","arguments":{{"message":"{}"}}}}"#, + large_arg + ); + let (_text, calls) = parse_tool_calls(&response); + assert_eq!(calls.len(), 1, "large arguments should still parse"); + assert_eq!(calls[0].name, "echo"); + } + + #[test] + fn parse_tool_calls_special_characters_in_arguments() { + let response = r#"{"name":"echo","arguments":{"message":"hello \"world\" <>&'\n\t"}}"#; + let (_text, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "echo"); + } + + #[test] + fn parse_tool_calls_text_with_embedded_json_not_extracted() { + // Raw JSON without any tags should NOT be extracted as a tool call + let response = r#"Here is some data: {"name":"echo","arguments":{"message":"hi"}} end."#; + let (_text, calls) = parse_tool_calls(response); + assert!( + calls.is_empty(), + "raw JSON in text without tags should not be extracted" + ); + } + + #[test] + fn parse_tool_calls_multiple_formats_mixed() { + // Mix of text and properly tagged tool call + let response = r#"I'll help you with that. + + +{"name":"shell","arguments":{"command":"echo hello"}} + + +Let me check the result."#; + let (text, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 1, "should extract one tool call from mixed content"); + assert_eq!(calls[0].name, "shell"); + assert!( + text.contains("help you"), + "text before tool call should be preserved" + ); + } + + // ───────────────────────────────────────────────────────────────────── + // TG4 (inline): scrub_credentials edge cases + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn scrub_credentials_empty_input() { + let result = scrub_credentials(""); + assert_eq!(result, ""); + } + + #[test] + fn scrub_credentials_no_sensitive_data() { + let input = "normal text without any secrets"; + let result = scrub_credentials(input); + assert_eq!(result, input, "non-sensitive text should pass through unchanged"); + } + + #[test] + fn scrub_credentials_short_values_not_redacted() { + // Values shorter than 8 chars should not be redacted + let input = r#"api_key="short""#; + let result = scrub_credentials(input); + assert_eq!(result, input, "short values should not be redacted"); + } + + // ───────────────────────────────────────────────────────────────────── + // TG4 (inline): trim_history edge cases + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn trim_history_empty_history() { + let mut history: Vec = vec![]; + trim_history(&mut history, 10); + assert!(history.is_empty()); + } + + #[test] + fn trim_history_system_only() { + let mut history = vec![crate::providers::ChatMessage::system("system prompt")]; + trim_history(&mut history, 10); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + } + + #[test] + fn trim_history_exactly_at_limit() { + let mut history = vec![ + crate::providers::ChatMessage::system("system"), + crate::providers::ChatMessage::user("msg 1"), + crate::providers::ChatMessage::assistant("reply 1"), + ]; + trim_history(&mut history, 2); // 2 non-system messages = exactly at limit + assert_eq!(history.len(), 3, "should not trim when exactly at limit"); + } + + #[test] + fn trim_history_removes_oldest_non_system() { + let mut history = vec![ + crate::providers::ChatMessage::system("system"), + crate::providers::ChatMessage::user("old msg"), + crate::providers::ChatMessage::assistant("old reply"), + crate::providers::ChatMessage::user("new msg"), + crate::providers::ChatMessage::assistant("new reply"), + ]; + trim_history(&mut history, 2); + assert_eq!(history.len(), 3); // system + 2 kept + assert_eq!(history[0].role, "system"); + assert_eq!(history[1].content, "new msg"); + } } diff --git a/src/channels/discord.rs b/src/channels/discord.rs index d7a4d20..a9e110e 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -840,4 +840,110 @@ mod tests { // Should have UUID dashes assert!(id.contains('-')); } + + // ───────────────────────────────────────────────────────────────────── + // TG6: Channel platform limit edge cases for Discord (2000 char limit) + // Prevents: Pattern 6 — issues #574, #499 + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn split_message_code_block_at_boundary() { + // Code block that spans the split boundary + let mut msg = String::new(); + msg.push_str("```rust\n"); + msg.push_str(&"x".repeat(1990)); + msg.push_str("\n```\nMore text after code block"); + let parts = split_message_for_discord(&msg); + assert!(parts.len() >= 2, "code block spanning boundary should split"); + for part in &parts { + assert!( + part.len() <= DISCORD_MAX_MESSAGE_LENGTH, + "each part must be <= {DISCORD_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + } + + #[test] + fn split_message_single_long_word_exceeds_limit() { + // A single word longer than 2000 chars must be hard-split + let long_word = "a".repeat(2500); + let parts = split_message_for_discord(&long_word); + assert!(parts.len() >= 2, "word exceeding limit must be split"); + for part in &parts { + assert!( + part.len() <= DISCORD_MAX_MESSAGE_LENGTH, + "hard-split part must be <= {DISCORD_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + // Reassembled content should match original + let reassembled: String = parts.join(""); + assert_eq!(reassembled, long_word); + } + + #[test] + fn split_message_exactly_at_limit_no_split() { + let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH); + let parts = split_message_for_discord(&msg); + assert_eq!(parts.len(), 1, "message exactly at limit should not split"); + assert_eq!(parts[0].len(), DISCORD_MAX_MESSAGE_LENGTH); + } + + #[test] + fn split_message_one_over_limit_splits() { + let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH + 1); + let parts = split_message_for_discord(&msg); + assert!(parts.len() >= 2, "message 1 char over limit must split"); + } + + #[test] + fn split_message_many_short_lines() { + // Many short lines should be batched into chunks under the limit + let msg: String = (0..500).map(|i| format!("line {i}\n")).collect(); + let parts = split_message_for_discord(&msg); + for part in &parts { + assert!( + part.len() <= DISCORD_MAX_MESSAGE_LENGTH, + "short-line batch must be <= limit" + ); + } + // All content should be preserved + let reassembled: String = parts.join(""); + assert_eq!(reassembled.trim(), msg.trim()); + } + + #[test] + fn split_message_only_whitespace() { + let msg = " \n\n\t "; + let parts = split_message_for_discord(msg); + // Should handle gracefully without panic + assert!(parts.len() <= 1); + } + + #[test] + fn split_message_emoji_at_boundary() { + // Emoji are multi-byte; ensure we don't split mid-emoji + let mut msg = "a".repeat(1998); + msg.push_str("🎉🎊"); // 2 emoji at the boundary (2000 chars total) + let parts = split_message_for_discord(&msg); + for part in &parts { + // The function splits on character count, not byte count + assert!( + part.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH, + "emoji boundary split must respect limit" + ); + } + } + + #[test] + fn split_message_consecutive_newlines_at_boundary() { + let mut msg = "a".repeat(1995); + msg.push_str("\n\n\n\n\n"); + msg.push_str(&"b".repeat(100)); + let parts = split_message_for_discord(&msg); + for part in &parts { + assert!(part.len() <= DISCORD_MAX_MESSAGE_LENGTH); + } + } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index ca0e03b..d6efc04 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -18,7 +18,7 @@ const TELEGRAM_BIND_COMMAND: &str = "/bind"; /// Split a message into chunks that respect Telegram's 4096 character limit. /// Tries to split at word boundaries when possible, and handles continuation. fn split_message_for_telegram(message: &str) -> Vec { - if message.len() <= TELEGRAM_MAX_MESSAGE_LENGTH { + if message.chars().count() <= TELEGRAM_MAX_MESSAGE_LENGTH { return vec![message.to_string()]; } @@ -26,29 +26,35 @@ fn split_message_for_telegram(message: &str) -> Vec { let mut remaining = message; while !remaining.is_empty() { - let chunk_end = if remaining.len() <= TELEGRAM_MAX_MESSAGE_LENGTH { - remaining.len() + // Find the byte offset for the Nth character boundary. + let hard_split = remaining + .char_indices() + .nth(TELEGRAM_MAX_MESSAGE_LENGTH) + .map_or(remaining.len(), |(idx, _)| idx); + + let chunk_end = if hard_split == remaining.len() { + hard_split } else { // Try to find a good break point (newline, then space) - let search_area = &remaining[..TELEGRAM_MAX_MESSAGE_LENGTH]; + let search_area = &remaining[..hard_split]; // Prefer splitting at newline if let Some(pos) = search_area.rfind('\n') { // Don't split if the newline is too close to the start - if pos >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 { + if search_area[..pos].chars().count() >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 { pos + 1 } else { // Try space as fallback search_area .rfind(' ') - .unwrap_or(TELEGRAM_MAX_MESSAGE_LENGTH) + .unwrap_or(hard_split) + 1 } } else if let Some(pos) = search_area.rfind(' ') { pos + 1 } else { - // Hard split at the limit - TELEGRAM_MAX_MESSAGE_LENGTH + // Hard split at character boundary + hard_split } }; @@ -2830,4 +2836,100 @@ mod tests { let ch_disabled = TelegramChannel::new("token".into(), vec!["*".into()], false); assert!(!ch_disabled.mention_only); } + + // ───────────────────────────────────────────────────────────────────── + // TG6: Channel platform limit edge cases for Telegram (4096 char limit) + // Prevents: Pattern 6 — issues #574, #499 + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn telegram_split_code_block_at_boundary() { + let mut msg = String::new(); + msg.push_str("```python\n"); + msg.push_str(&"x".repeat(4085)); + msg.push_str("\n```\nMore text after code block"); + let parts = split_message_for_telegram(&msg); + assert!(parts.len() >= 2, "code block spanning boundary should split"); + for part in &parts { + assert!( + part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "each part must be <= {TELEGRAM_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + } + + #[test] + fn telegram_split_single_long_word() { + let long_word = "a".repeat(5000); + let parts = split_message_for_telegram(&long_word); + assert!(parts.len() >= 2, "word exceeding limit must be split"); + for part in &parts { + assert!( + part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "hard-split part must be <= {TELEGRAM_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + let reassembled: String = parts.join(""); + assert_eq!(reassembled, long_word); + } + + #[test] + fn telegram_split_exactly_at_limit_no_split() { + let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH); + let parts = split_message_for_telegram(&msg); + assert_eq!(parts.len(), 1, "message exactly at limit should not split"); + } + + #[test] + fn telegram_split_one_over_limit() { + let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH + 1); + let parts = split_message_for_telegram(&msg); + assert!(parts.len() >= 2, "message 1 char over limit must split"); + } + + #[test] + fn telegram_split_many_short_lines() { + let msg: String = (0..1000).map(|i| format!("line {i}\n")).collect(); + let parts = split_message_for_telegram(&msg); + for part in &parts { + assert!( + part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "short-line batch must be <= limit" + ); + } + } + + #[test] + fn telegram_split_only_whitespace() { + let msg = " \n\n\t "; + let parts = split_message_for_telegram(msg); + assert!(parts.len() <= 1); + } + + #[test] + fn telegram_split_emoji_at_boundary() { + let mut msg = "a".repeat(4094); + msg.push_str("🎉🎊"); // 4096 chars total + let parts = split_message_for_telegram(&msg); + for part in &parts { + // The function splits on character count, not byte count + assert!( + part.chars().count() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "emoji boundary split must respect limit" + ); + } + } + + #[test] + fn telegram_split_consecutive_newlines() { + let mut msg = "a".repeat(4090); + msg.push_str("\n\n\n\n\n\n"); + msg.push_str(&"b".repeat(100)); + let parts = split_message_for_telegram(&msg); + for part in &parts { + assert!(part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + } + } } diff --git a/tests/agent_loop_robustness.rs b/tests/agent_loop_robustness.rs new file mode 100644 index 0000000..f63b51f --- /dev/null +++ b/tests/agent_loop_robustness.rs @@ -0,0 +1,438 @@ +//! TG4: Agent Loop Robustness Tests +//! +//! Prevents: Pattern 4 — Agent loop & tool call processing bugs (13% of user bugs). +//! Issues: #746, #418, #777, #848 +//! +//! Tests agent behavior with malformed tool calls, empty responses, +//! max iteration limits, and cascading tool failures using mock providers. +//! Complements inline parse_tool_calls tests in `src/agent/loop_.rs`. + +use anyhow::Result; +use async_trait::async_trait; +use serde_json::json; +use std::sync::{Arc, Mutex}; +use zeroclaw::agent::agent::Agent; +use zeroclaw::agent::dispatcher::NativeToolDispatcher; +use zeroclaw::config::MemoryConfig; +use zeroclaw::memory; +use zeroclaw::memory::Memory; +use zeroclaw::observability::{NoopObserver, Observer}; +use zeroclaw::providers::{ChatRequest, ChatResponse, Provider, ToolCall}; +use zeroclaw::tools::{Tool, ToolResult}; + +// ───────────────────────────────────────────────────────────────────────────── +// Mock infrastructure +// ───────────────────────────────────────────────────────────────────────────── + +struct MockProvider { + responses: Mutex>, +} + +impl MockProvider { + fn new(responses: Vec) -> Self { + Self { + responses: Mutex::new(responses), + } + } +} + +#[async_trait] +impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("fallback".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + let mut guard = self.responses.lock().unwrap(); + if guard.is_empty() { + return Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }); + } + Ok(guard.remove(0)) + } +} + +struct EchoTool; + +#[async_trait] +impl Tool for EchoTool { + fn name(&self) -> &str { + "echo" + } + fn description(&self) -> &str { + "Echoes the input message" + } + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "message": {"type": "string"} + } + }) + } + async fn execute(&self, args: serde_json::Value) -> Result { + let msg = args + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("(empty)") + .to_string(); + Ok(ToolResult { + success: true, + output: msg, + error: None, + }) + } +} + +/// Tool that always fails, simulating a broken external service +struct FailingTool; + +#[async_trait] +impl Tool for FailingTool { + fn name(&self) -> &str { + "failing_tool" + } + fn description(&self) -> &str { + "Always fails" + } + fn parameters_schema(&self) -> serde_json::Value { + json!({"type": "object"}) + } + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Service unavailable: connection timeout".into()), + }) + } +} + +/// Tool that tracks invocations +struct CountingTool { + count: Arc>, +} + +impl CountingTool { + fn new() -> (Self, Arc>) { + let count = Arc::new(Mutex::new(0)); + (Self { count: count.clone() }, count) + } +} + +#[async_trait] +impl Tool for CountingTool { + fn name(&self) -> &str { + "counter" + } + fn description(&self) -> &str { + "Counts invocations" + } + fn parameters_schema(&self) -> serde_json::Value { + json!({"type": "object"}) + } + async fn execute(&self, _args: serde_json::Value) -> Result { + let mut c = self.count.lock().unwrap(); + *c += 1; + Ok(ToolResult { + success: true, + output: format!("call #{}", *c), + error: None, + }) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Test helpers +// ───────────────────────────────────────────────────────────────────────────── + +fn make_memory() -> Arc { + let cfg = MemoryConfig { + backend: "none".into(), + ..MemoryConfig::default() + }; + Arc::from(memory::create_memory(&cfg, &std::env::temp_dir(), None).unwrap()) +} + +fn make_observer() -> Arc { + Arc::from(NoopObserver {}) +} + +fn text_response(text: &str) -> ChatResponse { + ChatResponse { + text: Some(text.into()), + tool_calls: vec![], + } +} + +fn tool_response(calls: Vec) -> ChatResponse { + ChatResponse { + text: Some(String::new()), + tool_calls: calls, + } +} + +fn build_agent(provider: Box, tools: Vec>) -> Agent { + Agent::builder() + .provider(provider) + .tools(tools) + .memory(make_memory()) + .observer(make_observer()) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::env::temp_dir()) + .build() + .unwrap() +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.1: Malformed tool call recovery +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should recover when LLM returns text with residual XML tags (#746) +#[tokio::test] +async fn agent_recovers_from_text_with_xml_residue() { + let provider = Box::new(MockProvider::new(vec![text_response( + "Here is the result. Some leftover text after.", + )])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("test").await.unwrap(); + assert!( + !response.is_empty(), + "agent should produce non-empty response despite XML residue" + ); +} + +/// Agent should handle tool call with empty arguments gracefully +#[tokio::test] +async fn agent_handles_tool_call_with_empty_arguments() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: "{}".into(), + }]), + text_response("Tool with empty args executed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("call with empty args").await.unwrap(); + assert!(!response.is_empty()); +} + +/// Agent should handle unknown tool name without crashing (#848 related) +#[tokio::test] +async fn agent_handles_nonexistent_tool_gracefully() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "absolutely_nonexistent_tool".into(), + arguments: "{}".into(), + }]), + text_response("Recovered from unknown tool"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("call missing tool").await.unwrap(); + assert!( + !response.is_empty(), + "agent should recover from unknown tool" + ); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.2: Tool failure cascade handling (#848) +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should handle repeated tool failures without infinite loop +#[tokio::test] +async fn agent_handles_failing_tool() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "failing_tool".into(), + arguments: "{}".into(), + }]), + text_response("Tool failed but I recovered"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(FailingTool)]); + let response = agent.turn("use failing tool").await.unwrap(); + assert!( + !response.is_empty(), + "agent should produce response even after tool failure" + ); +} + +/// Agent should handle mixed tool calls (some succeed, some fail) +#[tokio::test] +async fn agent_handles_mixed_tool_success_and_failure() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ + ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "success"}"#.into(), + }, + ToolCall { + id: "tc2".into(), + name: "failing_tool".into(), + arguments: "{}".into(), + }, + ]), + text_response("Mixed results processed"), + ])); + + let mut agent = build_agent( + provider, + vec![Box::new(EchoTool), Box::new(FailingTool)], + ); + let response = agent.turn("mixed tools").await.unwrap(); + assert!(!response.is_empty()); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.3: Iteration limit enforcement (#777) +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should not exceed max_tool_iterations (default=10) even with +/// a provider that keeps returning tool calls +#[tokio::test] +async fn agent_respects_max_tool_iterations() { + let (counting_tool, count) = CountingTool::new(); + + // Create 20 tool call responses - more than the default limit of 10 + let mut responses: Vec = (0..20) + .map(|i| { + tool_response(vec![ToolCall { + id: format!("tc_{i}"), + name: "counter".into(), + arguments: "{}".into(), + }]) + }) + .collect(); + // Add a final text response that would be used if limit is reached + responses.push(text_response("Final response after iterations")); + + let provider = Box::new(MockProvider::new(responses)); + let mut agent = build_agent(provider, vec![Box::new(counting_tool)]); + + // Agent should complete (either by hitting iteration limit or running out of responses) + let result = agent.turn("keep calling tools").await; + // The agent should complete without hanging + assert!(result.is_ok() || result.is_err()); + + let invocations = *count.lock().unwrap(); + assert!( + invocations <= 10, + "tool invocations ({invocations}) should not exceed default max_tool_iterations (10)" + ); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.4: Empty and whitespace responses +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should handle empty text response from provider (#418 related) +#[tokio::test] +async fn agent_handles_empty_provider_response() { + let provider = Box::new(MockProvider::new(vec![ChatResponse { + text: Some(String::new()), + tool_calls: vec![], + }])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + // Should not panic + let _result = agent.turn("test").await; +} + +/// Agent should handle None text response from provider +#[tokio::test] +async fn agent_handles_none_text_response() { + let provider = Box::new(MockProvider::new(vec![ChatResponse { + text: None, + tool_calls: vec![], + }])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let _result = agent.turn("test").await; +} + +/// Agent should handle whitespace-only response +#[tokio::test] +async fn agent_handles_whitespace_only_response() { + let provider = Box::new(MockProvider::new(vec![text_response(" \n\t ")])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let _result = agent.turn("test").await; +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.5: Tool call with special content +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should handle tool arguments with unicode content +#[tokio::test] +async fn agent_handles_unicode_tool_arguments() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "こんにちは世界 🌍"}"#.into(), + }]), + text_response("Unicode tool executed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("unicode test").await.unwrap(); + assert!(!response.is_empty()); +} + +/// Agent should handle tool arguments with nested JSON +#[tokio::test] +async fn agent_handles_nested_json_tool_arguments() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "{\"nested\": true, \"deep\": {\"level\": 3}}"}"#.into(), + }]), + text_response("Nested JSON tool executed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("nested json test").await.unwrap(); + assert!(!response.is_empty()); +} + +/// Agent should handle tool call followed by immediate text (no second LLM call) +#[tokio::test] +async fn agent_handles_sequential_tool_then_text() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "step 1"}"#.into(), + }]), + text_response("Final answer after tool"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("two step").await.unwrap(); + assert!( + !response.is_empty(), + "should produce final text after tool execution" + ); +} diff --git a/tests/channel_routing.rs b/tests/channel_routing.rs new file mode 100644 index 0000000..4db04e4 --- /dev/null +++ b/tests/channel_routing.rs @@ -0,0 +1,310 @@ +//! TG3: Channel Message Identity & Routing Tests +//! +//! Prevents: Pattern 3 — Channel message routing & identity bugs (17% of user bugs). +//! Issues: #496, #483, #620, #415, #503 +//! +//! Tests that ChannelMessage fields are used consistently and that the +//! SendMessage → Channel trait contract preserves correct identity semantics. +//! Verifies sender/reply_target field contracts to prevent field swaps. + +use async_trait::async_trait; +use zeroclaw::channels::traits::{Channel, ChannelMessage, SendMessage}; + +// ───────────────────────────────────────────────────────────────────────────── +// ChannelMessage construction and field semantics +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn channel_message_sender_field_holds_platform_user_id() { + // Simulates Telegram: sender should be numeric chat_id, not username + let msg = ChannelMessage { + id: "msg_1".into(), + sender: "123456789".into(), // numeric chat_id + reply_target: "msg_0".into(), + content: "test message".into(), + channel: "telegram".into(), + timestamp: 1700000000, + }; + + assert_eq!(msg.sender, "123456789"); + // Sender should be the platform-level user/chat identifier + assert!( + msg.sender.chars().all(|c| c.is_ascii_digit()), + "Telegram sender should be numeric chat_id, got: {}", + msg.sender + ); +} + +#[test] +fn channel_message_reply_target_distinct_from_sender() { + // Simulates Discord: reply_target should be channel_id, not sender user_id + let msg = ChannelMessage { + id: "msg_1".into(), + sender: "user_987654".into(), // Discord user ID + reply_target: "channel_123".into(), // Discord channel ID for replies + content: "test message".into(), + channel: "discord".into(), + timestamp: 1700000000, + }; + + assert_ne!( + msg.sender, msg.reply_target, + "sender and reply_target should be distinct for Discord" + ); + assert_eq!(msg.reply_target, "channel_123"); +} + +#[test] +fn channel_message_fields_not_swapped() { + // Guards against #496 (Telegram) and #483 (Discord) field swap bugs + let msg = ChannelMessage { + id: "msg_42".into(), + sender: "sender_value".into(), + reply_target: "target_value".into(), + content: "payload".into(), + channel: "test".into(), + timestamp: 1700000000, + }; + + assert_eq!(msg.sender, "sender_value", "sender field should not be swapped"); + assert_eq!( + msg.reply_target, "target_value", + "reply_target field should not be swapped" + ); + assert_ne!( + msg.sender, msg.reply_target, + "sender and reply_target should remain distinct" + ); +} + +#[test] +fn channel_message_preserves_all_fields_on_clone() { + let original = ChannelMessage { + id: "clone_test".into(), + sender: "sender_123".into(), + reply_target: "target_456".into(), + content: "cloned content".into(), + channel: "test_channel".into(), + timestamp: 1700000001, + }; + + let cloned = original.clone(); + + assert_eq!(cloned.id, original.id); + assert_eq!(cloned.sender, original.sender); + assert_eq!(cloned.reply_target, original.reply_target); + assert_eq!(cloned.content, original.content); + assert_eq!(cloned.channel, original.channel); + assert_eq!(cloned.timestamp, original.timestamp); +} + +// ───────────────────────────────────────────────────────────────────────────── +// SendMessage construction +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn send_message_new_sets_content_and_recipient() { + let msg = SendMessage::new("Hello", "recipient_123"); + + assert_eq!(msg.content, "Hello"); + assert_eq!(msg.recipient, "recipient_123"); + assert!(msg.subject.is_none(), "subject should be None by default"); +} + +#[test] +fn send_message_with_subject_sets_all_fields() { + let msg = SendMessage::with_subject("Hello", "recipient_123", "Re: Test"); + + assert_eq!(msg.content, "Hello"); + assert_eq!(msg.recipient, "recipient_123"); + assert_eq!(msg.subject.as_deref(), Some("Re: Test")); +} + +#[test] +fn send_message_recipient_carries_platform_target() { + // Verifies that SendMessage::recipient is used as the platform delivery target + // For Telegram: this should be the chat_id + // For Discord: this should be the channel_id + let telegram_msg = SendMessage::new("response", "123456789"); + assert_eq!( + telegram_msg.recipient, "123456789", + "Telegram SendMessage recipient should be chat_id" + ); + + let discord_msg = SendMessage::new("response", "channel_987654"); + assert_eq!( + discord_msg.recipient, "channel_987654", + "Discord SendMessage recipient should be channel_id" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Channel trait contract: send/listen roundtrip via DummyChannel +// ───────────────────────────────────────────────────────────────────────────── + +/// Test channel that captures sent messages for assertion +struct CapturingChannel { + sent: std::sync::Mutex>, +} + +impl CapturingChannel { + fn new() -> Self { + Self { + sent: std::sync::Mutex::new(Vec::new()), + } + } + + fn sent_messages(&self) -> Vec { + self.sent.lock().unwrap().clone() + } +} + +#[async_trait] +impl Channel for CapturingChannel { + fn name(&self) -> &str { + "capturing" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + self.sent.lock().unwrap().push(message.clone()); + Ok(()) + } + + async fn listen( + &self, + tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + tx.send(ChannelMessage { + id: "listen_1".into(), + sender: "test_sender".into(), + reply_target: "test_target".into(), + content: "incoming".into(), + channel: "capturing".into(), + timestamp: 1700000000, + }) + .await + .map_err(|e| anyhow::anyhow!(e.to_string())) + } +} + +#[tokio::test] +async fn channel_send_preserves_recipient() { + let channel = CapturingChannel::new(); + let msg = SendMessage::new("Hello", "target_123"); + + channel.send(&msg).await.unwrap(); + + let sent = channel.sent_messages(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0].recipient, "target_123"); + assert_eq!(sent[0].content, "Hello"); +} + +#[tokio::test] +async fn channel_listen_produces_correct_identity_fields() { + let channel = CapturingChannel::new(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + + channel.listen(tx).await.unwrap(); + let received = rx.recv().await.expect("should receive message"); + + assert_eq!(received.sender, "test_sender"); + assert_eq!(received.reply_target, "test_target"); + assert_ne!( + received.sender, received.reply_target, + "listen() should populate sender and reply_target distinctly" + ); +} + +#[tokio::test] +async fn channel_send_reply_uses_sender_from_listen() { + let channel = CapturingChannel::new(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + + // Simulate: listen() → receive message → send reply using sender + channel.listen(tx).await.unwrap(); + let incoming = rx.recv().await.expect("should receive message"); + + // Reply should go to the reply_target, not sender + let reply = SendMessage::new("reply content", &incoming.reply_target); + channel.send(&reply).await.unwrap(); + + let sent = channel.sent_messages(); + assert_eq!(sent.len(), 1); + assert_eq!( + sent[0].recipient, "test_target", + "reply should use reply_target as recipient" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Channel trait default methods +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn channel_health_check_default_returns_true() { + let channel = CapturingChannel::new(); + assert!( + channel.health_check().await, + "default health_check should return true" + ); +} + +#[tokio::test] +async fn channel_typing_defaults_succeed() { + let channel = CapturingChannel::new(); + assert!(channel.start_typing("target").await.is_ok()); + assert!(channel.stop_typing("target").await.is_ok()); +} + +#[tokio::test] +async fn channel_draft_defaults() { + let channel = CapturingChannel::new(); + assert!(!channel.supports_draft_updates()); + + let draft_result = channel + .send_draft(&SendMessage::new("draft", "target")) + .await + .unwrap(); + assert!(draft_result.is_none(), "default send_draft should return None"); + + assert!(channel + .update_draft("target", "msg_1", "updated") + .await + .is_ok()); + assert!(channel + .finalize_draft("target", "msg_1", "final") + .await + .is_ok()); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Multiple messages: conversation context preservation +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn channel_multiple_sends_preserve_order_and_recipients() { + let channel = CapturingChannel::new(); + + channel + .send(&SendMessage::new("msg 1", "target_a")) + .await + .unwrap(); + channel + .send(&SendMessage::new("msg 2", "target_b")) + .await + .unwrap(); + channel + .send(&SendMessage::new("msg 3", "target_a")) + .await + .unwrap(); + + let sent = channel.sent_messages(); + assert_eq!(sent.len(), 3); + assert_eq!(sent[0].recipient, "target_a"); + assert_eq!(sent[1].recipient, "target_b"); + assert_eq!(sent[2].recipient, "target_a"); + assert_eq!(sent[0].content, "msg 1"); + assert_eq!(sent[1].content, "msg 2"); + assert_eq!(sent[2].content, "msg 3"); +} diff --git a/tests/config_persistence.rs b/tests/config_persistence.rs new file mode 100644 index 0000000..edeef89 --- /dev/null +++ b/tests/config_persistence.rs @@ -0,0 +1,245 @@ +//! TG2: Config Load/Save Round-Trip Tests +//! +//! Prevents: Pattern 2 — Config persistence & workspace discovery bugs (13% of user bugs). +//! Issues: #547, #417, #621, #802 +//! +//! Tests Config::load_or_init() with isolated temp directories, env var overrides, +//! and config file round-trips to verify workspace discovery and persistence. + +use std::fs; +use zeroclaw::config::{AgentConfig, Config, MemoryConfig}; + +// ───────────────────────────────────────────────────────────────────────────── +// Config default construction +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn config_default_has_expected_provider() { + let config = Config::default(); + assert!( + config.default_provider.is_some(), + "default config should have a default_provider" + ); +} + +#[test] +fn config_default_has_expected_model() { + let config = Config::default(); + assert!( + config.default_model.is_some(), + "default config should have a default_model" + ); +} + +#[test] +fn config_default_temperature_positive() { + let config = Config::default(); + assert!( + config.default_temperature > 0.0, + "default temperature should be positive" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// AgentConfig defaults +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn agent_config_default_max_tool_iterations() { + let agent = AgentConfig::default(); + assert_eq!( + agent.max_tool_iterations, 10, + "default max_tool_iterations should be 10" + ); +} + +#[test] +fn agent_config_default_max_history_messages() { + let agent = AgentConfig::default(); + assert_eq!( + agent.max_history_messages, 50, + "default max_history_messages should be 50" + ); +} + +#[test] +fn agent_config_default_tool_dispatcher() { + let agent = AgentConfig::default(); + assert_eq!( + agent.tool_dispatcher, "auto", + "default tool_dispatcher should be 'auto'" + ); +} + +#[test] +fn agent_config_default_compact_context_off() { + let agent = AgentConfig::default(); + assert!( + !agent.compact_context, + "compact_context should default to false" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// MemoryConfig defaults +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn memory_config_default_backend() { + let memory = MemoryConfig::default(); + assert!( + !memory.backend.is_empty(), + "memory backend should have a default value" + ); +} + +#[test] +fn memory_config_default_embedding_provider() { + let memory = MemoryConfig::default(); + // Default embedding_provider should be set (even if "none") + assert!( + !memory.embedding_provider.is_empty(), + "embedding_provider should have a default value" + ); +} + +#[test] +fn memory_config_default_vector_keyword_weights_sum_to_one() { + let memory = MemoryConfig::default(); + let sum = memory.vector_weight + memory.keyword_weight; + assert!( + (sum - 1.0).abs() < 0.01, + "vector_weight + keyword_weight should sum to ~1.0, got {sum}" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Config TOML serialization round-trip +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn config_toml_roundtrip_preserves_provider() { + let mut config = Config::default(); + config.default_provider = Some("deepseek".into()); + config.default_model = Some("deepseek-chat".into()); + config.default_temperature = 0.5; + + let toml_str = toml::to_string(&config).expect("config should serialize to TOML"); + let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back"); + + assert_eq!(parsed.default_provider.as_deref(), Some("deepseek")); + assert_eq!(parsed.default_model.as_deref(), Some("deepseek-chat")); + assert!((parsed.default_temperature - 0.5).abs() < f64::EPSILON); +} + +#[test] +fn config_toml_roundtrip_preserves_agent_config() { + let mut config = Config::default(); + config.agent.max_tool_iterations = 5; + config.agent.max_history_messages = 25; + config.agent.compact_context = true; + + let toml_str = toml::to_string(&config).expect("config should serialize to TOML"); + let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back"); + + assert_eq!(parsed.agent.max_tool_iterations, 5); + assert_eq!(parsed.agent.max_history_messages, 25); + assert!(parsed.agent.compact_context); +} + +#[test] +fn config_toml_roundtrip_preserves_memory_config() { + let mut config = Config::default(); + config.memory.embedding_provider = "openai".into(); + config.memory.embedding_model = "text-embedding-3-small".into(); + config.memory.vector_weight = 0.8; + config.memory.keyword_weight = 0.2; + + let toml_str = toml::to_string(&config).expect("config should serialize to TOML"); + let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back"); + + assert_eq!(parsed.memory.embedding_provider, "openai"); + assert_eq!(parsed.memory.embedding_model, "text-embedding-3-small"); + assert!((parsed.memory.vector_weight - 0.8).abs() < f64::EPSILON); + assert!((parsed.memory.keyword_weight - 0.2).abs() < f64::EPSILON); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Config file write/read round-trip with tempdir +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn config_file_write_read_roundtrip() { + let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed"); + let config_path = tmp.path().join("config.toml"); + + let mut config = Config::default(); + config.default_provider = Some("mistral".into()); + config.default_model = Some("mistral-large".into()); + config.agent.max_tool_iterations = 15; + + let toml_str = toml::to_string(&config).expect("config should serialize"); + fs::write(&config_path, &toml_str).expect("config file write should succeed"); + + let read_back = fs::read_to_string(&config_path).expect("config file read should succeed"); + let parsed: Config = toml::from_str(&read_back).expect("TOML should parse back"); + + assert_eq!(parsed.default_provider.as_deref(), Some("mistral")); + assert_eq!(parsed.default_model.as_deref(), Some("mistral-large")); + assert_eq!(parsed.agent.max_tool_iterations, 15); +} + +#[test] +fn config_file_with_missing_optional_fields_uses_defaults() { + // Simulate a minimal config TOML that omits optional sections + let minimal_toml = r#" +default_temperature = 0.7 +"#; + let parsed: Config = toml::from_str(minimal_toml).expect("minimal TOML should parse"); + + // Agent config should use defaults + assert_eq!(parsed.agent.max_tool_iterations, 10); + assert_eq!(parsed.agent.max_history_messages, 50); + assert!(!parsed.agent.compact_context); +} + +#[test] +fn config_file_with_custom_agent_section() { + let toml_with_agent = r#" +default_temperature = 0.7 + +[agent] +max_tool_iterations = 3 +compact_context = true +"#; + let parsed: Config = + toml::from_str(toml_with_agent).expect("TOML with agent section should parse"); + + assert_eq!(parsed.agent.max_tool_iterations, 3); + assert!(parsed.agent.compact_context); + // max_history_messages should still use default + assert_eq!(parsed.agent.max_history_messages, 50); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Workspace directory creation +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn workspace_dir_creation_in_tempdir() { + let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed"); + let workspace_dir = tmp.path().join("workspace"); + + fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed"); + assert!(workspace_dir.exists(), "workspace dir should exist"); + assert!(workspace_dir.is_dir(), "workspace path should be a directory"); +} + +#[test] +fn nested_workspace_dir_creation() { + let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed"); + let nested_dir = tmp.path().join("deep").join("nested").join("workspace"); + + fs::create_dir_all(&nested_dir).expect("nested dir creation should succeed"); + assert!(nested_dir.exists(), "nested workspace dir should exist"); +} diff --git a/tests/memory_restart.rs b/tests/memory_restart.rs new file mode 100644 index 0000000..7538ab3 --- /dev/null +++ b/tests/memory_restart.rs @@ -0,0 +1,335 @@ +//! TG5: Memory Restart Resilience Tests +//! +//! Prevents: Pattern 5 — Memory & state persistence bugs (10% of user bugs). +//! Issues: #430, #693, #802 +//! +//! Tests SqliteMemory deduplication on restart, session scoping, concurrent +//! message ordering, and recall behavior after re-initialization. + +use std::sync::Arc; +use zeroclaw::memory::sqlite::SqliteMemory; +use zeroclaw::memory::traits::{Memory, MemoryCategory}; + +// ───────────────────────────────────────────────────────────────────────────── +// Deduplication: same key overwrites instead of duplicating (#430) +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_store_same_key_deduplicates() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + // Store same key twice with different content + mem.store("greeting", "hello world", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("greeting", "hello updated", MemoryCategory::Core, None) + .await + .unwrap(); + + // Should have exactly 1 entry, not 2 + let count = mem.count().await.unwrap(); + assert_eq!(count, 1, "storing same key twice should not create duplicates"); + + // Content should be the latest version + let entry = mem.get("greeting").await.unwrap().expect("entry should exist"); + assert_eq!(entry.content, "hello updated"); +} + +#[tokio::test] +async fn sqlite_memory_store_different_keys_creates_separate_entries() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("key_a", "content a", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("key_b", "content b", MemoryCategory::Core, None) + .await + .unwrap(); + + let count = mem.count().await.unwrap(); + assert_eq!(count, 2, "different keys should create separate entries"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Restart resilience: data persists across memory re-initialization +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_persists_across_reinitialization() { + let tmp = tempfile::TempDir::new().unwrap(); + + // First "session": store data + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("persistent_fact", "Rust is great", MemoryCategory::Core, None) + .await + .unwrap(); + } + + // Second "session": re-create memory from same path + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + let entry = mem + .get("persistent_fact") + .await + .unwrap() + .expect("entry should survive reinitialization"); + assert_eq!(entry.content, "Rust is great"); + } +} + +#[tokio::test] +async fn sqlite_memory_restart_does_not_duplicate_on_rewrite() { + let tmp = tempfile::TempDir::new().unwrap(); + + // First session: store entries + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("fact_1", "original content", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("fact_2", "another fact", MemoryCategory::Core, None) + .await + .unwrap(); + } + + // Second session: re-store same keys (simulates channel re-reading history) + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("fact_1", "original content", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("fact_2", "another fact", MemoryCategory::Core, None) + .await + .unwrap(); + + let count = mem.count().await.unwrap(); + assert_eq!( + count, 2, + "re-storing same keys after restart should not create duplicates" + ); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Session scoping: messages scoped to sessions don't leak +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_session_scoped_store_and_recall() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + // Store in different sessions + mem.store( + "session_a_fact", + "fact from session A", + MemoryCategory::Conversation, + Some("session_a"), + ) + .await + .unwrap(); + mem.store( + "session_b_fact", + "fact from session B", + MemoryCategory::Conversation, + Some("session_b"), + ) + .await + .unwrap(); + + // List scoped to session_a + let session_a_entries = mem + .list(Some(&MemoryCategory::Conversation), Some("session_a")) + .await + .unwrap(); + assert_eq!( + session_a_entries.len(), + 1, + "session_a should have exactly 1 entry" + ); + assert_eq!(session_a_entries[0].content, "fact from session A"); +} + +#[tokio::test] +async fn sqlite_memory_global_recall_includes_all_sessions() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("global_a", "alpha content", MemoryCategory::Core, Some("s1")) + .await + .unwrap(); + mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2")) + .await + .unwrap(); + + // Global count should include all + let count = mem.count().await.unwrap(); + assert_eq!(count, 2, "global count should include entries from all sessions"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Recall and search behavior +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_recall_returns_relevant_results() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("lang_pref", "User prefers Rust programming", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("food_pref", "User likes sushi for lunch", MemoryCategory::Core, None) + .await + .unwrap(); + + let results = mem.recall("Rust programming", 10, None).await.unwrap(); + assert!(!results.is_empty(), "recall should find matching entries"); + // The Rust-related entry should be in results + assert!( + results.iter().any(|e| e.content.contains("Rust")), + "recall for 'Rust' should include the Rust-related entry" + ); +} + +#[tokio::test] +async fn sqlite_memory_recall_respects_limit() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + for i in 0..10 { + mem.store( + &format!("entry_{i}"), + &format!("test content number {i}"), + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + } + + let results = mem.recall("test content", 3, None).await.unwrap(); + assert!( + results.len() <= 3, + "recall should respect limit of 3, got {}", + results.len() + ); +} + +#[tokio::test] +async fn sqlite_memory_recall_empty_query_returns_empty() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("fact", "some content", MemoryCategory::Core, None) + .await + .unwrap(); + + let results = mem.recall("", 10, None).await.unwrap(); + assert!( + results.is_empty(), + "empty query should return no results" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Forget and health check +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_forget_removes_entry() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("to_forget", "temporary info", MemoryCategory::Core, None) + .await + .unwrap(); + assert_eq!(mem.count().await.unwrap(), 1); + + let removed = mem.forget("to_forget").await.unwrap(); + assert!(removed, "forget should return true for existing key"); + assert_eq!(mem.count().await.unwrap(), 0); +} + +#[tokio::test] +async fn sqlite_memory_forget_nonexistent_returns_false() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + let removed = mem.forget("nonexistent_key").await.unwrap(); + assert!(!removed, "forget should return false for nonexistent key"); +} + +#[tokio::test] +async fn sqlite_memory_health_check_returns_true() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + assert!(mem.health_check().await, "health_check should return true"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Concurrent access +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_concurrent_stores_no_data_loss() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = Arc::new(SqliteMemory::new(tmp.path()).unwrap()); + + let mut handles = Vec::new(); + for i in 0..5 { + let mem_clone = mem.clone(); + handles.push(tokio::spawn(async move { + mem_clone + .store( + &format!("concurrent_{i}"), + &format!("content from task {i}"), + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let count = mem.count().await.unwrap(); + assert_eq!( + count, 5, + "all concurrent stores should succeed, got {count}" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Memory categories +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_list_by_category() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("core_fact", "core info", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("daily_note", "daily note", MemoryCategory::Daily, None) + .await + .unwrap(); + mem.store("conv_msg", "conversation msg", MemoryCategory::Conversation, None) + .await + .unwrap(); + + let core_entries = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); + assert_eq!(core_entries.len(), 1, "should have 1 Core entry"); + assert_eq!(core_entries[0].key, "core_fact"); + + let daily_entries = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + assert_eq!(daily_entries.len(), 1, "should have 1 Daily entry"); +} diff --git a/tests/provider_resolution.rs b/tests/provider_resolution.rs new file mode 100644 index 0000000..c88fa93 --- /dev/null +++ b/tests/provider_resolution.rs @@ -0,0 +1,244 @@ +//! TG1: Provider End-to-End Resolution Tests +//! +//! Prevents: Pattern 1 — Provider configuration & resolution bugs (27% of user bugs). +//! Issues: #831, #834, #721, #580, #452, #451, #796, #843 +//! +//! Tests the full pipeline from config values through `create_provider_with_url()` +//! to provider construction, verifying factory resolution, URL construction, +//! credential wiring, and auth header format. + +use zeroclaw::providers::compatible::{AuthStyle, OpenAiCompatibleProvider}; +use zeroclaw::providers::{create_provider, create_provider_with_url}; + +/// Helper: assert provider creation succeeds +fn assert_provider_ok(name: &str, key: Option<&str>, url: Option<&str>) { + let result = create_provider_with_url(name, key, url); + assert!( + result.is_ok(), + "{name} provider should resolve: {}", + result.err().map(|e| e.to_string()).unwrap_or_default() + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Factory resolution: each major provider name resolves without error +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_resolves_openai_provider() { + assert_provider_ok("openai", Some("test-key"), None); +} + +#[test] +fn factory_resolves_anthropic_provider() { + assert_provider_ok("anthropic", Some("test-key"), None); +} + +#[test] +fn factory_resolves_deepseek_provider() { + assert_provider_ok("deepseek", Some("test-key"), None); +} + +#[test] +fn factory_resolves_mistral_provider() { + assert_provider_ok("mistral", Some("test-key"), None); +} + +#[test] +fn factory_resolves_ollama_provider() { + assert_provider_ok("ollama", None, None); +} + +#[test] +fn factory_resolves_groq_provider() { + assert_provider_ok("groq", Some("test-key"), None); +} + +#[test] +fn factory_resolves_xai_provider() { + assert_provider_ok("xai", Some("test-key"), None); +} + +#[test] +fn factory_resolves_together_provider() { + assert_provider_ok("together", Some("test-key"), None); +} + +#[test] +fn factory_resolves_fireworks_provider() { + assert_provider_ok("fireworks", Some("test-key"), None); +} + +#[test] +fn factory_resolves_perplexity_provider() { + assert_provider_ok("perplexity", Some("test-key"), None); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Factory resolution: alias variants map to same provider +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_grok_alias_resolves_to_xai() { + assert_provider_ok("grok", Some("test-key"), None); +} + +#[test] +fn factory_kimi_alias_resolves_to_moonshot() { + assert_provider_ok("kimi", Some("test-key"), None); +} + +#[test] +fn factory_zhipu_alias_resolves_to_glm() { + assert_provider_ok("zhipu", Some("test-key"), None); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Custom URL provider creation +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_custom_http_url_resolves() { + assert_provider_ok("custom:http://localhost:8080", Some("test-key"), None); +} + +#[test] +fn factory_custom_https_url_resolves() { + assert_provider_ok("custom:https://api.example.com/v1", Some("test-key"), None); +} + +#[test] +fn factory_custom_ftp_url_rejected() { + let result = create_provider_with_url("custom:ftp://example.com", None, None); + assert!(result.is_err(), "ftp scheme should be rejected"); + let err_msg = result.err().unwrap().to_string(); + assert!( + err_msg.contains("http://") || err_msg.contains("https://"), + "error should mention valid schemes: {err_msg}" + ); +} + +#[test] +fn factory_custom_empty_url_rejected() { + let result = create_provider_with_url("custom:", None, None); + assert!(result.is_err(), "empty custom URL should be rejected"); +} + +#[test] +fn factory_unknown_provider_rejected() { + let result = create_provider_with_url("nonexistent_provider_xyz", None, None); + assert!(result.is_err(), "unknown provider name should be rejected"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// OpenAiCompatibleProvider: credential and auth style wiring +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn compatible_provider_bearer_auth_style() { + // Construction with Bearer auth should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com", + Some("sk-test-key-12345"), + AuthStyle::Bearer, + ); +} + +#[test] +fn compatible_provider_xapikey_auth_style() { + // Construction with XApiKey auth should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com", + Some("sk-test-key-12345"), + AuthStyle::XApiKey, + ); +} + +#[test] +fn compatible_provider_custom_auth_header() { + // Construction with Custom auth should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com", + Some("sk-test-key-12345"), + AuthStyle::Custom("X-Custom-Auth".into()), + ); +} + +#[test] +fn compatible_provider_no_credential() { + // Construction without credential should succeed (for local providers) + let _provider = OpenAiCompatibleProvider::new( + "TestLocal", + "http://localhost:11434", + None, + AuthStyle::Bearer, + ); +} + +#[test] +fn compatible_provider_base_url_trailing_slash_normalized() { + // Construction with trailing slash URL should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com/v1/", + Some("key"), + AuthStyle::Bearer, + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Provider with api_url override (simulates #721 - Ollama api_url config) +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_ollama_with_custom_api_url() { + assert_provider_ok("ollama", None, Some("http://192.168.1.100:11434")); +} + +#[test] +fn factory_openai_with_custom_api_url() { + assert_provider_ok( + "openai", + Some("test-key"), + Some("https://custom-openai-proxy.example.com/v1"), + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Provider default convenience factory +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn convenience_factory_resolves_major_providers() { + for provider_name in &[ + "openai", + "anthropic", + "deepseek", + "mistral", + "groq", + "xai", + "together", + "fireworks", + "perplexity", + ] { + let result = create_provider(provider_name, Some("test-key")); + assert!( + result.is_ok(), + "convenience factory should resolve {provider_name}: {}", + result.err().map(|e| e.to_string()).unwrap_or_default() + ); + } +} + +#[test] +fn convenience_factory_ollama_no_key() { + let result = create_provider("ollama", None); + assert!( + result.is_ok(), + "ollama should not require api key: {}", + result.err().map(|e| e.to_string()).unwrap_or_default() + ); +} diff --git a/tests/provider_schema.rs b/tests/provider_schema.rs new file mode 100644 index 0000000..bc3aa67 --- /dev/null +++ b/tests/provider_schema.rs @@ -0,0 +1,303 @@ +//! TG7: Provider Schema Conformance Tests +//! +//! Prevents: Pattern 7 — External schema compatibility bugs (7% of user bugs). +//! Issues: #769, #843 +//! +//! Tests request/response serialization to verify required fields are present +//! for each provider's API specification. Validates ChatMessage, ChatResponse, +//! ToolCall, and AuthStyle serialization contracts. + +use zeroclaw::providers::compatible::AuthStyle; +use zeroclaw::providers::traits::{ChatMessage, ChatResponse, ToolCall}; + +// ───────────────────────────────────────────────────────────────────────────── +// ChatMessage serialization +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn chat_message_system_role_correct() { + let msg = ChatMessage::system("You are a helpful assistant"); + assert_eq!(msg.role, "system"); + assert_eq!(msg.content, "You are a helpful assistant"); +} + +#[test] +fn chat_message_user_role_correct() { + let msg = ChatMessage::user("Hello"); + assert_eq!(msg.role, "user"); + assert_eq!(msg.content, "Hello"); +} + +#[test] +fn chat_message_assistant_role_correct() { + let msg = ChatMessage::assistant("Hi there!"); + assert_eq!(msg.role, "assistant"); + assert_eq!(msg.content, "Hi there!"); +} + +#[test] +fn chat_message_tool_role_correct() { + let msg = ChatMessage::tool("tool result"); + assert_eq!(msg.role, "tool"); + assert_eq!(msg.content, "tool result"); +} + +#[test] +fn chat_message_serializes_to_json_with_required_fields() { + let msg = ChatMessage::user("test message"); + let json = serde_json::to_value(&msg).unwrap(); + + assert!(json.get("role").is_some(), "JSON must have 'role' field"); + assert!( + json.get("content").is_some(), + "JSON must have 'content' field" + ); + assert_eq!(json["role"], "user"); + assert_eq!(json["content"], "test message"); +} + +#[test] +fn chat_message_json_roundtrip() { + let original = ChatMessage::assistant("response text"); + let json_str = serde_json::to_string(&original).unwrap(); + let parsed: ChatMessage = serde_json::from_str(&json_str).unwrap(); + + assert_eq!(parsed.role, original.role); + assert_eq!(parsed.content, original.content); +} + +// ───────────────────────────────────────────────────────────────────────────── +// ToolCall serialization (#843 - tool_call_id field) +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn tool_call_has_required_fields() { + let tc = ToolCall { + id: "call_abc123".into(), + name: "web_search".into(), + arguments: r#"{"query": "rust programming"}"#.into(), + }; + + let json = serde_json::to_value(&tc).unwrap(); + assert!(json.get("id").is_some(), "ToolCall must have 'id' field"); + assert!(json.get("name").is_some(), "ToolCall must have 'name' field"); + assert!( + json.get("arguments").is_some(), + "ToolCall must have 'arguments' field" + ); +} + +#[test] +fn tool_call_id_preserved_in_serialization() { + let tc = ToolCall { + id: "call_deepseek_42".into(), + name: "shell".into(), + arguments: r#"{"command": "ls"}"#.into(), + }; + + let json_str = serde_json::to_string(&tc).unwrap(); + let parsed: ToolCall = serde_json::from_str(&json_str).unwrap(); + + assert_eq!(parsed.id, "call_deepseek_42", "tool_call_id must survive roundtrip"); + assert_eq!(parsed.name, "shell"); +} + +#[test] +fn tool_call_arguments_contain_valid_json() { + let tc = ToolCall { + id: "call_1".into(), + name: "file_write".into(), + arguments: r#"{"path": "/tmp/test.txt", "content": "hello"}"#.into(), + }; + + // Arguments should parse as valid JSON + let args: serde_json::Value = serde_json::from_str(&tc.arguments) + .expect("tool call arguments should be valid JSON"); + assert!(args.get("path").is_some()); + assert!(args.get("content").is_some()); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool message with tool_call_id (DeepSeek requirement) +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn tool_response_message_can_embed_tool_call_id() { + // DeepSeek requires tool_call_id in tool response messages. + // The tool message content can embed the tool_call_id as JSON. + let tool_response = ChatMessage::tool( + r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#, + ); + + let parsed: serde_json::Value = serde_json::from_str(&tool_response.content) + .expect("tool response content should be valid JSON"); + + assert!( + parsed.get("tool_call_id").is_some(), + "tool response should include tool_call_id for DeepSeek compatibility" + ); + assert_eq!(parsed["tool_call_id"], "call_abc123"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// ChatResponse structure +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn chat_response_text_only() { + let resp = ChatResponse { + text: Some("Hello world".into()), + tool_calls: vec![], + }; + + assert_eq!(resp.text_or_empty(), "Hello world"); + assert!(!resp.has_tool_calls()); +} + +#[test] +fn chat_response_with_tool_calls() { + let resp = ChatResponse { + text: Some(String::new()), + tool_calls: vec![ToolCall { + id: "tc_1".into(), + name: "echo".into(), + arguments: "{}".into(), + }], + }; + + assert!(resp.has_tool_calls()); + assert_eq!(resp.tool_calls.len(), 1); + assert_eq!(resp.tool_calls[0].name, "echo"); +} + +#[test] +fn chat_response_text_or_empty_handles_none() { + let resp = ChatResponse { + text: None, + tool_calls: vec![], + }; + + assert_eq!(resp.text_or_empty(), ""); +} + +#[test] +fn chat_response_multiple_tool_calls() { + let resp = ChatResponse { + text: None, + tool_calls: vec![ + ToolCall { + id: "tc_1".into(), + name: "shell".into(), + arguments: r#"{"command": "ls"}"#.into(), + }, + ToolCall { + id: "tc_2".into(), + name: "file_read".into(), + arguments: r#"{"path": "test.txt"}"#.into(), + }, + ], + }; + + assert!(resp.has_tool_calls()); + assert_eq!(resp.tool_calls.len(), 2); + // Each tool call should have a distinct id + assert_ne!(resp.tool_calls[0].id, resp.tool_calls[1].id); +} + +// ───────────────────────────────────────────────────────────────────────────── +// AuthStyle variants +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn auth_style_bearer_is_constructible() { + let style = AuthStyle::Bearer; + assert!(matches!(style, AuthStyle::Bearer)); +} + +#[test] +fn auth_style_xapikey_is_constructible() { + let style = AuthStyle::XApiKey; + assert!(matches!(style, AuthStyle::XApiKey)); +} + +#[test] +fn auth_style_custom_header() { + let style = AuthStyle::Custom("X-Custom-Auth".into()); + if let AuthStyle::Custom(header) = style { + assert_eq!(header, "X-Custom-Auth"); + } else { + panic!("expected AuthStyle::Custom"); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Provider naming consistency +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn provider_construction_with_different_names() { + use zeroclaw::providers::compatible::OpenAiCompatibleProvider; + + // Construction with various names should succeed + let _p1 = OpenAiCompatibleProvider::new( + "DeepSeek", + "https://api.deepseek.com", + Some("test-key"), + AuthStyle::Bearer, + ); + let _p2 = OpenAiCompatibleProvider::new( + "deepseek", + "https://api.test.com", + None, + AuthStyle::Bearer, + ); +} + +#[test] +fn provider_construction_with_different_auth_styles() { + use zeroclaw::providers::compatible::OpenAiCompatibleProvider; + + let _bearer = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Bearer); + let _xapi = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::XApiKey); + let _custom = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Custom("X-My-Auth".into())); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Conversation history message ordering +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn chat_messages_maintain_role_sequence() { + let history = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("What is Rust?"), + ChatMessage::assistant("Rust is a systems programming language"), + ChatMessage::user("Tell me more"), + ChatMessage::assistant("It emphasizes safety and performance"), + ]; + + assert_eq!(history[0].role, "system"); + assert_eq!(history[1].role, "user"); + assert_eq!(history[2].role, "assistant"); + assert_eq!(history[3].role, "user"); + assert_eq!(history[4].role, "assistant"); +} + +#[test] +fn chat_messages_with_tool_calls_maintain_sequence() { + let history = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("Search for Rust"), + ChatMessage::assistant("I'll search for that"), + ChatMessage::tool(r#"{"tool_call_id": "tc_1", "content": "search results"}"#), + ChatMessage::assistant("Based on the search results..."), + ]; + + assert_eq!(history.len(), 5); + assert_eq!(history[3].role, "tool"); + assert_eq!(history[4].role, "assistant"); + + // Verify tool message content is valid JSON with tool_call_id + let tool_content: serde_json::Value = serde_json::from_str(&history[3].content).unwrap(); + assert!(tool_content.get("tool_call_id").is_some()); +}