fix(channels): interrupt in-flight telegram requests on newer sender messages

This commit is contained in:
Chummy 2026-02-20 01:26:38 +08:00
parent d9a94fc763
commit ef82c7dbcd
17 changed files with 669 additions and 115 deletions

View file

@ -128,7 +128,12 @@ struct CountingTool {
impl CountingTool {
fn new() -> (Self, Arc<Mutex<usize>>) {
let count = Arc::new(Mutex::new(0));
(Self { count: count.clone() }, count)
(
Self {
count: count.clone(),
},
count,
)
}
}
@ -295,10 +300,7 @@ async fn agent_handles_mixed_tool_success_and_failure() {
text_response("Mixed results processed"),
]));
let mut agent = build_agent(
provider,
vec![Box::new(EchoTool), Box::new(FailingTool)],
);
let mut agent = build_agent(provider, vec![Box::new(EchoTool), Box::new(FailingTool)]);
let response = agent.turn("mixed tools").await.unwrap();
assert!(!response.is_empty());
}

View file

@ -24,6 +24,7 @@ fn channel_message_sender_field_holds_platform_user_id() {
content: "test message".into(),
channel: "telegram".into(),
timestamp: 1700000000,
thread_ts: None,
};
assert_eq!(msg.sender, "123456789");
@ -40,11 +41,12 @@ fn channel_message_reply_target_distinct_from_sender() {
// Simulates Discord: reply_target should be channel_id, not sender user_id
let msg = ChannelMessage {
id: "msg_1".into(),
sender: "user_987654".into(), // Discord user ID
sender: "user_987654".into(), // Discord user ID
reply_target: "channel_123".into(), // Discord channel ID for replies
content: "test message".into(),
channel: "discord".into(),
timestamp: 1700000000,
thread_ts: None,
};
assert_ne!(
@ -64,9 +66,13 @@ fn channel_message_fields_not_swapped() {
content: "payload".into(),
channel: "test".into(),
timestamp: 1700000000,
thread_ts: None,
};
assert_eq!(msg.sender, "sender_value", "sender field should not be swapped");
assert_eq!(
msg.sender, "sender_value",
"sender field should not be swapped"
);
assert_eq!(
msg.reply_target, "target_value",
"reply_target field should not be swapped"
@ -86,6 +92,7 @@ fn channel_message_preserves_all_fields_on_clone() {
content: "cloned content".into(),
channel: "test_channel".into(),
timestamp: 1700000001,
thread_ts: None,
};
let cloned = original.clone();
@ -170,10 +177,7 @@ impl Channel for CapturingChannel {
Ok(())
}
async fn listen(
&self,
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
) -> anyhow::Result<()> {
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
tx.send(ChannelMessage {
id: "listen_1".into(),
sender: "test_sender".into(),
@ -181,6 +185,7 @@ impl Channel for CapturingChannel {
content: "incoming".into(),
channel: "capturing".into(),
timestamp: 1700000000,
thread_ts: None,
})
.await
.map_err(|e| anyhow::anyhow!(e.to_string()))
@ -266,7 +271,10 @@ async fn channel_draft_defaults() {
.send_draft(&SendMessage::new("draft", "target"))
.await
.unwrap();
assert!(draft_result.is_none(), "default send_draft should return None");
assert!(
draft_result.is_none(),
"default send_draft should return None"
);
assert!(channel
.update_draft("target", "msg_1", "updated")

View file

@ -232,7 +232,10 @@ fn workspace_dir_creation_in_tempdir() {
fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed");
assert!(workspace_dir.exists(), "workspace dir should exist");
assert!(workspace_dir.is_dir(), "workspace path should be a directory");
assert!(
workspace_dir.is_dir(),
"workspace path should be a directory"
);
}
#[test]

View file

@ -29,10 +29,17 @@ async fn sqlite_memory_store_same_key_deduplicates() {
// Should have exactly 1 entry, not 2
let count = mem.count().await.unwrap();
assert_eq!(count, 1, "storing same key twice should not create duplicates");
assert_eq!(
count, 1,
"storing same key twice should not create duplicates"
);
// Content should be the latest version
let entry = mem.get("greeting").await.unwrap().expect("entry should exist");
let entry = mem
.get("greeting")
.await
.unwrap()
.expect("entry should exist");
assert_eq!(entry.content, "hello updated");
}
@ -63,9 +70,14 @@ async fn sqlite_memory_persists_across_reinitialization() {
// First "session": store data
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("persistent_fact", "Rust is great", MemoryCategory::Core, None)
.await
.unwrap();
mem.store(
"persistent_fact",
"Rust is great",
MemoryCategory::Core,
None,
)
.await
.unwrap();
}
// Second "session": re-create memory from same path
@ -158,16 +170,24 @@ async fn sqlite_memory_global_recall_includes_all_sessions() {
let tmp = tempfile::TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("global_a", "alpha content", MemoryCategory::Core, Some("s1"))
.await
.unwrap();
mem.store(
"global_a",
"alpha content",
MemoryCategory::Core,
Some("s1"),
)
.await
.unwrap();
mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2"))
.await
.unwrap();
// Global count should include all
let count = mem.count().await.unwrap();
assert_eq!(count, 2, "global count should include entries from all sessions");
assert_eq!(
count, 2,
"global count should include entries from all sessions"
);
}
// ─────────────────────────────────────────────────────────────────────────────
@ -179,12 +199,22 @@ async fn sqlite_memory_recall_returns_relevant_results() {
let tmp = tempfile::TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("lang_pref", "User prefers Rust programming", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("food_pref", "User likes sushi for lunch", MemoryCategory::Core, None)
.await
.unwrap();
mem.store(
"lang_pref",
"User prefers Rust programming",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store(
"food_pref",
"User likes sushi for lunch",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let results = mem.recall("Rust programming", 10, None).await.unwrap();
assert!(!results.is_empty(), "recall should find matching entries");
@ -229,10 +259,7 @@ async fn sqlite_memory_recall_empty_query_returns_empty() {
.unwrap();
let results = mem.recall("", 10, None).await.unwrap();
assert!(
results.is_empty(),
"empty query should return no results"
);
assert!(results.is_empty(), "empty query should return no results");
}
// ─────────────────────────────────────────────────────────────────────────────
@ -322,9 +349,14 @@ async fn sqlite_memory_list_by_category() {
mem.store("daily_note", "daily note", MemoryCategory::Daily, None)
.await
.unwrap();
mem.store("conv_msg", "conversation msg", MemoryCategory::Conversation, None)
.await
.unwrap();
mem.store(
"conv_msg",
"conversation msg",
MemoryCategory::Conversation,
None,
)
.await
.unwrap();
let core_entries = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert_eq!(core_entries.len(), 1, "should have 1 Core entry");

View file

@ -80,7 +80,10 @@ fn tool_call_has_required_fields() {
let json = serde_json::to_value(&tc).unwrap();
assert!(json.get("id").is_some(), "ToolCall must have 'id' field");
assert!(json.get("name").is_some(), "ToolCall must have 'name' field");
assert!(
json.get("name").is_some(),
"ToolCall must have 'name' field"
);
assert!(
json.get("arguments").is_some(),
"ToolCall must have 'arguments' field"
@ -98,7 +101,10 @@ fn tool_call_id_preserved_in_serialization() {
let json_str = serde_json::to_string(&tc).unwrap();
let parsed: ToolCall = serde_json::from_str(&json_str).unwrap();
assert_eq!(parsed.id, "call_deepseek_42", "tool_call_id must survive roundtrip");
assert_eq!(
parsed.id, "call_deepseek_42",
"tool_call_id must survive roundtrip"
);
assert_eq!(parsed.name, "shell");
}
@ -111,8 +117,8 @@ fn tool_call_arguments_contain_valid_json() {
};
// Arguments should parse as valid JSON
let args: serde_json::Value = serde_json::from_str(&tc.arguments)
.expect("tool call arguments should be valid JSON");
let args: serde_json::Value =
serde_json::from_str(&tc.arguments).expect("tool call arguments should be valid JSON");
assert!(args.get("path").is_some());
assert!(args.get("content").is_some());
}
@ -125,9 +131,8 @@ fn tool_call_arguments_contain_valid_json() {
fn tool_response_message_can_embed_tool_call_id() {
// DeepSeek requires tool_call_id in tool response messages.
// The tool message content can embed the tool_call_id as JSON.
let tool_response = ChatMessage::tool(
r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#,
);
let tool_response =
ChatMessage::tool(r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#);
let parsed: serde_json::Value = serde_json::from_str(&tool_response.content)
.expect("tool response content should be valid JSON");
@ -245,21 +250,32 @@ fn provider_construction_with_different_names() {
Some("test-key"),
AuthStyle::Bearer,
);
let _p2 = OpenAiCompatibleProvider::new(
"deepseek",
"https://api.test.com",
None,
AuthStyle::Bearer,
);
let _p2 =
OpenAiCompatibleProvider::new("deepseek", "https://api.test.com", None, AuthStyle::Bearer);
}
#[test]
fn provider_construction_with_different_auth_styles() {
use zeroclaw::providers::compatible::OpenAiCompatibleProvider;
let _bearer = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Bearer);
let _xapi = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::XApiKey);
let _custom = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Custom("X-My-Auth".into()));
let _bearer = OpenAiCompatibleProvider::new(
"Test",
"https://api.test.com",
Some("key"),
AuthStyle::Bearer,
);
let _xapi = OpenAiCompatibleProvider::new(
"Test",
"https://api.test.com",
Some("key"),
AuthStyle::XApiKey,
);
let _custom = OpenAiCompatibleProvider::new(
"Test",
"https://api.test.com",
Some("key"),
AuthStyle::Custom("X-My-Auth".into()),
);
}
// ─────────────────────────────────────────────────────────────────────────────