Merge pull request #865 from agorevski/feat/systematic-test-coverage-852
test: add systematic test coverage for 7 bug pattern groups (#852)
This commit is contained in:
commit
dce7280812
9 changed files with 2272 additions and 8 deletions
|
|
@ -2737,4 +2737,185 @@ browser_open/url>https://example.com"#;
|
||||||
assert_eq!(calls[0].arguments["command"], "pwd");
|
assert_eq!(calls[0].arguments["command"], "pwd");
|
||||||
assert_eq!(text, "Done");
|
assert_eq!(text, "Done");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
// TG4 (inline): parse_tool_calls robustness — malformed/edge-case inputs
|
||||||
|
// Prevents: Pattern 4 issues #746, #418, #777, #848
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_empty_input_returns_empty() {
|
||||||
|
let (text, calls) = parse_tool_calls("");
|
||||||
|
assert!(calls.is_empty(), "empty input should produce no tool calls");
|
||||||
|
assert!(text.is_empty(), "empty input should produce no text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_whitespace_only_returns_empty_calls() {
|
||||||
|
let (text, calls) = parse_tool_calls(" \n\t ");
|
||||||
|
assert!(calls.is_empty());
|
||||||
|
assert!(text.is_empty() || text.trim().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_nested_xml_tags_handled() {
|
||||||
|
// Double-wrapped tool call should still parse the inner call
|
||||||
|
let response = r#"<tool_call><tool_call>{"name":"echo","arguments":{"msg":"hi"}}</tool_call></tool_call>"#;
|
||||||
|
let (_text, calls) = parse_tool_calls(response);
|
||||||
|
// Should find at least one tool call
|
||||||
|
assert!(
|
||||||
|
!calls.is_empty(),
|
||||||
|
"nested XML tags should still yield at least one tool call"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_truncated_json_no_panic() {
|
||||||
|
// Incomplete JSON inside tool_call tags
|
||||||
|
let response = r#"<tool_call>{"name":"shell","arguments":{"command":"ls"</tool_call>"#;
|
||||||
|
let (_text, _calls) = parse_tool_calls(response);
|
||||||
|
// Should not panic — graceful handling of truncated JSON
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_empty_json_object_in_tag() {
|
||||||
|
let response = "<tool_call>{}</tool_call>";
|
||||||
|
let (_text, calls) = parse_tool_calls(response);
|
||||||
|
// Empty JSON object has no name field — should not produce valid tool call
|
||||||
|
assert!(
|
||||||
|
calls.is_empty(),
|
||||||
|
"empty JSON object should not produce a tool call"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_closing_tag_only_returns_text() {
|
||||||
|
let response = "Some text </tool_call> more text";
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(calls.is_empty(), "closing tag only should not produce calls");
|
||||||
|
assert!(
|
||||||
|
!text.is_empty(),
|
||||||
|
"text around orphaned closing tag should be preserved"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_very_large_arguments_no_panic() {
|
||||||
|
let large_arg = "x".repeat(100_000);
|
||||||
|
let response = format!(
|
||||||
|
r#"<tool_call>{{"name":"echo","arguments":{{"message":"{}"}}}}</tool_call>"#,
|
||||||
|
large_arg
|
||||||
|
);
|
||||||
|
let (_text, calls) = parse_tool_calls(&response);
|
||||||
|
assert_eq!(calls.len(), 1, "large arguments should still parse");
|
||||||
|
assert_eq!(calls[0].name, "echo");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_special_characters_in_arguments() {
|
||||||
|
let response = r#"<tool_call>{"name":"echo","arguments":{"message":"hello \"world\" <>&'\n\t"}}</tool_call>"#;
|
||||||
|
let (_text, calls) = parse_tool_calls(response);
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "echo");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_text_with_embedded_json_not_extracted() {
|
||||||
|
// Raw JSON without any tags should NOT be extracted as a tool call
|
||||||
|
let response = r#"Here is some data: {"name":"echo","arguments":{"message":"hi"}} end."#;
|
||||||
|
let (_text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(
|
||||||
|
calls.is_empty(),
|
||||||
|
"raw JSON in text without tags should not be extracted"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_multiple_formats_mixed() {
|
||||||
|
// Mix of text and properly tagged tool call
|
||||||
|
let response = r#"I'll help you with that.
|
||||||
|
|
||||||
|
<tool_call>
|
||||||
|
{"name":"shell","arguments":{"command":"echo hello"}}
|
||||||
|
</tool_call>
|
||||||
|
|
||||||
|
Let me check the result."#;
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert_eq!(calls.len(), 1, "should extract one tool call from mixed content");
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert!(
|
||||||
|
text.contains("help you"),
|
||||||
|
"text before tool call should be preserved"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
// TG4 (inline): scrub_credentials edge cases
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_credentials_empty_input() {
|
||||||
|
let result = scrub_credentials("");
|
||||||
|
assert_eq!(result, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_credentials_no_sensitive_data() {
|
||||||
|
let input = "normal text without any secrets";
|
||||||
|
let result = scrub_credentials(input);
|
||||||
|
assert_eq!(result, input, "non-sensitive text should pass through unchanged");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_credentials_short_values_not_redacted() {
|
||||||
|
// Values shorter than 8 chars should not be redacted
|
||||||
|
let input = r#"api_key="short""#;
|
||||||
|
let result = scrub_credentials(input);
|
||||||
|
assert_eq!(result, input, "short values should not be redacted");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
// TG4 (inline): trim_history edge cases
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trim_history_empty_history() {
|
||||||
|
let mut history: Vec<crate::providers::ChatMessage> = vec![];
|
||||||
|
trim_history(&mut history, 10);
|
||||||
|
assert!(history.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trim_history_system_only() {
|
||||||
|
let mut history = vec![crate::providers::ChatMessage::system("system prompt")];
|
||||||
|
trim_history(&mut history, 10);
|
||||||
|
assert_eq!(history.len(), 1);
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trim_history_exactly_at_limit() {
|
||||||
|
let mut history = vec![
|
||||||
|
crate::providers::ChatMessage::system("system"),
|
||||||
|
crate::providers::ChatMessage::user("msg 1"),
|
||||||
|
crate::providers::ChatMessage::assistant("reply 1"),
|
||||||
|
];
|
||||||
|
trim_history(&mut history, 2); // 2 non-system messages = exactly at limit
|
||||||
|
assert_eq!(history.len(), 3, "should not trim when exactly at limit");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trim_history_removes_oldest_non_system() {
|
||||||
|
let mut history = vec![
|
||||||
|
crate::providers::ChatMessage::system("system"),
|
||||||
|
crate::providers::ChatMessage::user("old msg"),
|
||||||
|
crate::providers::ChatMessage::assistant("old reply"),
|
||||||
|
crate::providers::ChatMessage::user("new msg"),
|
||||||
|
crate::providers::ChatMessage::assistant("new reply"),
|
||||||
|
];
|
||||||
|
trim_history(&mut history, 2);
|
||||||
|
assert_eq!(history.len(), 3); // system + 2 kept
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
assert_eq!(history[1].content, "new msg");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -843,4 +843,110 @@ mod tests {
|
||||||
// Should have UUID dashes
|
// Should have UUID dashes
|
||||||
assert!(id.contains('-'));
|
assert!(id.contains('-'));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
// TG6: Channel platform limit edge cases for Discord (2000 char limit)
|
||||||
|
// Prevents: Pattern 6 — issues #574, #499
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_code_block_at_boundary() {
|
||||||
|
// Code block that spans the split boundary
|
||||||
|
let mut msg = String::new();
|
||||||
|
msg.push_str("```rust\n");
|
||||||
|
msg.push_str(&"x".repeat(1990));
|
||||||
|
msg.push_str("\n```\nMore text after code block");
|
||||||
|
let parts = split_message_for_discord(&msg);
|
||||||
|
assert!(parts.len() >= 2, "code block spanning boundary should split");
|
||||||
|
for part in &parts {
|
||||||
|
assert!(
|
||||||
|
part.len() <= DISCORD_MAX_MESSAGE_LENGTH,
|
||||||
|
"each part must be <= {DISCORD_MAX_MESSAGE_LENGTH}, got {}",
|
||||||
|
part.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_single_long_word_exceeds_limit() {
|
||||||
|
// A single word longer than 2000 chars must be hard-split
|
||||||
|
let long_word = "a".repeat(2500);
|
||||||
|
let parts = split_message_for_discord(&long_word);
|
||||||
|
assert!(parts.len() >= 2, "word exceeding limit must be split");
|
||||||
|
for part in &parts {
|
||||||
|
assert!(
|
||||||
|
part.len() <= DISCORD_MAX_MESSAGE_LENGTH,
|
||||||
|
"hard-split part must be <= {DISCORD_MAX_MESSAGE_LENGTH}, got {}",
|
||||||
|
part.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// Reassembled content should match original
|
||||||
|
let reassembled: String = parts.join("");
|
||||||
|
assert_eq!(reassembled, long_word);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_exactly_at_limit_no_split() {
|
||||||
|
let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH);
|
||||||
|
let parts = split_message_for_discord(&msg);
|
||||||
|
assert_eq!(parts.len(), 1, "message exactly at limit should not split");
|
||||||
|
assert_eq!(parts[0].len(), DISCORD_MAX_MESSAGE_LENGTH);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_one_over_limit_splits() {
|
||||||
|
let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH + 1);
|
||||||
|
let parts = split_message_for_discord(&msg);
|
||||||
|
assert!(parts.len() >= 2, "message 1 char over limit must split");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_many_short_lines() {
|
||||||
|
// Many short lines should be batched into chunks under the limit
|
||||||
|
let msg: String = (0..500).map(|i| format!("line {i}\n")).collect();
|
||||||
|
let parts = split_message_for_discord(&msg);
|
||||||
|
for part in &parts {
|
||||||
|
assert!(
|
||||||
|
part.len() <= DISCORD_MAX_MESSAGE_LENGTH,
|
||||||
|
"short-line batch must be <= limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
// All content should be preserved
|
||||||
|
let reassembled: String = parts.join("");
|
||||||
|
assert_eq!(reassembled.trim(), msg.trim());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_only_whitespace() {
|
||||||
|
let msg = " \n\n\t ";
|
||||||
|
let parts = split_message_for_discord(msg);
|
||||||
|
// Should handle gracefully without panic
|
||||||
|
assert!(parts.len() <= 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_emoji_at_boundary() {
|
||||||
|
// Emoji are multi-byte; ensure we don't split mid-emoji
|
||||||
|
let mut msg = "a".repeat(1998);
|
||||||
|
msg.push_str("🎉🎊"); // 2 emoji at the boundary (2000 chars total)
|
||||||
|
let parts = split_message_for_discord(&msg);
|
||||||
|
for part in &parts {
|
||||||
|
// The function splits on character count, not byte count
|
||||||
|
assert!(
|
||||||
|
part.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH,
|
||||||
|
"emoji boundary split must respect limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn split_message_consecutive_newlines_at_boundary() {
|
||||||
|
let mut msg = "a".repeat(1995);
|
||||||
|
msg.push_str("\n\n\n\n\n");
|
||||||
|
msg.push_str(&"b".repeat(100));
|
||||||
|
let parts = split_message_for_discord(&msg);
|
||||||
|
for part in &parts {
|
||||||
|
assert!(part.len() <= DISCORD_MAX_MESSAGE_LENGTH);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ const TELEGRAM_BIND_COMMAND: &str = "/bind";
|
||||||
/// Split a message into chunks that respect Telegram's 4096 character limit.
|
/// Split a message into chunks that respect Telegram's 4096 character limit.
|
||||||
/// Tries to split at word boundaries when possible, and handles continuation.
|
/// Tries to split at word boundaries when possible, and handles continuation.
|
||||||
fn split_message_for_telegram(message: &str) -> Vec<String> {
|
fn split_message_for_telegram(message: &str) -> Vec<String> {
|
||||||
if message.len() <= TELEGRAM_MAX_MESSAGE_LENGTH {
|
if message.chars().count() <= TELEGRAM_MAX_MESSAGE_LENGTH {
|
||||||
return vec![message.to_string()];
|
return vec![message.to_string()];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -26,29 +26,35 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
|
||||||
let mut remaining = message;
|
let mut remaining = message;
|
||||||
|
|
||||||
while !remaining.is_empty() {
|
while !remaining.is_empty() {
|
||||||
let chunk_end = if remaining.len() <= TELEGRAM_MAX_MESSAGE_LENGTH {
|
// Find the byte offset for the Nth character boundary.
|
||||||
remaining.len()
|
let hard_split = remaining
|
||||||
|
.char_indices()
|
||||||
|
.nth(TELEGRAM_MAX_MESSAGE_LENGTH)
|
||||||
|
.map_or(remaining.len(), |(idx, _)| idx);
|
||||||
|
|
||||||
|
let chunk_end = if hard_split == remaining.len() {
|
||||||
|
hard_split
|
||||||
} else {
|
} else {
|
||||||
// Try to find a good break point (newline, then space)
|
// Try to find a good break point (newline, then space)
|
||||||
let search_area = &remaining[..TELEGRAM_MAX_MESSAGE_LENGTH];
|
let search_area = &remaining[..hard_split];
|
||||||
|
|
||||||
// Prefer splitting at newline
|
// Prefer splitting at newline
|
||||||
if let Some(pos) = search_area.rfind('\n') {
|
if let Some(pos) = search_area.rfind('\n') {
|
||||||
// Don't split if the newline is too close to the start
|
// Don't split if the newline is too close to the start
|
||||||
if pos >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 {
|
if search_area[..pos].chars().count() >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 {
|
||||||
pos + 1
|
pos + 1
|
||||||
} else {
|
} else {
|
||||||
// Try space as fallback
|
// Try space as fallback
|
||||||
search_area
|
search_area
|
||||||
.rfind(' ')
|
.rfind(' ')
|
||||||
.unwrap_or(TELEGRAM_MAX_MESSAGE_LENGTH)
|
.unwrap_or(hard_split)
|
||||||
+ 1
|
+ 1
|
||||||
}
|
}
|
||||||
} else if let Some(pos) = search_area.rfind(' ') {
|
} else if let Some(pos) = search_area.rfind(' ') {
|
||||||
pos + 1
|
pos + 1
|
||||||
} else {
|
} else {
|
||||||
// Hard split at the limit
|
// Hard split at character boundary
|
||||||
TELEGRAM_MAX_MESSAGE_LENGTH
|
hard_split
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -2825,4 +2831,100 @@ mod tests {
|
||||||
let ch_disabled = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
let ch_disabled = TelegramChannel::new("token".into(), vec!["*".into()], false);
|
||||||
assert!(!ch_disabled.mention_only);
|
assert!(!ch_disabled.mention_only);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
// TG6: Channel platform limit edge cases for Telegram (4096 char limit)
|
||||||
|
// Prevents: Pattern 6 — issues #574, #499
|
||||||
|
// ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_code_block_at_boundary() {
|
||||||
|
let mut msg = String::new();
|
||||||
|
msg.push_str("```python\n");
|
||||||
|
msg.push_str(&"x".repeat(4085));
|
||||||
|
msg.push_str("\n```\nMore text after code block");
|
||||||
|
let parts = split_message_for_telegram(&msg);
|
||||||
|
assert!(parts.len() >= 2, "code block spanning boundary should split");
|
||||||
|
for part in &parts {
|
||||||
|
assert!(
|
||||||
|
part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH,
|
||||||
|
"each part must be <= {TELEGRAM_MAX_MESSAGE_LENGTH}, got {}",
|
||||||
|
part.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_single_long_word() {
|
||||||
|
let long_word = "a".repeat(5000);
|
||||||
|
let parts = split_message_for_telegram(&long_word);
|
||||||
|
assert!(parts.len() >= 2, "word exceeding limit must be split");
|
||||||
|
for part in &parts {
|
||||||
|
assert!(
|
||||||
|
part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH,
|
||||||
|
"hard-split part must be <= {TELEGRAM_MAX_MESSAGE_LENGTH}, got {}",
|
||||||
|
part.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let reassembled: String = parts.join("");
|
||||||
|
assert_eq!(reassembled, long_word);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_exactly_at_limit_no_split() {
|
||||||
|
let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH);
|
||||||
|
let parts = split_message_for_telegram(&msg);
|
||||||
|
assert_eq!(parts.len(), 1, "message exactly at limit should not split");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_one_over_limit() {
|
||||||
|
let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH + 1);
|
||||||
|
let parts = split_message_for_telegram(&msg);
|
||||||
|
assert!(parts.len() >= 2, "message 1 char over limit must split");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_many_short_lines() {
|
||||||
|
let msg: String = (0..1000).map(|i| format!("line {i}\n")).collect();
|
||||||
|
let parts = split_message_for_telegram(&msg);
|
||||||
|
for part in &parts {
|
||||||
|
assert!(
|
||||||
|
part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH,
|
||||||
|
"short-line batch must be <= limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_only_whitespace() {
|
||||||
|
let msg = " \n\n\t ";
|
||||||
|
let parts = split_message_for_telegram(msg);
|
||||||
|
assert!(parts.len() <= 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_emoji_at_boundary() {
|
||||||
|
let mut msg = "a".repeat(4094);
|
||||||
|
msg.push_str("🎉🎊"); // 4096 chars total
|
||||||
|
let parts = split_message_for_telegram(&msg);
|
||||||
|
for part in &parts {
|
||||||
|
// The function splits on character count, not byte count
|
||||||
|
assert!(
|
||||||
|
part.chars().count() <= TELEGRAM_MAX_MESSAGE_LENGTH,
|
||||||
|
"emoji boundary split must respect limit"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_split_consecutive_newlines() {
|
||||||
|
let mut msg = "a".repeat(4090);
|
||||||
|
msg.push_str("\n\n\n\n\n\n");
|
||||||
|
msg.push_str(&"b".repeat(100));
|
||||||
|
let parts = split_message_for_telegram(&msg);
|
||||||
|
for part in &parts {
|
||||||
|
assert!(part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
438
tests/agent_loop_robustness.rs
Normal file
438
tests/agent_loop_robustness.rs
Normal file
|
|
@ -0,0 +1,438 @@
|
||||||
|
//! TG4: Agent Loop Robustness Tests
|
||||||
|
//!
|
||||||
|
//! Prevents: Pattern 4 — Agent loop & tool call processing bugs (13% of user bugs).
|
||||||
|
//! Issues: #746, #418, #777, #848
|
||||||
|
//!
|
||||||
|
//! Tests agent behavior with malformed tool calls, empty responses,
|
||||||
|
//! max iteration limits, and cascading tool failures using mock providers.
|
||||||
|
//! Complements inline parse_tool_calls tests in `src/agent/loop_.rs`.
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
use zeroclaw::agent::agent::Agent;
|
||||||
|
use zeroclaw::agent::dispatcher::NativeToolDispatcher;
|
||||||
|
use zeroclaw::config::MemoryConfig;
|
||||||
|
use zeroclaw::memory;
|
||||||
|
use zeroclaw::memory::Memory;
|
||||||
|
use zeroclaw::observability::{NoopObserver, Observer};
|
||||||
|
use zeroclaw::providers::{ChatRequest, ChatResponse, Provider, ToolCall};
|
||||||
|
use zeroclaw::tools::{Tool, ToolResult};
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Mock infrastructure
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
struct MockProvider {
|
||||||
|
responses: Mutex<Vec<ChatResponse>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MockProvider {
|
||||||
|
fn new(responses: Vec<ChatResponse>) -> Self {
|
||||||
|
Self {
|
||||||
|
responses: Mutex::new(responses),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> Result<String> {
|
||||||
|
Ok("fallback".into())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
_request: ChatRequest<'_>,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> Result<ChatResponse> {
|
||||||
|
let mut guard = self.responses.lock().unwrap();
|
||||||
|
if guard.is_empty() {
|
||||||
|
return Ok(ChatResponse {
|
||||||
|
text: Some("done".into()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(guard.remove(0))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct EchoTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for EchoTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"echo"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Echoes the input message"
|
||||||
|
}
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {"type": "string"}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||||
|
let msg = args
|
||||||
|
.get("message")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("(empty)")
|
||||||
|
.to_string();
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: msg,
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool that always fails, simulating a broken external service
|
||||||
|
struct FailingTool;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for FailingTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"failing_tool"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Always fails"
|
||||||
|
}
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({"type": "object"})
|
||||||
|
}
|
||||||
|
async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult> {
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Service unavailable: connection timeout".into()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Tool that tracks invocations
|
||||||
|
struct CountingTool {
|
||||||
|
count: Arc<Mutex<usize>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CountingTool {
|
||||||
|
fn new() -> (Self, Arc<Mutex<usize>>) {
|
||||||
|
let count = Arc::new(Mutex::new(0));
|
||||||
|
(Self { count: count.clone() }, count)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for CountingTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"counter"
|
||||||
|
}
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Counts invocations"
|
||||||
|
}
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({"type": "object"})
|
||||||
|
}
|
||||||
|
async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult> {
|
||||||
|
let mut c = self.count.lock().unwrap();
|
||||||
|
*c += 1;
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!("call #{}", *c),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Test helpers
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn make_memory() -> Arc<dyn Memory> {
|
||||||
|
let cfg = MemoryConfig {
|
||||||
|
backend: "none".into(),
|
||||||
|
..MemoryConfig::default()
|
||||||
|
};
|
||||||
|
Arc::from(memory::create_memory(&cfg, &std::env::temp_dir(), None).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn make_observer() -> Arc<dyn Observer> {
|
||||||
|
Arc::from(NoopObserver {})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn text_response(text: &str) -> ChatResponse {
|
||||||
|
ChatResponse {
|
||||||
|
text: Some(text.into()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tool_response(calls: Vec<ToolCall>) -> ChatResponse {
|
||||||
|
ChatResponse {
|
||||||
|
text: Some(String::new()),
|
||||||
|
tool_calls: calls,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_agent(provider: Box<dyn Provider>, tools: Vec<Box<dyn Tool>>) -> Agent {
|
||||||
|
Agent::builder()
|
||||||
|
.provider(provider)
|
||||||
|
.tools(tools)
|
||||||
|
.memory(make_memory())
|
||||||
|
.observer(make_observer())
|
||||||
|
.tool_dispatcher(Box::new(NativeToolDispatcher))
|
||||||
|
.workspace_dir(std::env::temp_dir())
|
||||||
|
.build()
|
||||||
|
.unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
// TG4.1: Malformed tool call recovery
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/// Agent should recover when LLM returns text with residual XML tags (#746)
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_recovers_from_text_with_xml_residue() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![text_response(
|
||||||
|
"Here is the result. Some leftover </tool_call> text after.",
|
||||||
|
)]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let response = agent.turn("test").await.unwrap();
|
||||||
|
assert!(
|
||||||
|
!response.is_empty(),
|
||||||
|
"agent should produce non-empty response despite XML residue"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle tool call with empty arguments gracefully
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_tool_call_with_empty_arguments() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}]),
|
||||||
|
text_response("Tool with empty args executed"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let response = agent.turn("call with empty args").await.unwrap();
|
||||||
|
assert!(!response.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle unknown tool name without crashing (#848 related)
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_nonexistent_tool_gracefully() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "absolutely_nonexistent_tool".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}]),
|
||||||
|
text_response("Recovered from unknown tool"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let response = agent.turn("call missing tool").await.unwrap();
|
||||||
|
assert!(
|
||||||
|
!response.is_empty(),
|
||||||
|
"agent should recover from unknown tool"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
// TG4.2: Tool failure cascade handling (#848)
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/// Agent should handle repeated tool failures without infinite loop
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_failing_tool() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "failing_tool".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}]),
|
||||||
|
text_response("Tool failed but I recovered"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(FailingTool)]);
|
||||||
|
let response = agent.turn("use failing tool").await.unwrap();
|
||||||
|
assert!(
|
||||||
|
!response.is_empty(),
|
||||||
|
"agent should produce response even after tool failure"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle mixed tool calls (some succeed, some fail)
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_mixed_tool_success_and_failure() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![
|
||||||
|
ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: r#"{"message": "success"}"#.into(),
|
||||||
|
},
|
||||||
|
ToolCall {
|
||||||
|
id: "tc2".into(),
|
||||||
|
name: "failing_tool".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
},
|
||||||
|
]),
|
||||||
|
text_response("Mixed results processed"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(
|
||||||
|
provider,
|
||||||
|
vec![Box::new(EchoTool), Box::new(FailingTool)],
|
||||||
|
);
|
||||||
|
let response = agent.turn("mixed tools").await.unwrap();
|
||||||
|
assert!(!response.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
// TG4.3: Iteration limit enforcement (#777)
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/// Agent should not exceed max_tool_iterations (default=10) even with
|
||||||
|
/// a provider that keeps returning tool calls
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_respects_max_tool_iterations() {
|
||||||
|
let (counting_tool, count) = CountingTool::new();
|
||||||
|
|
||||||
|
// Create 20 tool call responses - more than the default limit of 10
|
||||||
|
let mut responses: Vec<ChatResponse> = (0..20)
|
||||||
|
.map(|i| {
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: format!("tc_{i}"),
|
||||||
|
name: "counter".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}])
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
// Add a final text response that would be used if limit is reached
|
||||||
|
responses.push(text_response("Final response after iterations"));
|
||||||
|
|
||||||
|
let provider = Box::new(MockProvider::new(responses));
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(counting_tool)]);
|
||||||
|
|
||||||
|
// Agent should complete (either by hitting iteration limit or running out of responses)
|
||||||
|
let result = agent.turn("keep calling tools").await;
|
||||||
|
// The agent should complete without hanging
|
||||||
|
assert!(result.is_ok() || result.is_err());
|
||||||
|
|
||||||
|
let invocations = *count.lock().unwrap();
|
||||||
|
assert!(
|
||||||
|
invocations <= 10,
|
||||||
|
"tool invocations ({invocations}) should not exceed default max_tool_iterations (10)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
// TG4.4: Empty and whitespace responses
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/// Agent should handle empty text response from provider (#418 related)
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_empty_provider_response() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![ChatResponse {
|
||||||
|
text: Some(String::new()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
}]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
// Should not panic
|
||||||
|
let _result = agent.turn("test").await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle None text response from provider
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_none_text_response() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![ChatResponse {
|
||||||
|
text: None,
|
||||||
|
tool_calls: vec![],
|
||||||
|
}]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let _result = agent.turn("test").await;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle whitespace-only response
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_whitespace_only_response() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![text_response(" \n\t ")]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let _result = agent.turn("test").await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
// TG4.5: Tool call with special content
|
||||||
|
// ═════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
/// Agent should handle tool arguments with unicode content
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_unicode_tool_arguments() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: r#"{"message": "こんにちは世界 🌍"}"#.into(),
|
||||||
|
}]),
|
||||||
|
text_response("Unicode tool executed"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let response = agent.turn("unicode test").await.unwrap();
|
||||||
|
assert!(!response.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle tool arguments with nested JSON
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_nested_json_tool_arguments() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: r#"{"message": "{\"nested\": true, \"deep\": {\"level\": 3}}"}"#.into(),
|
||||||
|
}]),
|
||||||
|
text_response("Nested JSON tool executed"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let response = agent.turn("nested json test").await.unwrap();
|
||||||
|
assert!(!response.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Agent should handle tool call followed by immediate text (no second LLM call)
|
||||||
|
#[tokio::test]
|
||||||
|
async fn agent_handles_sequential_tool_then_text() {
|
||||||
|
let provider = Box::new(MockProvider::new(vec![
|
||||||
|
tool_response(vec![ToolCall {
|
||||||
|
id: "tc1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: r#"{"message": "step 1"}"#.into(),
|
||||||
|
}]),
|
||||||
|
text_response("Final answer after tool"),
|
||||||
|
]));
|
||||||
|
|
||||||
|
let mut agent = build_agent(provider, vec![Box::new(EchoTool)]);
|
||||||
|
let response = agent.turn("two step").await.unwrap();
|
||||||
|
assert!(
|
||||||
|
!response.is_empty(),
|
||||||
|
"should produce final text after tool execution"
|
||||||
|
);
|
||||||
|
}
|
||||||
310
tests/channel_routing.rs
Normal file
310
tests/channel_routing.rs
Normal file
|
|
@ -0,0 +1,310 @@
|
||||||
|
//! TG3: Channel Message Identity & Routing Tests
|
||||||
|
//!
|
||||||
|
//! Prevents: Pattern 3 — Channel message routing & identity bugs (17% of user bugs).
|
||||||
|
//! Issues: #496, #483, #620, #415, #503
|
||||||
|
//!
|
||||||
|
//! Tests that ChannelMessage fields are used consistently and that the
|
||||||
|
//! SendMessage → Channel trait contract preserves correct identity semantics.
|
||||||
|
//! Verifies sender/reply_target field contracts to prevent field swaps.
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use zeroclaw::channels::traits::{Channel, ChannelMessage, SendMessage};
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// ChannelMessage construction and field semantics
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channel_message_sender_field_holds_platform_user_id() {
|
||||||
|
// Simulates Telegram: sender should be numeric chat_id, not username
|
||||||
|
let msg = ChannelMessage {
|
||||||
|
id: "msg_1".into(),
|
||||||
|
sender: "123456789".into(), // numeric chat_id
|
||||||
|
reply_target: "msg_0".into(),
|
||||||
|
content: "test message".into(),
|
||||||
|
channel: "telegram".into(),
|
||||||
|
timestamp: 1700000000,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(msg.sender, "123456789");
|
||||||
|
// Sender should be the platform-level user/chat identifier
|
||||||
|
assert!(
|
||||||
|
msg.sender.chars().all(|c| c.is_ascii_digit()),
|
||||||
|
"Telegram sender should be numeric chat_id, got: {}",
|
||||||
|
msg.sender
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channel_message_reply_target_distinct_from_sender() {
|
||||||
|
// Simulates Discord: reply_target should be channel_id, not sender user_id
|
||||||
|
let msg = ChannelMessage {
|
||||||
|
id: "msg_1".into(),
|
||||||
|
sender: "user_987654".into(), // Discord user ID
|
||||||
|
reply_target: "channel_123".into(), // Discord channel ID for replies
|
||||||
|
content: "test message".into(),
|
||||||
|
channel: "discord".into(),
|
||||||
|
timestamp: 1700000000,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_ne!(
|
||||||
|
msg.sender, msg.reply_target,
|
||||||
|
"sender and reply_target should be distinct for Discord"
|
||||||
|
);
|
||||||
|
assert_eq!(msg.reply_target, "channel_123");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channel_message_fields_not_swapped() {
|
||||||
|
// Guards against #496 (Telegram) and #483 (Discord) field swap bugs
|
||||||
|
let msg = ChannelMessage {
|
||||||
|
id: "msg_42".into(),
|
||||||
|
sender: "sender_value".into(),
|
||||||
|
reply_target: "target_value".into(),
|
||||||
|
content: "payload".into(),
|
||||||
|
channel: "test".into(),
|
||||||
|
timestamp: 1700000000,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(msg.sender, "sender_value", "sender field should not be swapped");
|
||||||
|
assert_eq!(
|
||||||
|
msg.reply_target, "target_value",
|
||||||
|
"reply_target field should not be swapped"
|
||||||
|
);
|
||||||
|
assert_ne!(
|
||||||
|
msg.sender, msg.reply_target,
|
||||||
|
"sender and reply_target should remain distinct"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channel_message_preserves_all_fields_on_clone() {
|
||||||
|
let original = ChannelMessage {
|
||||||
|
id: "clone_test".into(),
|
||||||
|
sender: "sender_123".into(),
|
||||||
|
reply_target: "target_456".into(),
|
||||||
|
content: "cloned content".into(),
|
||||||
|
channel: "test_channel".into(),
|
||||||
|
timestamp: 1700000001,
|
||||||
|
};
|
||||||
|
|
||||||
|
let cloned = original.clone();
|
||||||
|
|
||||||
|
assert_eq!(cloned.id, original.id);
|
||||||
|
assert_eq!(cloned.sender, original.sender);
|
||||||
|
assert_eq!(cloned.reply_target, original.reply_target);
|
||||||
|
assert_eq!(cloned.content, original.content);
|
||||||
|
assert_eq!(cloned.channel, original.channel);
|
||||||
|
assert_eq!(cloned.timestamp, original.timestamp);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// SendMessage construction
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn send_message_new_sets_content_and_recipient() {
|
||||||
|
let msg = SendMessage::new("Hello", "recipient_123");
|
||||||
|
|
||||||
|
assert_eq!(msg.content, "Hello");
|
||||||
|
assert_eq!(msg.recipient, "recipient_123");
|
||||||
|
assert!(msg.subject.is_none(), "subject should be None by default");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn send_message_with_subject_sets_all_fields() {
|
||||||
|
let msg = SendMessage::with_subject("Hello", "recipient_123", "Re: Test");
|
||||||
|
|
||||||
|
assert_eq!(msg.content, "Hello");
|
||||||
|
assert_eq!(msg.recipient, "recipient_123");
|
||||||
|
assert_eq!(msg.subject.as_deref(), Some("Re: Test"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn send_message_recipient_carries_platform_target() {
|
||||||
|
// Verifies that SendMessage::recipient is used as the platform delivery target
|
||||||
|
// For Telegram: this should be the chat_id
|
||||||
|
// For Discord: this should be the channel_id
|
||||||
|
let telegram_msg = SendMessage::new("response", "123456789");
|
||||||
|
assert_eq!(
|
||||||
|
telegram_msg.recipient, "123456789",
|
||||||
|
"Telegram SendMessage recipient should be chat_id"
|
||||||
|
);
|
||||||
|
|
||||||
|
let discord_msg = SendMessage::new("response", "channel_987654");
|
||||||
|
assert_eq!(
|
||||||
|
discord_msg.recipient, "channel_987654",
|
||||||
|
"Discord SendMessage recipient should be channel_id"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Channel trait contract: send/listen roundtrip via DummyChannel
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Test channel that captures sent messages for assertion
|
||||||
|
struct CapturingChannel {
|
||||||
|
sent: std::sync::Mutex<Vec<SendMessage>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CapturingChannel {
|
||||||
|
fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
sent: std::sync::Mutex::new(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sent_messages(&self) -> Vec<SendMessage> {
|
||||||
|
self.sent.lock().unwrap().clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Channel for CapturingChannel {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"capturing"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
|
||||||
|
self.sent.lock().unwrap().push(message.clone());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn listen(
|
||||||
|
&self,
|
||||||
|
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
tx.send(ChannelMessage {
|
||||||
|
id: "listen_1".into(),
|
||||||
|
sender: "test_sender".into(),
|
||||||
|
reply_target: "test_target".into(),
|
||||||
|
content: "incoming".into(),
|
||||||
|
channel: "capturing".into(),
|
||||||
|
timestamp: 1700000000,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!(e.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_send_preserves_recipient() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
let msg = SendMessage::new("Hello", "target_123");
|
||||||
|
|
||||||
|
channel.send(&msg).await.unwrap();
|
||||||
|
|
||||||
|
let sent = channel.sent_messages();
|
||||||
|
assert_eq!(sent.len(), 1);
|
||||||
|
assert_eq!(sent[0].recipient, "target_123");
|
||||||
|
assert_eq!(sent[0].content, "Hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_listen_produces_correct_identity_fields() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
|
||||||
|
|
||||||
|
channel.listen(tx).await.unwrap();
|
||||||
|
let received = rx.recv().await.expect("should receive message");
|
||||||
|
|
||||||
|
assert_eq!(received.sender, "test_sender");
|
||||||
|
assert_eq!(received.reply_target, "test_target");
|
||||||
|
assert_ne!(
|
||||||
|
received.sender, received.reply_target,
|
||||||
|
"listen() should populate sender and reply_target distinctly"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_send_reply_uses_sender_from_listen() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
|
||||||
|
|
||||||
|
// Simulate: listen() → receive message → send reply using sender
|
||||||
|
channel.listen(tx).await.unwrap();
|
||||||
|
let incoming = rx.recv().await.expect("should receive message");
|
||||||
|
|
||||||
|
// Reply should go to the reply_target, not sender
|
||||||
|
let reply = SendMessage::new("reply content", &incoming.reply_target);
|
||||||
|
channel.send(&reply).await.unwrap();
|
||||||
|
|
||||||
|
let sent = channel.sent_messages();
|
||||||
|
assert_eq!(sent.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
sent[0].recipient, "test_target",
|
||||||
|
"reply should use reply_target as recipient"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Channel trait default methods
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_health_check_default_returns_true() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
assert!(
|
||||||
|
channel.health_check().await,
|
||||||
|
"default health_check should return true"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_typing_defaults_succeed() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
assert!(channel.start_typing("target").await.is_ok());
|
||||||
|
assert!(channel.stop_typing("target").await.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_draft_defaults() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
assert!(!channel.supports_draft_updates());
|
||||||
|
|
||||||
|
let draft_result = channel
|
||||||
|
.send_draft(&SendMessage::new("draft", "target"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(draft_result.is_none(), "default send_draft should return None");
|
||||||
|
|
||||||
|
assert!(channel
|
||||||
|
.update_draft("target", "msg_1", "updated")
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
assert!(channel
|
||||||
|
.finalize_draft("target", "msg_1", "final")
|
||||||
|
.await
|
||||||
|
.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Multiple messages: conversation context preservation
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn channel_multiple_sends_preserve_order_and_recipients() {
|
||||||
|
let channel = CapturingChannel::new();
|
||||||
|
|
||||||
|
channel
|
||||||
|
.send(&SendMessage::new("msg 1", "target_a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
channel
|
||||||
|
.send(&SendMessage::new("msg 2", "target_b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
channel
|
||||||
|
.send(&SendMessage::new("msg 3", "target_a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let sent = channel.sent_messages();
|
||||||
|
assert_eq!(sent.len(), 3);
|
||||||
|
assert_eq!(sent[0].recipient, "target_a");
|
||||||
|
assert_eq!(sent[1].recipient, "target_b");
|
||||||
|
assert_eq!(sent[2].recipient, "target_a");
|
||||||
|
assert_eq!(sent[0].content, "msg 1");
|
||||||
|
assert_eq!(sent[1].content, "msg 2");
|
||||||
|
assert_eq!(sent[2].content, "msg 3");
|
||||||
|
}
|
||||||
245
tests/config_persistence.rs
Normal file
245
tests/config_persistence.rs
Normal file
|
|
@ -0,0 +1,245 @@
|
||||||
|
//! TG2: Config Load/Save Round-Trip Tests
|
||||||
|
//!
|
||||||
|
//! Prevents: Pattern 2 — Config persistence & workspace discovery bugs (13% of user bugs).
|
||||||
|
//! Issues: #547, #417, #621, #802
|
||||||
|
//!
|
||||||
|
//! Tests Config::load_or_init() with isolated temp directories, env var overrides,
|
||||||
|
//! and config file round-trips to verify workspace discovery and persistence.
|
||||||
|
|
||||||
|
use std::fs;
|
||||||
|
use zeroclaw::config::{AgentConfig, Config, MemoryConfig};
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Config default construction
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_default_has_expected_provider() {
|
||||||
|
let config = Config::default();
|
||||||
|
assert!(
|
||||||
|
config.default_provider.is_some(),
|
||||||
|
"default config should have a default_provider"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_default_has_expected_model() {
|
||||||
|
let config = Config::default();
|
||||||
|
assert!(
|
||||||
|
config.default_model.is_some(),
|
||||||
|
"default config should have a default_model"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_default_temperature_positive() {
|
||||||
|
let config = Config::default();
|
||||||
|
assert!(
|
||||||
|
config.default_temperature > 0.0,
|
||||||
|
"default temperature should be positive"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// AgentConfig defaults
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_config_default_max_tool_iterations() {
|
||||||
|
let agent = AgentConfig::default();
|
||||||
|
assert_eq!(
|
||||||
|
agent.max_tool_iterations, 10,
|
||||||
|
"default max_tool_iterations should be 10"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_config_default_max_history_messages() {
|
||||||
|
let agent = AgentConfig::default();
|
||||||
|
assert_eq!(
|
||||||
|
agent.max_history_messages, 50,
|
||||||
|
"default max_history_messages should be 50"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_config_default_tool_dispatcher() {
|
||||||
|
let agent = AgentConfig::default();
|
||||||
|
assert_eq!(
|
||||||
|
agent.tool_dispatcher, "auto",
|
||||||
|
"default tool_dispatcher should be 'auto'"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn agent_config_default_compact_context_off() {
|
||||||
|
let agent = AgentConfig::default();
|
||||||
|
assert!(
|
||||||
|
!agent.compact_context,
|
||||||
|
"compact_context should default to false"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// MemoryConfig defaults
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn memory_config_default_backend() {
|
||||||
|
let memory = MemoryConfig::default();
|
||||||
|
assert!(
|
||||||
|
!memory.backend.is_empty(),
|
||||||
|
"memory backend should have a default value"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn memory_config_default_embedding_provider() {
|
||||||
|
let memory = MemoryConfig::default();
|
||||||
|
// Default embedding_provider should be set (even if "none")
|
||||||
|
assert!(
|
||||||
|
!memory.embedding_provider.is_empty(),
|
||||||
|
"embedding_provider should have a default value"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn memory_config_default_vector_keyword_weights_sum_to_one() {
|
||||||
|
let memory = MemoryConfig::default();
|
||||||
|
let sum = memory.vector_weight + memory.keyword_weight;
|
||||||
|
assert!(
|
||||||
|
(sum - 1.0).abs() < 0.01,
|
||||||
|
"vector_weight + keyword_weight should sum to ~1.0, got {sum}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Config TOML serialization round-trip
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_toml_roundtrip_preserves_provider() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.default_provider = Some("deepseek".into());
|
||||||
|
config.default_model = Some("deepseek-chat".into());
|
||||||
|
config.default_temperature = 0.5;
|
||||||
|
|
||||||
|
let toml_str = toml::to_string(&config).expect("config should serialize to TOML");
|
||||||
|
let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back");
|
||||||
|
|
||||||
|
assert_eq!(parsed.default_provider.as_deref(), Some("deepseek"));
|
||||||
|
assert_eq!(parsed.default_model.as_deref(), Some("deepseek-chat"));
|
||||||
|
assert!((parsed.default_temperature - 0.5).abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_toml_roundtrip_preserves_agent_config() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.agent.max_tool_iterations = 5;
|
||||||
|
config.agent.max_history_messages = 25;
|
||||||
|
config.agent.compact_context = true;
|
||||||
|
|
||||||
|
let toml_str = toml::to_string(&config).expect("config should serialize to TOML");
|
||||||
|
let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back");
|
||||||
|
|
||||||
|
assert_eq!(parsed.agent.max_tool_iterations, 5);
|
||||||
|
assert_eq!(parsed.agent.max_history_messages, 25);
|
||||||
|
assert!(parsed.agent.compact_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_toml_roundtrip_preserves_memory_config() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.memory.embedding_provider = "openai".into();
|
||||||
|
config.memory.embedding_model = "text-embedding-3-small".into();
|
||||||
|
config.memory.vector_weight = 0.8;
|
||||||
|
config.memory.keyword_weight = 0.2;
|
||||||
|
|
||||||
|
let toml_str = toml::to_string(&config).expect("config should serialize to TOML");
|
||||||
|
let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back");
|
||||||
|
|
||||||
|
assert_eq!(parsed.memory.embedding_provider, "openai");
|
||||||
|
assert_eq!(parsed.memory.embedding_model, "text-embedding-3-small");
|
||||||
|
assert!((parsed.memory.vector_weight - 0.8).abs() < f64::EPSILON);
|
||||||
|
assert!((parsed.memory.keyword_weight - 0.2).abs() < f64::EPSILON);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Config file write/read round-trip with tempdir
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_file_write_read_roundtrip() {
|
||||||
|
let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed");
|
||||||
|
let config_path = tmp.path().join("config.toml");
|
||||||
|
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.default_provider = Some("mistral".into());
|
||||||
|
config.default_model = Some("mistral-large".into());
|
||||||
|
config.agent.max_tool_iterations = 15;
|
||||||
|
|
||||||
|
let toml_str = toml::to_string(&config).expect("config should serialize");
|
||||||
|
fs::write(&config_path, &toml_str).expect("config file write should succeed");
|
||||||
|
|
||||||
|
let read_back = fs::read_to_string(&config_path).expect("config file read should succeed");
|
||||||
|
let parsed: Config = toml::from_str(&read_back).expect("TOML should parse back");
|
||||||
|
|
||||||
|
assert_eq!(parsed.default_provider.as_deref(), Some("mistral"));
|
||||||
|
assert_eq!(parsed.default_model.as_deref(), Some("mistral-large"));
|
||||||
|
assert_eq!(parsed.agent.max_tool_iterations, 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_file_with_missing_optional_fields_uses_defaults() {
|
||||||
|
// Simulate a minimal config TOML that omits optional sections
|
||||||
|
let minimal_toml = r#"
|
||||||
|
default_temperature = 0.7
|
||||||
|
"#;
|
||||||
|
let parsed: Config = toml::from_str(minimal_toml).expect("minimal TOML should parse");
|
||||||
|
|
||||||
|
// Agent config should use defaults
|
||||||
|
assert_eq!(parsed.agent.max_tool_iterations, 10);
|
||||||
|
assert_eq!(parsed.agent.max_history_messages, 50);
|
||||||
|
assert!(!parsed.agent.compact_context);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_file_with_custom_agent_section() {
|
||||||
|
let toml_with_agent = r#"
|
||||||
|
default_temperature = 0.7
|
||||||
|
|
||||||
|
[agent]
|
||||||
|
max_tool_iterations = 3
|
||||||
|
compact_context = true
|
||||||
|
"#;
|
||||||
|
let parsed: Config =
|
||||||
|
toml::from_str(toml_with_agent).expect("TOML with agent section should parse");
|
||||||
|
|
||||||
|
assert_eq!(parsed.agent.max_tool_iterations, 3);
|
||||||
|
assert!(parsed.agent.compact_context);
|
||||||
|
// max_history_messages should still use default
|
||||||
|
assert_eq!(parsed.agent.max_history_messages, 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Workspace directory creation
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn workspace_dir_creation_in_tempdir() {
|
||||||
|
let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed");
|
||||||
|
let workspace_dir = tmp.path().join("workspace");
|
||||||
|
|
||||||
|
fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed");
|
||||||
|
assert!(workspace_dir.exists(), "workspace dir should exist");
|
||||||
|
assert!(workspace_dir.is_dir(), "workspace path should be a directory");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn nested_workspace_dir_creation() {
|
||||||
|
let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed");
|
||||||
|
let nested_dir = tmp.path().join("deep").join("nested").join("workspace");
|
||||||
|
|
||||||
|
fs::create_dir_all(&nested_dir).expect("nested dir creation should succeed");
|
||||||
|
assert!(nested_dir.exists(), "nested workspace dir should exist");
|
||||||
|
}
|
||||||
335
tests/memory_restart.rs
Normal file
335
tests/memory_restart.rs
Normal file
|
|
@ -0,0 +1,335 @@
|
||||||
|
//! TG5: Memory Restart Resilience Tests
|
||||||
|
//!
|
||||||
|
//! Prevents: Pattern 5 — Memory & state persistence bugs (10% of user bugs).
|
||||||
|
//! Issues: #430, #693, #802
|
||||||
|
//!
|
||||||
|
//! Tests SqliteMemory deduplication on restart, session scoping, concurrent
|
||||||
|
//! message ordering, and recall behavior after re-initialization.
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
use zeroclaw::memory::sqlite::SqliteMemory;
|
||||||
|
use zeroclaw::memory::traits::{Memory, MemoryCategory};
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Deduplication: same key overwrites instead of duplicating (#430)
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_store_same_key_deduplicates() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
// Store same key twice with different content
|
||||||
|
mem.store("greeting", "hello world", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("greeting", "hello updated", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Should have exactly 1 entry, not 2
|
||||||
|
let count = mem.count().await.unwrap();
|
||||||
|
assert_eq!(count, 1, "storing same key twice should not create duplicates");
|
||||||
|
|
||||||
|
// Content should be the latest version
|
||||||
|
let entry = mem.get("greeting").await.unwrap().expect("entry should exist");
|
||||||
|
assert_eq!(entry.content, "hello updated");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_store_different_keys_creates_separate_entries() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
mem.store("key_a", "content a", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("key_b", "content b", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let count = mem.count().await.unwrap();
|
||||||
|
assert_eq!(count, 2, "different keys should create separate entries");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Restart resilience: data persists across memory re-initialization
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_persists_across_reinitialization() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// First "session": store data
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
mem.store("persistent_fact", "Rust is great", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second "session": re-create memory from same path
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
let entry = mem
|
||||||
|
.get("persistent_fact")
|
||||||
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.expect("entry should survive reinitialization");
|
||||||
|
assert_eq!(entry.content, "Rust is great");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_restart_does_not_duplicate_on_rewrite() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// First session: store entries
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
mem.store("fact_1", "original content", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("fact_2", "another fact", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second session: re-store same keys (simulates channel re-reading history)
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
mem.store("fact_1", "original content", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("fact_2", "another fact", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let count = mem.count().await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
count, 2,
|
||||||
|
"re-storing same keys after restart should not create duplicates"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Session scoping: messages scoped to sessions don't leak
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_session_scoped_store_and_recall() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
// Store in different sessions
|
||||||
|
mem.store(
|
||||||
|
"session_a_fact",
|
||||||
|
"fact from session A",
|
||||||
|
MemoryCategory::Conversation,
|
||||||
|
Some("session_a"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store(
|
||||||
|
"session_b_fact",
|
||||||
|
"fact from session B",
|
||||||
|
MemoryCategory::Conversation,
|
||||||
|
Some("session_b"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// List scoped to session_a
|
||||||
|
let session_a_entries = mem
|
||||||
|
.list(Some(&MemoryCategory::Conversation), Some("session_a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
session_a_entries.len(),
|
||||||
|
1,
|
||||||
|
"session_a should have exactly 1 entry"
|
||||||
|
);
|
||||||
|
assert_eq!(session_a_entries[0].content, "fact from session A");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_global_recall_includes_all_sessions() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
mem.store("global_a", "alpha content", MemoryCategory::Core, Some("s1"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Global count should include all
|
||||||
|
let count = mem.count().await.unwrap();
|
||||||
|
assert_eq!(count, 2, "global count should include entries from all sessions");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Recall and search behavior
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_recall_returns_relevant_results() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
mem.store("lang_pref", "User prefers Rust programming", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("food_pref", "User likes sushi for lunch", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mem.recall("Rust programming", 10, None).await.unwrap();
|
||||||
|
assert!(!results.is_empty(), "recall should find matching entries");
|
||||||
|
// The Rust-related entry should be in results
|
||||||
|
assert!(
|
||||||
|
results.iter().any(|e| e.content.contains("Rust")),
|
||||||
|
"recall for 'Rust' should include the Rust-related entry"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_recall_respects_limit() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
mem.store(
|
||||||
|
&format!("entry_{i}"),
|
||||||
|
&format!("test content number {i}"),
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let results = mem.recall("test content", 3, None).await.unwrap();
|
||||||
|
assert!(
|
||||||
|
results.len() <= 3,
|
||||||
|
"recall should respect limit of 3, got {}",
|
||||||
|
results.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_recall_empty_query_returns_empty() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
mem.store("fact", "some content", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mem.recall("", 10, None).await.unwrap();
|
||||||
|
assert!(
|
||||||
|
results.is_empty(),
|
||||||
|
"empty query should return no results"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Forget and health check
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_forget_removes_entry() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
mem.store("to_forget", "temporary info", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
|
||||||
|
let removed = mem.forget("to_forget").await.unwrap();
|
||||||
|
assert!(removed, "forget should return true for existing key");
|
||||||
|
assert_eq!(mem.count().await.unwrap(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_forget_nonexistent_returns_false() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
let removed = mem.forget("nonexistent_key").await.unwrap();
|
||||||
|
assert!(!removed, "forget should return false for nonexistent key");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_health_check_returns_true() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
assert!(mem.health_check().await, "health_check should return true");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Concurrent access
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_concurrent_stores_no_data_loss() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = Arc::new(SqliteMemory::new(tmp.path()).unwrap());
|
||||||
|
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for i in 0..5 {
|
||||||
|
let mem_clone = mem.clone();
|
||||||
|
handles.push(tokio::spawn(async move {
|
||||||
|
mem_clone
|
||||||
|
.store(
|
||||||
|
&format!("concurrent_{i}"),
|
||||||
|
&format!("content from task {i}"),
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let count = mem.count().await.unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
count, 5,
|
||||||
|
"all concurrent stores should succeed, got {count}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Memory categories
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn sqlite_memory_list_by_category() {
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
||||||
|
mem.store("core_fact", "core info", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("daily_note", "daily note", MemoryCategory::Daily, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("conv_msg", "conversation msg", MemoryCategory::Conversation, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let core_entries = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
|
assert_eq!(core_entries.len(), 1, "should have 1 Core entry");
|
||||||
|
assert_eq!(core_entries[0].key, "core_fact");
|
||||||
|
|
||||||
|
let daily_entries = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
|
assert_eq!(daily_entries.len(), 1, "should have 1 Daily entry");
|
||||||
|
}
|
||||||
244
tests/provider_resolution.rs
Normal file
244
tests/provider_resolution.rs
Normal file
|
|
@ -0,0 +1,244 @@
|
||||||
|
//! TG1: Provider End-to-End Resolution Tests
|
||||||
|
//!
|
||||||
|
//! Prevents: Pattern 1 — Provider configuration & resolution bugs (27% of user bugs).
|
||||||
|
//! Issues: #831, #834, #721, #580, #452, #451, #796, #843
|
||||||
|
//!
|
||||||
|
//! Tests the full pipeline from config values through `create_provider_with_url()`
|
||||||
|
//! to provider construction, verifying factory resolution, URL construction,
|
||||||
|
//! credential wiring, and auth header format.
|
||||||
|
|
||||||
|
use zeroclaw::providers::compatible::{AuthStyle, OpenAiCompatibleProvider};
|
||||||
|
use zeroclaw::providers::{create_provider, create_provider_with_url};
|
||||||
|
|
||||||
|
/// Helper: assert provider creation succeeds
|
||||||
|
fn assert_provider_ok(name: &str, key: Option<&str>, url: Option<&str>) {
|
||||||
|
let result = create_provider_with_url(name, key, url);
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"{name} provider should resolve: {}",
|
||||||
|
result.err().map(|e| e.to_string()).unwrap_or_default()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Factory resolution: each major provider name resolves without error
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_openai_provider() {
|
||||||
|
assert_provider_ok("openai", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_anthropic_provider() {
|
||||||
|
assert_provider_ok("anthropic", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_deepseek_provider() {
|
||||||
|
assert_provider_ok("deepseek", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_mistral_provider() {
|
||||||
|
assert_provider_ok("mistral", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_ollama_provider() {
|
||||||
|
assert_provider_ok("ollama", None, None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_groq_provider() {
|
||||||
|
assert_provider_ok("groq", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_xai_provider() {
|
||||||
|
assert_provider_ok("xai", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_together_provider() {
|
||||||
|
assert_provider_ok("together", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_fireworks_provider() {
|
||||||
|
assert_provider_ok("fireworks", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_resolves_perplexity_provider() {
|
||||||
|
assert_provider_ok("perplexity", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Factory resolution: alias variants map to same provider
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_grok_alias_resolves_to_xai() {
|
||||||
|
assert_provider_ok("grok", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_kimi_alias_resolves_to_moonshot() {
|
||||||
|
assert_provider_ok("kimi", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_zhipu_alias_resolves_to_glm() {
|
||||||
|
assert_provider_ok("zhipu", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Custom URL provider creation
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_custom_http_url_resolves() {
|
||||||
|
assert_provider_ok("custom:http://localhost:8080", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_custom_https_url_resolves() {
|
||||||
|
assert_provider_ok("custom:https://api.example.com/v1", Some("test-key"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_custom_ftp_url_rejected() {
|
||||||
|
let result = create_provider_with_url("custom:ftp://example.com", None, None);
|
||||||
|
assert!(result.is_err(), "ftp scheme should be rejected");
|
||||||
|
let err_msg = result.err().unwrap().to_string();
|
||||||
|
assert!(
|
||||||
|
err_msg.contains("http://") || err_msg.contains("https://"),
|
||||||
|
"error should mention valid schemes: {err_msg}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_custom_empty_url_rejected() {
|
||||||
|
let result = create_provider_with_url("custom:", None, None);
|
||||||
|
assert!(result.is_err(), "empty custom URL should be rejected");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_unknown_provider_rejected() {
|
||||||
|
let result = create_provider_with_url("nonexistent_provider_xyz", None, None);
|
||||||
|
assert!(result.is_err(), "unknown provider name should be rejected");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// OpenAiCompatibleProvider: credential and auth style wiring
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compatible_provider_bearer_auth_style() {
|
||||||
|
// Construction with Bearer auth should succeed
|
||||||
|
let _provider = OpenAiCompatibleProvider::new(
|
||||||
|
"TestProvider",
|
||||||
|
"https://api.test.com",
|
||||||
|
Some("sk-test-key-12345"),
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compatible_provider_xapikey_auth_style() {
|
||||||
|
// Construction with XApiKey auth should succeed
|
||||||
|
let _provider = OpenAiCompatibleProvider::new(
|
||||||
|
"TestProvider",
|
||||||
|
"https://api.test.com",
|
||||||
|
Some("sk-test-key-12345"),
|
||||||
|
AuthStyle::XApiKey,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compatible_provider_custom_auth_header() {
|
||||||
|
// Construction with Custom auth should succeed
|
||||||
|
let _provider = OpenAiCompatibleProvider::new(
|
||||||
|
"TestProvider",
|
||||||
|
"https://api.test.com",
|
||||||
|
Some("sk-test-key-12345"),
|
||||||
|
AuthStyle::Custom("X-Custom-Auth".into()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compatible_provider_no_credential() {
|
||||||
|
// Construction without credential should succeed (for local providers)
|
||||||
|
let _provider = OpenAiCompatibleProvider::new(
|
||||||
|
"TestLocal",
|
||||||
|
"http://localhost:11434",
|
||||||
|
None,
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compatible_provider_base_url_trailing_slash_normalized() {
|
||||||
|
// Construction with trailing slash URL should succeed
|
||||||
|
let _provider = OpenAiCompatibleProvider::new(
|
||||||
|
"TestProvider",
|
||||||
|
"https://api.test.com/v1/",
|
||||||
|
Some("key"),
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Provider with api_url override (simulates #721 - Ollama api_url config)
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_ollama_with_custom_api_url() {
|
||||||
|
assert_provider_ok("ollama", None, Some("http://192.168.1.100:11434"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_openai_with_custom_api_url() {
|
||||||
|
assert_provider_ok(
|
||||||
|
"openai",
|
||||||
|
Some("test-key"),
|
||||||
|
Some("https://custom-openai-proxy.example.com/v1"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Provider default convenience factory
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convenience_factory_resolves_major_providers() {
|
||||||
|
for provider_name in &[
|
||||||
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"deepseek",
|
||||||
|
"mistral",
|
||||||
|
"groq",
|
||||||
|
"xai",
|
||||||
|
"together",
|
||||||
|
"fireworks",
|
||||||
|
"perplexity",
|
||||||
|
] {
|
||||||
|
let result = create_provider(provider_name, Some("test-key"));
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"convenience factory should resolve {provider_name}: {}",
|
||||||
|
result.err().map(|e| e.to_string()).unwrap_or_default()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convenience_factory_ollama_no_key() {
|
||||||
|
let result = create_provider("ollama", None);
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"ollama should not require api key: {}",
|
||||||
|
result.err().map(|e| e.to_string()).unwrap_or_default()
|
||||||
|
);
|
||||||
|
}
|
||||||
303
tests/provider_schema.rs
Normal file
303
tests/provider_schema.rs
Normal file
|
|
@ -0,0 +1,303 @@
|
||||||
|
//! TG7: Provider Schema Conformance Tests
|
||||||
|
//!
|
||||||
|
//! Prevents: Pattern 7 — External schema compatibility bugs (7% of user bugs).
|
||||||
|
//! Issues: #769, #843
|
||||||
|
//!
|
||||||
|
//! Tests request/response serialization to verify required fields are present
|
||||||
|
//! for each provider's API specification. Validates ChatMessage, ChatResponse,
|
||||||
|
//! ToolCall, and AuthStyle serialization contracts.
|
||||||
|
|
||||||
|
use zeroclaw::providers::compatible::AuthStyle;
|
||||||
|
use zeroclaw::providers::traits::{ChatMessage, ChatResponse, ToolCall};
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// ChatMessage serialization
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_system_role_correct() {
|
||||||
|
let msg = ChatMessage::system("You are a helpful assistant");
|
||||||
|
assert_eq!(msg.role, "system");
|
||||||
|
assert_eq!(msg.content, "You are a helpful assistant");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_user_role_correct() {
|
||||||
|
let msg = ChatMessage::user("Hello");
|
||||||
|
assert_eq!(msg.role, "user");
|
||||||
|
assert_eq!(msg.content, "Hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_assistant_role_correct() {
|
||||||
|
let msg = ChatMessage::assistant("Hi there!");
|
||||||
|
assert_eq!(msg.role, "assistant");
|
||||||
|
assert_eq!(msg.content, "Hi there!");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_tool_role_correct() {
|
||||||
|
let msg = ChatMessage::tool("tool result");
|
||||||
|
assert_eq!(msg.role, "tool");
|
||||||
|
assert_eq!(msg.content, "tool result");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_serializes_to_json_with_required_fields() {
|
||||||
|
let msg = ChatMessage::user("test message");
|
||||||
|
let json = serde_json::to_value(&msg).unwrap();
|
||||||
|
|
||||||
|
assert!(json.get("role").is_some(), "JSON must have 'role' field");
|
||||||
|
assert!(
|
||||||
|
json.get("content").is_some(),
|
||||||
|
"JSON must have 'content' field"
|
||||||
|
);
|
||||||
|
assert_eq!(json["role"], "user");
|
||||||
|
assert_eq!(json["content"], "test message");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_json_roundtrip() {
|
||||||
|
let original = ChatMessage::assistant("response text");
|
||||||
|
let json_str = serde_json::to_string(&original).unwrap();
|
||||||
|
let parsed: ChatMessage = serde_json::from_str(&json_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.role, original.role);
|
||||||
|
assert_eq!(parsed.content, original.content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// ToolCall serialization (#843 - tool_call_id field)
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_call_has_required_fields() {
|
||||||
|
let tc = ToolCall {
|
||||||
|
id: "call_abc123".into(),
|
||||||
|
name: "web_search".into(),
|
||||||
|
arguments: r#"{"query": "rust programming"}"#.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&tc).unwrap();
|
||||||
|
assert!(json.get("id").is_some(), "ToolCall must have 'id' field");
|
||||||
|
assert!(json.get("name").is_some(), "ToolCall must have 'name' field");
|
||||||
|
assert!(
|
||||||
|
json.get("arguments").is_some(),
|
||||||
|
"ToolCall must have 'arguments' field"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_call_id_preserved_in_serialization() {
|
||||||
|
let tc = ToolCall {
|
||||||
|
id: "call_deepseek_42".into(),
|
||||||
|
name: "shell".into(),
|
||||||
|
arguments: r#"{"command": "ls"}"#.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let json_str = serde_json::to_string(&tc).unwrap();
|
||||||
|
let parsed: ToolCall = serde_json::from_str(&json_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(parsed.id, "call_deepseek_42", "tool_call_id must survive roundtrip");
|
||||||
|
assert_eq!(parsed.name, "shell");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_call_arguments_contain_valid_json() {
|
||||||
|
let tc = ToolCall {
|
||||||
|
id: "call_1".into(),
|
||||||
|
name: "file_write".into(),
|
||||||
|
arguments: r#"{"path": "/tmp/test.txt", "content": "hello"}"#.into(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Arguments should parse as valid JSON
|
||||||
|
let args: serde_json::Value = serde_json::from_str(&tc.arguments)
|
||||||
|
.expect("tool call arguments should be valid JSON");
|
||||||
|
assert!(args.get("path").is_some());
|
||||||
|
assert!(args.get("content").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Tool message with tool_call_id (DeepSeek requirement)
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_response_message_can_embed_tool_call_id() {
|
||||||
|
// DeepSeek requires tool_call_id in tool response messages.
|
||||||
|
// The tool message content can embed the tool_call_id as JSON.
|
||||||
|
let tool_response = ChatMessage::tool(
|
||||||
|
r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#,
|
||||||
|
);
|
||||||
|
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(&tool_response.content)
|
||||||
|
.expect("tool response content should be valid JSON");
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
parsed.get("tool_call_id").is_some(),
|
||||||
|
"tool response should include tool_call_id for DeepSeek compatibility"
|
||||||
|
);
|
||||||
|
assert_eq!(parsed["tool_call_id"], "call_abc123");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// ChatResponse structure
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_response_text_only() {
|
||||||
|
let resp = ChatResponse {
|
||||||
|
text: Some("Hello world".into()),
|
||||||
|
tool_calls: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(resp.text_or_empty(), "Hello world");
|
||||||
|
assert!(!resp.has_tool_calls());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_response_with_tool_calls() {
|
||||||
|
let resp = ChatResponse {
|
||||||
|
text: Some(String::new()),
|
||||||
|
tool_calls: vec![ToolCall {
|
||||||
|
id: "tc_1".into(),
|
||||||
|
name: "echo".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(resp.has_tool_calls());
|
||||||
|
assert_eq!(resp.tool_calls.len(), 1);
|
||||||
|
assert_eq!(resp.tool_calls[0].name, "echo");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_response_text_or_empty_handles_none() {
|
||||||
|
let resp = ChatResponse {
|
||||||
|
text: None,
|
||||||
|
tool_calls: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(resp.text_or_empty(), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_response_multiple_tool_calls() {
|
||||||
|
let resp = ChatResponse {
|
||||||
|
text: None,
|
||||||
|
tool_calls: vec![
|
||||||
|
ToolCall {
|
||||||
|
id: "tc_1".into(),
|
||||||
|
name: "shell".into(),
|
||||||
|
arguments: r#"{"command": "ls"}"#.into(),
|
||||||
|
},
|
||||||
|
ToolCall {
|
||||||
|
id: "tc_2".into(),
|
||||||
|
name: "file_read".into(),
|
||||||
|
arguments: r#"{"path": "test.txt"}"#.into(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(resp.has_tool_calls());
|
||||||
|
assert_eq!(resp.tool_calls.len(), 2);
|
||||||
|
// Each tool call should have a distinct id
|
||||||
|
assert_ne!(resp.tool_calls[0].id, resp.tool_calls[1].id);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// AuthStyle variants
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_style_bearer_is_constructible() {
|
||||||
|
let style = AuthStyle::Bearer;
|
||||||
|
assert!(matches!(style, AuthStyle::Bearer));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_style_xapikey_is_constructible() {
|
||||||
|
let style = AuthStyle::XApiKey;
|
||||||
|
assert!(matches!(style, AuthStyle::XApiKey));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_style_custom_header() {
|
||||||
|
let style = AuthStyle::Custom("X-Custom-Auth".into());
|
||||||
|
if let AuthStyle::Custom(header) = style {
|
||||||
|
assert_eq!(header, "X-Custom-Auth");
|
||||||
|
} else {
|
||||||
|
panic!("expected AuthStyle::Custom");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Provider naming consistency
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_construction_with_different_names() {
|
||||||
|
use zeroclaw::providers::compatible::OpenAiCompatibleProvider;
|
||||||
|
|
||||||
|
// Construction with various names should succeed
|
||||||
|
let _p1 = OpenAiCompatibleProvider::new(
|
||||||
|
"DeepSeek",
|
||||||
|
"https://api.deepseek.com",
|
||||||
|
Some("test-key"),
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
);
|
||||||
|
let _p2 = OpenAiCompatibleProvider::new(
|
||||||
|
"deepseek",
|
||||||
|
"https://api.test.com",
|
||||||
|
None,
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_construction_with_different_auth_styles() {
|
||||||
|
use zeroclaw::providers::compatible::OpenAiCompatibleProvider;
|
||||||
|
|
||||||
|
let _bearer = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Bearer);
|
||||||
|
let _xapi = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::XApiKey);
|
||||||
|
let _custom = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Custom("X-My-Auth".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Conversation history message ordering
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_messages_maintain_role_sequence() {
|
||||||
|
let history = vec![
|
||||||
|
ChatMessage::system("You are helpful"),
|
||||||
|
ChatMessage::user("What is Rust?"),
|
||||||
|
ChatMessage::assistant("Rust is a systems programming language"),
|
||||||
|
ChatMessage::user("Tell me more"),
|
||||||
|
ChatMessage::assistant("It emphasizes safety and performance"),
|
||||||
|
];
|
||||||
|
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
assert_eq!(history[1].role, "user");
|
||||||
|
assert_eq!(history[2].role, "assistant");
|
||||||
|
assert_eq!(history[3].role, "user");
|
||||||
|
assert_eq!(history[4].role, "assistant");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_messages_with_tool_calls_maintain_sequence() {
|
||||||
|
let history = vec![
|
||||||
|
ChatMessage::system("You are helpful"),
|
||||||
|
ChatMessage::user("Search for Rust"),
|
||||||
|
ChatMessage::assistant("I'll search for that"),
|
||||||
|
ChatMessage::tool(r#"{"tool_call_id": "tc_1", "content": "search results"}"#),
|
||||||
|
ChatMessage::assistant("Based on the search results..."),
|
||||||
|
];
|
||||||
|
|
||||||
|
assert_eq!(history.len(), 5);
|
||||||
|
assert_eq!(history[3].role, "tool");
|
||||||
|
assert_eq!(history[4].role, "assistant");
|
||||||
|
|
||||||
|
// Verify tool message content is valid JSON with tool_call_id
|
||||||
|
let tool_content: serde_json::Value = serde_json::from_str(&history[3].content).unwrap();
|
||||||
|
assert!(tool_content.get("tool_call_id").is_some());
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue