fix(channels): interrupt in-flight telegram requests on newer sender messages
This commit is contained in:
parent
d9a94fc763
commit
ef82c7dbcd
17 changed files with 669 additions and 115 deletions
|
|
@ -128,7 +128,12 @@ struct CountingTool {
|
|||
impl CountingTool {
|
||||
fn new() -> (Self, Arc<Mutex<usize>>) {
|
||||
let count = Arc::new(Mutex::new(0));
|
||||
(Self { count: count.clone() }, count)
|
||||
(
|
||||
Self {
|
||||
count: count.clone(),
|
||||
},
|
||||
count,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -295,10 +300,7 @@ async fn agent_handles_mixed_tool_success_and_failure() {
|
|||
text_response("Mixed results processed"),
|
||||
]));
|
||||
|
||||
let mut agent = build_agent(
|
||||
provider,
|
||||
vec![Box::new(EchoTool), Box::new(FailingTool)],
|
||||
);
|
||||
let mut agent = build_agent(provider, vec![Box::new(EchoTool), Box::new(FailingTool)]);
|
||||
let response = agent.turn("mixed tools").await.unwrap();
|
||||
assert!(!response.is_empty());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ fn channel_message_sender_field_holds_platform_user_id() {
|
|||
content: "test message".into(),
|
||||
channel: "telegram".into(),
|
||||
timestamp: 1700000000,
|
||||
thread_ts: None,
|
||||
};
|
||||
|
||||
assert_eq!(msg.sender, "123456789");
|
||||
|
|
@ -40,11 +41,12 @@ fn channel_message_reply_target_distinct_from_sender() {
|
|||
// Simulates Discord: reply_target should be channel_id, not sender user_id
|
||||
let msg = ChannelMessage {
|
||||
id: "msg_1".into(),
|
||||
sender: "user_987654".into(), // Discord user ID
|
||||
sender: "user_987654".into(), // Discord user ID
|
||||
reply_target: "channel_123".into(), // Discord channel ID for replies
|
||||
content: "test message".into(),
|
||||
channel: "discord".into(),
|
||||
timestamp: 1700000000,
|
||||
thread_ts: None,
|
||||
};
|
||||
|
||||
assert_ne!(
|
||||
|
|
@ -64,9 +66,13 @@ fn channel_message_fields_not_swapped() {
|
|||
content: "payload".into(),
|
||||
channel: "test".into(),
|
||||
timestamp: 1700000000,
|
||||
thread_ts: None,
|
||||
};
|
||||
|
||||
assert_eq!(msg.sender, "sender_value", "sender field should not be swapped");
|
||||
assert_eq!(
|
||||
msg.sender, "sender_value",
|
||||
"sender field should not be swapped"
|
||||
);
|
||||
assert_eq!(
|
||||
msg.reply_target, "target_value",
|
||||
"reply_target field should not be swapped"
|
||||
|
|
@ -86,6 +92,7 @@ fn channel_message_preserves_all_fields_on_clone() {
|
|||
content: "cloned content".into(),
|
||||
channel: "test_channel".into(),
|
||||
timestamp: 1700000001,
|
||||
thread_ts: None,
|
||||
};
|
||||
|
||||
let cloned = original.clone();
|
||||
|
|
@ -170,10 +177,7 @@ impl Channel for CapturingChannel {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn listen(
|
||||
&self,
|
||||
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
tx.send(ChannelMessage {
|
||||
id: "listen_1".into(),
|
||||
sender: "test_sender".into(),
|
||||
|
|
@ -181,6 +185,7 @@ impl Channel for CapturingChannel {
|
|||
content: "incoming".into(),
|
||||
channel: "capturing".into(),
|
||||
timestamp: 1700000000,
|
||||
thread_ts: None,
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!(e.to_string()))
|
||||
|
|
@ -266,7 +271,10 @@ async fn channel_draft_defaults() {
|
|||
.send_draft(&SendMessage::new("draft", "target"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(draft_result.is_none(), "default send_draft should return None");
|
||||
assert!(
|
||||
draft_result.is_none(),
|
||||
"default send_draft should return None"
|
||||
);
|
||||
|
||||
assert!(channel
|
||||
.update_draft("target", "msg_1", "updated")
|
||||
|
|
|
|||
|
|
@ -232,7 +232,10 @@ fn workspace_dir_creation_in_tempdir() {
|
|||
|
||||
fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed");
|
||||
assert!(workspace_dir.exists(), "workspace dir should exist");
|
||||
assert!(workspace_dir.is_dir(), "workspace path should be a directory");
|
||||
assert!(
|
||||
workspace_dir.is_dir(),
|
||||
"workspace path should be a directory"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -29,10 +29,17 @@ async fn sqlite_memory_store_same_key_deduplicates() {
|
|||
|
||||
// Should have exactly 1 entry, not 2
|
||||
let count = mem.count().await.unwrap();
|
||||
assert_eq!(count, 1, "storing same key twice should not create duplicates");
|
||||
assert_eq!(
|
||||
count, 1,
|
||||
"storing same key twice should not create duplicates"
|
||||
);
|
||||
|
||||
// Content should be the latest version
|
||||
let entry = mem.get("greeting").await.unwrap().expect("entry should exist");
|
||||
let entry = mem
|
||||
.get("greeting")
|
||||
.await
|
||||
.unwrap()
|
||||
.expect("entry should exist");
|
||||
assert_eq!(entry.content, "hello updated");
|
||||
}
|
||||
|
||||
|
|
@ -63,9 +70,14 @@ async fn sqlite_memory_persists_across_reinitialization() {
|
|||
// First "session": store data
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("persistent_fact", "Rust is great", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"persistent_fact",
|
||||
"Rust is great",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Second "session": re-create memory from same path
|
||||
|
|
@ -158,16 +170,24 @@ async fn sqlite_memory_global_recall_includes_all_sessions() {
|
|||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
||||
mem.store("global_a", "alpha content", MemoryCategory::Core, Some("s1"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"global_a",
|
||||
"alpha content",
|
||||
MemoryCategory::Core,
|
||||
Some("s1"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2"))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Global count should include all
|
||||
let count = mem.count().await.unwrap();
|
||||
assert_eq!(count, 2, "global count should include entries from all sessions");
|
||||
assert_eq!(
|
||||
count, 2,
|
||||
"global count should include entries from all sessions"
|
||||
);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
|
@ -179,12 +199,22 @@ async fn sqlite_memory_recall_returns_relevant_results() {
|
|||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
||||
mem.store("lang_pref", "User prefers Rust programming", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("food_pref", "User likes sushi for lunch", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"lang_pref",
|
||||
"User prefers Rust programming",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"food_pref",
|
||||
"User likes sushi for lunch",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust programming", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty(), "recall should find matching entries");
|
||||
|
|
@ -229,10 +259,7 @@ async fn sqlite_memory_recall_empty_query_returns_empty() {
|
|||
.unwrap();
|
||||
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"empty query should return no results"
|
||||
);
|
||||
assert!(results.is_empty(), "empty query should return no results");
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
|
@ -322,9 +349,14 @@ async fn sqlite_memory_list_by_category() {
|
|||
mem.store("daily_note", "daily note", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("conv_msg", "conversation msg", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"conv_msg",
|
||||
"conversation msg",
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core_entries = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert_eq!(core_entries.len(), 1, "should have 1 Core entry");
|
||||
|
|
|
|||
|
|
@ -80,7 +80,10 @@ fn tool_call_has_required_fields() {
|
|||
|
||||
let json = serde_json::to_value(&tc).unwrap();
|
||||
assert!(json.get("id").is_some(), "ToolCall must have 'id' field");
|
||||
assert!(json.get("name").is_some(), "ToolCall must have 'name' field");
|
||||
assert!(
|
||||
json.get("name").is_some(),
|
||||
"ToolCall must have 'name' field"
|
||||
);
|
||||
assert!(
|
||||
json.get("arguments").is_some(),
|
||||
"ToolCall must have 'arguments' field"
|
||||
|
|
@ -98,7 +101,10 @@ fn tool_call_id_preserved_in_serialization() {
|
|||
let json_str = serde_json::to_string(&tc).unwrap();
|
||||
let parsed: ToolCall = serde_json::from_str(&json_str).unwrap();
|
||||
|
||||
assert_eq!(parsed.id, "call_deepseek_42", "tool_call_id must survive roundtrip");
|
||||
assert_eq!(
|
||||
parsed.id, "call_deepseek_42",
|
||||
"tool_call_id must survive roundtrip"
|
||||
);
|
||||
assert_eq!(parsed.name, "shell");
|
||||
}
|
||||
|
||||
|
|
@ -111,8 +117,8 @@ fn tool_call_arguments_contain_valid_json() {
|
|||
};
|
||||
|
||||
// Arguments should parse as valid JSON
|
||||
let args: serde_json::Value = serde_json::from_str(&tc.arguments)
|
||||
.expect("tool call arguments should be valid JSON");
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tc.arguments).expect("tool call arguments should be valid JSON");
|
||||
assert!(args.get("path").is_some());
|
||||
assert!(args.get("content").is_some());
|
||||
}
|
||||
|
|
@ -125,9 +131,8 @@ fn tool_call_arguments_contain_valid_json() {
|
|||
fn tool_response_message_can_embed_tool_call_id() {
|
||||
// DeepSeek requires tool_call_id in tool response messages.
|
||||
// The tool message content can embed the tool_call_id as JSON.
|
||||
let tool_response = ChatMessage::tool(
|
||||
r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#,
|
||||
);
|
||||
let tool_response =
|
||||
ChatMessage::tool(r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#);
|
||||
|
||||
let parsed: serde_json::Value = serde_json::from_str(&tool_response.content)
|
||||
.expect("tool response content should be valid JSON");
|
||||
|
|
@ -245,21 +250,32 @@ fn provider_construction_with_different_names() {
|
|||
Some("test-key"),
|
||||
AuthStyle::Bearer,
|
||||
);
|
||||
let _p2 = OpenAiCompatibleProvider::new(
|
||||
"deepseek",
|
||||
"https://api.test.com",
|
||||
None,
|
||||
AuthStyle::Bearer,
|
||||
);
|
||||
let _p2 =
|
||||
OpenAiCompatibleProvider::new("deepseek", "https://api.test.com", None, AuthStyle::Bearer);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_construction_with_different_auth_styles() {
|
||||
use zeroclaw::providers::compatible::OpenAiCompatibleProvider;
|
||||
|
||||
let _bearer = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Bearer);
|
||||
let _xapi = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::XApiKey);
|
||||
let _custom = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Custom("X-My-Auth".into()));
|
||||
let _bearer = OpenAiCompatibleProvider::new(
|
||||
"Test",
|
||||
"https://api.test.com",
|
||||
Some("key"),
|
||||
AuthStyle::Bearer,
|
||||
);
|
||||
let _xapi = OpenAiCompatibleProvider::new(
|
||||
"Test",
|
||||
"https://api.test.com",
|
||||
Some("key"),
|
||||
AuthStyle::XApiKey,
|
||||
);
|
||||
let _custom = OpenAiCompatibleProvider::new(
|
||||
"Test",
|
||||
"https://api.test.com",
|
||||
Some("key"),
|
||||
AuthStyle::Custom("X-My-Auth".into()),
|
||||
);
|
||||
}
|
||||
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue