diff --git a/docs/channels-reference.md b/docs/channels-reference.md index 61d0604..49defc1 100644 --- a/docs/channels-reference.md +++ b/docs/channels-reference.md @@ -138,8 +138,17 @@ Field names differ by channel: [channels_config.telegram] bot_token = "123456:telegram-token" allowed_users = ["*"] +stream_mode = "off" # optional: off | partial +draft_update_interval_ms = 1000 # optional: edit throttle for partial streaming +mention_only = false # optional: require @mention in groups +interrupt_on_new_message = false # optional: cancel in-flight same-sender same-chat request ``` +Telegram notes: + +- `interrupt_on_new_message = true` preserves interrupted user turns in conversation history, then restarts generation on the newest message. +- Interruption scope is strict: same sender in the same chat. Messages from different chats are processed independently. + ### 4.2 Discord ```toml diff --git a/docs/config-reference.md b/docs/config-reference.md index 2e8278e..47aa0bc 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -188,6 +188,8 @@ Notes: - If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower. - Values below `30` are clamped to `30` to avoid immediate timeout churn. - When a timeout occurs, users receive: `⚠️ Request timed out while waiting for the model. Please try again.` +- Telegram-only interruption behavior is controlled with `channels_config.telegram.interrupt_on_new_message` (default `false`). + When enabled, a newer message from the same sender in the same chat cancels the in-flight request and preserves interrupted user context. See detailed channel matrix and allowlist behavior in [channels-reference.md](channels-reference.md). diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index f5402b0..79e52ba 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -16,6 +16,7 @@ use std::fmt::Write; use std::io::Write as _; use std::sync::{Arc, LazyLock}; use std::time::Instant; +use tokio_util::sync::CancellationToken; use uuid::Uuid; /// Minimum characters per chunk when relaying LLM text to a streaming draft. @@ -823,6 +824,21 @@ struct ParsedToolCall { arguments: serde_json::Value, } +#[derive(Debug)] +pub(crate) struct ToolLoopCancelled; + +impl std::fmt::Display for ToolLoopCancelled { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("tool loop cancelled") + } +} + +impl std::error::Error for ToolLoopCancelled {} + +pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool { + err.chain().any(|source| source.is::()) +} + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. /// When `silent` is true, suppresses stdout (for channel use). @@ -853,6 +869,7 @@ pub(crate) async fn agent_turn( multimodal_config, max_tool_iterations, None, + None, ) .await } @@ -873,6 +890,7 @@ pub(crate) async fn run_tool_call_loop( channel_name: &str, multimodal_config: &crate::config::MultimodalConfig, max_tool_iterations: usize, + cancellation_token: Option, on_delta: Option>, ) -> Result { let max_iterations = if max_tool_iterations == 0 { @@ -886,6 +904,13 @@ pub(crate) async fn run_tool_call_loop( let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty(); for _iteration in 0..max_iterations { + if cancellation_token + .as_ref() + .is_some_and(CancellationToken::is_cancelled) + { + return Err(ToolLoopCancelled.into()); + } + let image_marker_count = multimodal::count_image_markers(history); if image_marker_count > 0 && !provider.supports_vision() { return Err(ProviderCapabilityError { @@ -917,18 +942,26 @@ pub(crate) async fn run_tool_call_loop( None }; + let chat_future = provider.chat( + ChatRequest { + messages: &prepared_messages.messages, + tools: request_tools, + }, + model, + temperature, + ); + + let chat_result = if let Some(token) = cancellation_token.as_ref() { + tokio::select! { + () = token.cancelled() => return Err(ToolLoopCancelled.into()), + result = chat_future => result, + } + } else { + chat_future.await + }; + let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) = - match provider - .chat( - ChatRequest { - messages: &prepared_messages.messages, - tools: request_tools, - }, - model, - temperature, - ) - .await - { + match chat_result { Ok(resp) => { observer.record_event(&ObserverEvent::LlmResponse { provider: provider_name.to_string(), @@ -994,6 +1027,12 @@ pub(crate) async fn run_tool_call_loop( // STREAM_CHUNK_MIN_CHARS characters for progressive draft updates. let mut chunk = String::new(); for word in display_text.split_inclusive(char::is_whitespace) { + if cancellation_token + .as_ref() + .is_some_and(CancellationToken::is_cancelled) + { + return Err(ToolLoopCancelled.into()); + } chunk.push_str(word); if chunk.len() >= STREAM_CHUNK_MIN_CHARS && tx.send(std::mem::take(&mut chunk)).await.is_err() @@ -1056,7 +1095,17 @@ pub(crate) async fn run_tool_call_loop( }); let start = Instant::now(); let result = if let Some(tool) = find_tool(tools_registry, &call.name) { - match tool.execute(call.arguments.clone()).await { + let tool_future = tool.execute(call.arguments.clone()); + let tool_result = if let Some(token) = cancellation_token.as_ref() { + tokio::select! { + () = token.cancelled() => return Err(ToolLoopCancelled.into()), + result = tool_future => result, + } + } else { + tool_future.await + }; + + match tool_result { Ok(r) => { observer.record_event(&ObserverEvent::ToolCall { tool: call.name.clone(), @@ -1435,6 +1484,7 @@ pub async fn run( &config.multimodal, config.agent.max_tool_iterations, None, + None, ) .await?; final_output = response.clone(); @@ -1553,6 +1603,7 @@ pub async fn run( &config.multimodal, config.agent.max_tool_iterations, None, + None, ) .await { @@ -1900,6 +1951,7 @@ mod tests { &crate::config::MultimodalConfig::default(), 3, None, + None, ) .await .expect_err("provider without vision support should fail"); @@ -1943,6 +1995,7 @@ mod tests { &multimodal, 3, None, + None, ) .await .expect_err("oversized payload must fail"); @@ -1980,6 +2033,7 @@ mod tests { &crate::config::MultimodalConfig::default(), 3, None, + None, ) .await .expect("valid multimodal payload should pass"); @@ -2809,7 +2863,10 @@ browser_open/url>https://example.com"#; fn parse_tool_calls_closing_tag_only_returns_text() { let response = "Some text more text"; let (text, calls) = parse_tool_calls(response); - assert!(calls.is_empty(), "closing tag only should not produce calls"); + assert!( + calls.is_empty(), + "closing tag only should not produce calls" + ); assert!( !text.is_empty(), "text around orphaned closing tag should be preserved" @@ -2858,7 +2915,11 @@ browser_open/url>https://example.com"#; Let me check the result."#; let (text, calls) = parse_tool_calls(response); - assert_eq!(calls.len(), 1, "should extract one tool call from mixed content"); + assert_eq!( + calls.len(), + 1, + "should extract one tool call from mixed content" + ); assert_eq!(calls[0].name, "shell"); assert!( text.contains("help you"), @@ -2880,7 +2941,10 @@ Let me check the result."#; fn scrub_credentials_no_sensitive_data() { let input = "normal text without any secrets"; let result = scrub_credentials(input); - assert_eq!(result, input, "non-sensitive text should pass through unchanged"); + assert_eq!( + result, input, + "non-sensitive text should pass through unchanged" + ); } #[test] diff --git a/src/channels/discord.rs b/src/channels/discord.rs index b5226f3..3ae69ba 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -857,7 +857,10 @@ mod tests { msg.push_str(&"x".repeat(1990)); msg.push_str("\n```\nMore text after code block"); let parts = split_message_for_discord(&msg); - assert!(parts.len() >= 2, "code block spanning boundary should split"); + assert!( + parts.len() >= 2, + "code block spanning boundary should split" + ); for part in &parts { assert!( part.len() <= DISCORD_MAX_MESSAGE_LENGTH, diff --git a/src/channels/mod.rs b/src/channels/mod.rs index a23356f..80a82ae 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -56,6 +56,7 @@ use std::collections::HashMap; use std::fmt::Write; use std::path::{Path, PathBuf}; use std::process::Command; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tokio_util::sync::CancellationToken; @@ -141,9 +142,43 @@ struct ChannelRuntimeContext { provider_runtime_options: providers::ProviderRuntimeOptions, workspace_dir: Arc, message_timeout_secs: u64, + interrupt_on_new_message: bool, multimodal: crate::config::MultimodalConfig, } +#[derive(Clone)] +struct InFlightSenderTaskState { + task_id: u64, + cancellation: CancellationToken, + completion: Arc, +} + +struct InFlightTaskCompletion { + done: AtomicBool, + notify: tokio::sync::Notify, +} + +impl InFlightTaskCompletion { + fn new() -> Self { + Self { + done: AtomicBool::new(false), + notify: tokio::sync::Notify::new(), + } + } + + fn mark_done(&self) { + self.done.store(true, Ordering::Release); + self.notify.notify_waiters(); + } + + async fn wait(&self) { + if self.done.load(Ordering::Acquire) { + return; + } + self.notify.notified().await; + } +} + fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { format!("{}_{}_{}", msg.channel, msg.sender, msg.id) } @@ -152,6 +187,10 @@ fn conversation_history_key(msg: &traits::ChannelMessage) -> String { format!("{}_{}", msg.channel, msg.sender) } +fn interruption_scope_key(msg: &traits::ChannelMessage) -> String { + format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender) +} + fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { match channel_name { "telegram" => Some( @@ -292,6 +331,18 @@ fn compact_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) -> bool true } +fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) { + let mut histories = ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()); + let turns = histories.entry(sender_key.to_string()).or_default(); + turns.push(turn); + while turns.len() > MAX_CHANNEL_HISTORY { + turns.remove(0); + } +} + fn should_skip_memory_context_entry(key: &str, content: &str) -> bool { if memory::is_assistant_autosave_key(key) { return true; @@ -657,7 +708,15 @@ fn spawn_scoped_typing_task( handle } -async fn process_channel_message(ctx: Arc, msg: traits::ChannelMessage) { +async fn process_channel_message( + ctx: Arc, + msg: traits::ChannelMessage, + cancellation_token: CancellationToken, +) { + if cancellation_token.is_cancelled() { + return; + } + println!( " 💬 [{}] from {}: {}", msg.channel, @@ -717,7 +776,13 @@ async fn process_channel_message(ctx: Arc, msg: traits::C println!(" ⏳ Processing message..."); let started_at = Instant::now(); - // Build history from per-sender conversation cache + // Preserve user turn before the LLM call so interrupted requests keep context. + append_sender_turn( + ctx.as_ref(), + &history_key, + ChatMessage::user(&enriched_message), + ); + let mut prior_turns = ctx .conversation_histories .lock() @@ -728,18 +793,15 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let mut history = vec![ChatMessage::system(ctx.system_prompt.as_str())]; history.append(&mut prior_turns); - history.push(ChatMessage::user(&enriched_message)); if let Some(instructions) = channel_delivery_instructions(&msg.channel) { history.push(ChatMessage::system(instructions)); } - // Determine if this channel supports streaming draft updates let use_streaming = target_channel .as_ref() - .map_or(false, |ch| ch.supports_draft_updates()); + .is_some_and(|ch| ch.supports_draft_updates()); - // Set up streaming channel if supported let (delta_tx, delta_rx) = if use_streaming { let (tx, rx) = tokio::sync::mpsc::channel::(64); (Some(tx), Some(rx)) @@ -747,7 +809,6 @@ async fn process_channel_message(ctx: Arc, msg: traits::C (None, None) }; - // Send initial draft message if streaming let draft_message_id = if use_streaming { if let Some(channel) = target_channel.as_ref() { match channel @@ -769,7 +830,6 @@ async fn process_channel_message(ctx: Arc, msg: traits::C None }; - // Spawn a task to forward streaming deltas to draft updates let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = ( delta_rx, draft_message_id.as_deref(), @@ -804,27 +864,34 @@ async fn process_channel_message(ctx: Arc, msg: traits::C _ => None, }; - let llm_result = tokio::time::timeout( - Duration::from_secs(ctx.message_timeout_secs), - run_tool_call_loop( - active_provider.as_ref(), - &mut history, - ctx.tools_registry.as_ref(), - ctx.observer.as_ref(), - route.provider.as_str(), - route.model.as_str(), - ctx.temperature, - true, - None, - msg.channel.as_str(), - &ctx.multimodal, - ctx.max_tool_iterations, - delta_tx, - ), - ) - .await; + enum LlmExecutionResult { + Completed(Result, tokio::time::error::Elapsed>), + Cancelled, + } + + let llm_result = tokio::select! { + () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, + result = tokio::time::timeout( + Duration::from_secs(ctx.message_timeout_secs), + run_tool_call_loop( + active_provider.as_ref(), + &mut history, + ctx.tools_registry.as_ref(), + ctx.observer.as_ref(), + route.provider.as_str(), + route.model.as_str(), + ctx.temperature, + true, + None, + msg.channel.as_str(), + &ctx.multimodal, + ctx.max_tool_iterations, + Some(cancellation_token.clone()), + delta_tx, + ), + ) => LlmExecutionResult::Completed(result), + }; - // Wait for draft updater to finish if let Some(handle) = draft_updater { let _ = handle.await; } @@ -837,21 +904,26 @@ async fn process_channel_message(ctx: Arc, msg: traits::C } match llm_result { - Ok(Ok(response)) => { - // Save user + assistant turn to per-sender history + LlmExecutionResult::Cancelled => { + tracing::info!( + channel = %msg.channel, + sender = %msg.sender, + "Cancelled in-flight channel request due to newer message" + ); + if let (Some(channel), Some(draft_id)) = + (target_channel.as_ref(), draft_message_id.as_deref()) { - let mut histories = ctx - .conversation_histories - .lock() - .unwrap_or_else(|e| e.into_inner()); - let turns = histories.entry(history_key).or_default(); - turns.push(ChatMessage::user(&enriched_message)); - turns.push(ChatMessage::assistant(&response)); - // Trim to MAX_CHANNEL_HISTORY (keep recent turns) - while turns.len() > MAX_CHANNEL_HISTORY { - turns.remove(0); + if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await { + tracing::debug!("Failed to cancel draft on {}: {err}", channel.name()); } } + } + LlmExecutionResult::Completed(Ok(Ok(response))) => { + append_sender_turn( + ctx.as_ref(), + &history_key, + ChatMessage::assistant(&response), + ); println!( " 🤖 Reply ({}ms): {}", started_at.elapsed().as_millis(), @@ -882,7 +954,24 @@ async fn process_channel_message(ctx: Arc, msg: traits::C } } } - Ok(Err(e)) => { + LlmExecutionResult::Completed(Ok(Err(e))) => { + if crate::agent::loop_::is_tool_loop_cancelled(&e) || cancellation_token.is_cancelled() + { + tracing::info!( + channel = %msg.channel, + sender = %msg.sender, + "Cancelled in-flight channel request due to newer message" + ); + if let (Some(channel), Some(draft_id)) = + (target_channel.as_ref(), draft_message_id.as_deref()) + { + if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await { + tracing::debug!("Failed to cancel draft on {}: {err}", channel.name()); + } + } + return; + } + if is_context_window_overflow_error(&e) { let compacted = compact_sender_history(ctx.as_ref(), &history_key); let error_text = if compacted { @@ -931,7 +1020,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C } } } - Err(_) => { + LlmExecutionResult::Completed(Err(_)) => { let timeout_msg = format!("LLM response timed out after {}s", ctx.message_timeout_secs); eprintln!( " ❌ {} (elapsed: {}ms)", @@ -965,6 +1054,11 @@ async fn run_message_dispatch_loop( ) { let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages)); let mut workers = tokio::task::JoinSet::new(); + let in_flight_by_sender = Arc::new(tokio::sync::Mutex::new(HashMap::< + String, + InFlightSenderTaskState, + >::new())); + let task_sequence = Arc::new(AtomicU64::new(1)); while let Some(msg) = rx.recv().await { let permit = match Arc::clone(&semaphore).acquire_owned().await { @@ -973,9 +1067,54 @@ async fn run_message_dispatch_loop( }; let worker_ctx = Arc::clone(&ctx); + let in_flight = Arc::clone(&in_flight_by_sender); + let task_sequence = Arc::clone(&task_sequence); workers.spawn(async move { let _permit = permit; - process_channel_message(worker_ctx, msg).await; + let interrupt_enabled = + worker_ctx.interrupt_on_new_message && msg.channel == "telegram"; + let sender_scope_key = interruption_scope_key(&msg); + let cancellation_token = CancellationToken::new(); + let completion = Arc::new(InFlightTaskCompletion::new()); + let task_id = task_sequence.fetch_add(1, Ordering::Relaxed); + + if interrupt_enabled { + let previous = { + let mut active = in_flight.lock().await; + active.insert( + sender_scope_key.clone(), + InFlightSenderTaskState { + task_id, + cancellation: cancellation_token.clone(), + completion: Arc::clone(&completion), + }, + ) + }; + + if let Some(previous) = previous { + tracing::info!( + channel = %msg.channel, + sender = %msg.sender, + "Interrupting previous in-flight request for sender" + ); + previous.cancellation.cancel(); + previous.completion.wait().await; + } + } + + process_channel_message(worker_ctx, msg, cancellation_token).await; + + if interrupt_enabled { + let mut active = in_flight.lock().await; + if active + .get(&sender_scope_key) + .is_some_and(|state| state.task_id == task_id) + { + active.remove(&sender_scope_key); + } + } + + completion.mark_done(); }); while let Some(result) = workers.try_join_next() { @@ -2101,6 +2240,11 @@ pub async fn start_channels(config: Config) -> Result<()> { provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider)); let message_timeout_secs = effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs); + let interrupt_on_new_message = config + .channels_config + .telegram + .as_ref() + .is_some_and(|tg| tg.interrupt_on_new_message); let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name, @@ -2124,6 +2268,7 @@ pub async fn start_channels(config: Config) -> Result<()> { provider_runtime_options, workspace_dir: Arc::new(config.workspace_dir.clone()), message_timeout_secs, + interrupt_on_new_message, multimodal: config.multimodal.clone(), }); @@ -2245,6 +2390,7 @@ mod tests { api_key: None, api_url: None, reliability: Arc::new(crate::config::ReliabilityConfig::default()), + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), @@ -2527,6 +2673,43 @@ mod tests { } } + struct DelayedHistoryCaptureProvider { + delay: Duration, + calls: std::sync::Mutex>>, + } + + #[async_trait::async_trait] + impl Provider for DelayedHistoryCaptureProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("fallback".to_string()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let snapshot = messages + .iter() + .map(|m| (m.role.clone(), m.content.clone())) + .collect::>(); + let call_index = { + let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner()); + calls.push(snapshot); + calls.len() + }; + tokio::time::sleep(self.delay).await; + Ok(format!("response-{call_index}")) + } + } + struct MockPriceTool; #[derive(Default)] @@ -2630,6 +2813,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -2644,6 +2828,7 @@ mod tests { timestamp: 1, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2685,6 +2870,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -2699,6 +2885,7 @@ mod tests { timestamp: 2, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2749,6 +2936,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -2763,6 +2951,7 @@ mod tests { timestamp: 1, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2834,6 +3023,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -2848,6 +3038,7 @@ mod tests { timestamp: 2, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2895,6 +3086,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -2909,6 +3101,7 @@ mod tests { timestamp: 1, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2951,6 +3144,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -2965,6 +3159,7 @@ mod tests { timestamp: 2, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -3058,6 +3253,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -3100,6 +3296,171 @@ mod tests { assert_eq!(sent_messages.len(), 2); } + #[tokio::test] + async fn message_dispatch_interrupts_in_flight_telegram_request_and_preserves_context() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let provider_impl = Arc::new(DelayedHistoryCaptureProvider { + delay: Duration::from_millis(250), + calls: std::sync::Mutex::new(Vec::new()), + }); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: provider_impl.clone(), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: true, + multimodal: crate::config::MultimodalConfig::default(), + }); + + let (tx, rx) = tokio::sync::mpsc::channel::(8); + let send_task = tokio::spawn(async move { + tx.send(traits::ChannelMessage { + id: "msg-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "forwarded content".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(40)).await; + tx.send(traits::ChannelMessage { + id: "msg-2".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "summarize this".to_string(), + channel: "telegram".to_string(), + timestamp: 2, + thread_ts: None, + }) + .await + .unwrap(); + }); + + run_message_dispatch_loop(rx, runtime_ctx, 4).await; + send_task.await.unwrap(); + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 1); + assert!(sent_messages[0].starts_with("chat-1:")); + assert!(sent_messages[0].contains("response-2")); + drop(sent_messages); + + let calls = provider_impl + .calls + .lock() + .unwrap_or_else(|e| e.into_inner()); + assert_eq!(calls.len(), 2); + let second_call = &calls[1]; + assert!(second_call + .iter() + .any(|(role, content)| { role == "user" && content.contains("forwarded content") })); + assert!(second_call + .iter() + .any(|(role, content)| { role == "user" && content.contains("summarize this") })); + assert!( + !second_call.iter().any(|(role, _)| role == "assistant"), + "cancelled turn should not persist an assistant response" + ); + } + + #[tokio::test] + async fn message_dispatch_interrupt_scope_is_same_sender_same_chat() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(180), + }), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: true, + multimodal: crate::config::MultimodalConfig::default(), + }); + + let (tx, rx) = tokio::sync::mpsc::channel::(8); + let send_task = tokio::spawn(async move { + tx.send(traits::ChannelMessage { + id: "msg-a".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "first chat".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(30)).await; + tx.send(traits::ChannelMessage { + id: "msg-b".to_string(), + sender: "alice".to_string(), + reply_target: "chat-2".to_string(), + content: "second chat".to_string(), + channel: "telegram".to_string(), + timestamp: 2, + thread_ts: None, + }) + .await + .unwrap(); + }); + + run_message_dispatch_loop(rx, runtime_ctx, 4).await; + send_task.await.unwrap(); + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 2); + assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-1:"))); + assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-2:"))); + } + #[tokio::test] async fn process_channel_message_cancels_scoped_typing_task() { let channel_impl = Arc::new(RecordingChannel::default()); @@ -3132,6 +3493,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -3146,6 +3508,7 @@ mod tests { timestamp: 1, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -3579,6 +3942,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, multimodal: crate::config::MultimodalConfig::default(), }); @@ -3593,6 +3957,7 @@ mod tests { timestamp: 1, thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -3607,6 +3972,7 @@ mod tests { timestamp: 2, thread_ts: None, }, + CancellationToken::new(), ) .await; diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index fa7f130..3954ef5 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -45,10 +45,7 @@ fn split_message_for_telegram(message: &str) -> Vec { pos + 1 } else { // Try space as fallback - search_area - .rfind(' ') - .unwrap_or(hard_split) - + 1 + search_area.rfind(' ').unwrap_or(hard_split) + 1 } } else if let Some(pos) = search_area.rfind(' ') { pos + 1 @@ -1632,6 +1629,37 @@ impl Channel for TelegramChannel { .await } + async fn cancel_draft(&self, recipient: &str, message_id: &str) -> anyhow::Result<()> { + let (chat_id, _) = Self::parse_reply_target(recipient); + self.last_draft_edit.lock().remove(&chat_id); + + let message_id = match message_id.parse::() { + Ok(id) => id, + Err(e) => { + tracing::debug!("Invalid Telegram draft message_id '{message_id}': {e}"); + return Ok(()); + } + }; + + let response = self + .client + .post(self.api_url("deleteMessage")) + .json(&serde_json::json!({ + "chat_id": chat_id, + "message_id": message_id, + })) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + tracing::debug!("Telegram deleteMessage failed ({status}): {body}"); + } + + Ok(()) + } + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { // Strip tool_call tags before processing to prevent Markdown parsing failures let content = strip_tool_call_tags(&message.content); @@ -2844,7 +2872,10 @@ mod tests { msg.push_str(&"x".repeat(4085)); msg.push_str("\n```\nMore text after code block"); let parts = split_message_for_telegram(&msg); - assert!(parts.len() >= 2, "code block spanning boundary should split"); + assert!( + parts.len() >= 2, + "code block spanning boundary should split" + ); for part in &parts { assert!( part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, diff --git a/src/channels/traits.rs b/src/channels/traits.rs index 1f072ca..67546ce 100644 --- a/src/channels/traits.rs +++ b/src/channels/traits.rs @@ -113,6 +113,11 @@ pub trait Channel: Send + Sync { ) -> anyhow::Result<()> { Ok(()) } + + /// Cancel and remove a previously sent draft message if the channel supports it. + async fn cancel_draft(&self, _recipient: &str, _message_id: &str) -> anyhow::Result<()> { + Ok(()) + } } #[cfg(test)] @@ -198,6 +203,7 @@ mod tests { .finalize_draft("bob", "msg_1", "final text") .await .is_ok()); + assert!(channel.cancel_draft("bob", "msg_1").await.is_ok()); } #[tokio::test] diff --git a/src/config/mod.rs b/src/config/mod.rs index d67a80d..8187eec 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -36,6 +36,7 @@ mod tests { allowed_users: vec!["alice".into()], stream_mode: StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index dbd84f0..6a87620 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -2125,6 +2125,10 @@ pub struct TelegramConfig { /// Minimum interval (ms) between draft message edits to avoid rate limits. #[serde(default = "default_draft_update_interval_ms")] pub draft_update_interval_ms: u64, + /// When true, a newer Telegram message from the same sender in the same chat + /// cancels the in-flight request and starts a fresh response with preserved history. + #[serde(default)] + pub interrupt_on_new_message: bool, /// When true, only respond to messages that @-mention the bot in groups. /// Direct messages are always processed. #[serde(default)] @@ -3520,6 +3524,7 @@ default_temperature = 0.7 allowed_users: vec!["user1".into()], stream_mode: StreamMode::default(), draft_update_interval_ms: default_draft_update_interval_ms(), + interrupt_on_new_message: false, mention_only: false, }), discord: None, @@ -3852,6 +3857,7 @@ tool_dispatcher = "xml" allowed_users: vec!["alice".into(), "bob".into()], stream_mode: StreamMode::Partial, draft_update_interval_ms: 500, + interrupt_on_new_message: true, mention_only: false, }; let json = serde_json::to_string(&tc).unwrap(); @@ -3860,6 +3866,7 @@ tool_dispatcher = "xml" assert_eq!(parsed.allowed_users.len(), 2); assert_eq!(parsed.stream_mode, StreamMode::Partial); assert_eq!(parsed.draft_update_interval_ms, 500); + assert!(parsed.interrupt_on_new_message); } #[test] @@ -3868,6 +3875,7 @@ tool_dispatcher = "xml" let parsed: TelegramConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.stream_mode, StreamMode::Off); assert_eq!(parsed.draft_update_interval_ms, 1000); + assert!(!parsed.interrupt_on_new_message); } #[test] diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index 9483141..a2dfee2 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -321,6 +321,7 @@ mod tests { allowed_users: vec![], stream_mode: crate::config::StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }); assert!(has_supervised_channels(&config)); diff --git a/src/integrations/registry.rs b/src/integrations/registry.rs index cc91082..9e28f5c 100644 --- a/src/integrations/registry.rs +++ b/src/integrations/registry.rs @@ -790,6 +790,7 @@ mod tests { allowed_users: vec!["user".into()], stream_mode: StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }); let entries = all_integrations(); diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index f52a416..eee75bd 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -2793,6 +2793,7 @@ fn setup_channels() -> Result { allowed_users, stream_mode: StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }); } diff --git a/tests/agent_loop_robustness.rs b/tests/agent_loop_robustness.rs index f63b51f..fadcd9f 100644 --- a/tests/agent_loop_robustness.rs +++ b/tests/agent_loop_robustness.rs @@ -128,7 +128,12 @@ struct CountingTool { impl CountingTool { fn new() -> (Self, Arc>) { let count = Arc::new(Mutex::new(0)); - (Self { count: count.clone() }, count) + ( + Self { + count: count.clone(), + }, + count, + ) } } @@ -295,10 +300,7 @@ async fn agent_handles_mixed_tool_success_and_failure() { text_response("Mixed results processed"), ])); - let mut agent = build_agent( - provider, - vec![Box::new(EchoTool), Box::new(FailingTool)], - ); + let mut agent = build_agent(provider, vec![Box::new(EchoTool), Box::new(FailingTool)]); let response = agent.turn("mixed tools").await.unwrap(); assert!(!response.is_empty()); } diff --git a/tests/channel_routing.rs b/tests/channel_routing.rs index 4db04e4..178c85a 100644 --- a/tests/channel_routing.rs +++ b/tests/channel_routing.rs @@ -24,6 +24,7 @@ fn channel_message_sender_field_holds_platform_user_id() { content: "test message".into(), channel: "telegram".into(), timestamp: 1700000000, + thread_ts: None, }; assert_eq!(msg.sender, "123456789"); @@ -40,11 +41,12 @@ fn channel_message_reply_target_distinct_from_sender() { // Simulates Discord: reply_target should be channel_id, not sender user_id let msg = ChannelMessage { id: "msg_1".into(), - sender: "user_987654".into(), // Discord user ID + sender: "user_987654".into(), // Discord user ID reply_target: "channel_123".into(), // Discord channel ID for replies content: "test message".into(), channel: "discord".into(), timestamp: 1700000000, + thread_ts: None, }; assert_ne!( @@ -64,9 +66,13 @@ fn channel_message_fields_not_swapped() { content: "payload".into(), channel: "test".into(), timestamp: 1700000000, + thread_ts: None, }; - assert_eq!(msg.sender, "sender_value", "sender field should not be swapped"); + assert_eq!( + msg.sender, "sender_value", + "sender field should not be swapped" + ); assert_eq!( msg.reply_target, "target_value", "reply_target field should not be swapped" @@ -86,6 +92,7 @@ fn channel_message_preserves_all_fields_on_clone() { content: "cloned content".into(), channel: "test_channel".into(), timestamp: 1700000001, + thread_ts: None, }; let cloned = original.clone(); @@ -170,10 +177,7 @@ impl Channel for CapturingChannel { Ok(()) } - async fn listen( - &self, - tx: tokio::sync::mpsc::Sender, - ) -> anyhow::Result<()> { + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { tx.send(ChannelMessage { id: "listen_1".into(), sender: "test_sender".into(), @@ -181,6 +185,7 @@ impl Channel for CapturingChannel { content: "incoming".into(), channel: "capturing".into(), timestamp: 1700000000, + thread_ts: None, }) .await .map_err(|e| anyhow::anyhow!(e.to_string())) @@ -266,7 +271,10 @@ async fn channel_draft_defaults() { .send_draft(&SendMessage::new("draft", "target")) .await .unwrap(); - assert!(draft_result.is_none(), "default send_draft should return None"); + assert!( + draft_result.is_none(), + "default send_draft should return None" + ); assert!(channel .update_draft("target", "msg_1", "updated") diff --git a/tests/config_persistence.rs b/tests/config_persistence.rs index edeef89..079b9df 100644 --- a/tests/config_persistence.rs +++ b/tests/config_persistence.rs @@ -232,7 +232,10 @@ fn workspace_dir_creation_in_tempdir() { fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed"); assert!(workspace_dir.exists(), "workspace dir should exist"); - assert!(workspace_dir.is_dir(), "workspace path should be a directory"); + assert!( + workspace_dir.is_dir(), + "workspace path should be a directory" + ); } #[test] diff --git a/tests/memory_restart.rs b/tests/memory_restart.rs index 7538ab3..fe63f16 100644 --- a/tests/memory_restart.rs +++ b/tests/memory_restart.rs @@ -29,10 +29,17 @@ async fn sqlite_memory_store_same_key_deduplicates() { // Should have exactly 1 entry, not 2 let count = mem.count().await.unwrap(); - assert_eq!(count, 1, "storing same key twice should not create duplicates"); + assert_eq!( + count, 1, + "storing same key twice should not create duplicates" + ); // Content should be the latest version - let entry = mem.get("greeting").await.unwrap().expect("entry should exist"); + let entry = mem + .get("greeting") + .await + .unwrap() + .expect("entry should exist"); assert_eq!(entry.content, "hello updated"); } @@ -63,9 +70,14 @@ async fn sqlite_memory_persists_across_reinitialization() { // First "session": store data { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("persistent_fact", "Rust is great", MemoryCategory::Core, None) - .await - .unwrap(); + mem.store( + "persistent_fact", + "Rust is great", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } // Second "session": re-create memory from same path @@ -158,16 +170,24 @@ async fn sqlite_memory_global_recall_includes_all_sessions() { let tmp = tempfile::TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("global_a", "alpha content", MemoryCategory::Core, Some("s1")) - .await - .unwrap(); + mem.store( + "global_a", + "alpha content", + MemoryCategory::Core, + Some("s1"), + ) + .await + .unwrap(); mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2")) .await .unwrap(); // Global count should include all let count = mem.count().await.unwrap(); - assert_eq!(count, 2, "global count should include entries from all sessions"); + assert_eq!( + count, 2, + "global count should include entries from all sessions" + ); } // ───────────────────────────────────────────────────────────────────────────── @@ -179,12 +199,22 @@ async fn sqlite_memory_recall_returns_relevant_results() { let tmp = tempfile::TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("lang_pref", "User prefers Rust programming", MemoryCategory::Core, None) - .await - .unwrap(); - mem.store("food_pref", "User likes sushi for lunch", MemoryCategory::Core, None) - .await - .unwrap(); + mem.store( + "lang_pref", + "User prefers Rust programming", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store( + "food_pref", + "User likes sushi for lunch", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); let results = mem.recall("Rust programming", 10, None).await.unwrap(); assert!(!results.is_empty(), "recall should find matching entries"); @@ -229,10 +259,7 @@ async fn sqlite_memory_recall_empty_query_returns_empty() { .unwrap(); let results = mem.recall("", 10, None).await.unwrap(); - assert!( - results.is_empty(), - "empty query should return no results" - ); + assert!(results.is_empty(), "empty query should return no results"); } // ───────────────────────────────────────────────────────────────────────────── @@ -322,9 +349,14 @@ async fn sqlite_memory_list_by_category() { mem.store("daily_note", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - mem.store("conv_msg", "conversation msg", MemoryCategory::Conversation, None) - .await - .unwrap(); + mem.store( + "conv_msg", + "conversation msg", + MemoryCategory::Conversation, + None, + ) + .await + .unwrap(); let core_entries = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert_eq!(core_entries.len(), 1, "should have 1 Core entry"); diff --git a/tests/provider_schema.rs b/tests/provider_schema.rs index bc3aa67..84e2c84 100644 --- a/tests/provider_schema.rs +++ b/tests/provider_schema.rs @@ -80,7 +80,10 @@ fn tool_call_has_required_fields() { let json = serde_json::to_value(&tc).unwrap(); assert!(json.get("id").is_some(), "ToolCall must have 'id' field"); - assert!(json.get("name").is_some(), "ToolCall must have 'name' field"); + assert!( + json.get("name").is_some(), + "ToolCall must have 'name' field" + ); assert!( json.get("arguments").is_some(), "ToolCall must have 'arguments' field" @@ -98,7 +101,10 @@ fn tool_call_id_preserved_in_serialization() { let json_str = serde_json::to_string(&tc).unwrap(); let parsed: ToolCall = serde_json::from_str(&json_str).unwrap(); - assert_eq!(parsed.id, "call_deepseek_42", "tool_call_id must survive roundtrip"); + assert_eq!( + parsed.id, "call_deepseek_42", + "tool_call_id must survive roundtrip" + ); assert_eq!(parsed.name, "shell"); } @@ -111,8 +117,8 @@ fn tool_call_arguments_contain_valid_json() { }; // Arguments should parse as valid JSON - let args: serde_json::Value = serde_json::from_str(&tc.arguments) - .expect("tool call arguments should be valid JSON"); + let args: serde_json::Value = + serde_json::from_str(&tc.arguments).expect("tool call arguments should be valid JSON"); assert!(args.get("path").is_some()); assert!(args.get("content").is_some()); } @@ -125,9 +131,8 @@ fn tool_call_arguments_contain_valid_json() { fn tool_response_message_can_embed_tool_call_id() { // DeepSeek requires tool_call_id in tool response messages. // The tool message content can embed the tool_call_id as JSON. - let tool_response = ChatMessage::tool( - r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#, - ); + let tool_response = + ChatMessage::tool(r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#); let parsed: serde_json::Value = serde_json::from_str(&tool_response.content) .expect("tool response content should be valid JSON"); @@ -245,21 +250,32 @@ fn provider_construction_with_different_names() { Some("test-key"), AuthStyle::Bearer, ); - let _p2 = OpenAiCompatibleProvider::new( - "deepseek", - "https://api.test.com", - None, - AuthStyle::Bearer, - ); + let _p2 = + OpenAiCompatibleProvider::new("deepseek", "https://api.test.com", None, AuthStyle::Bearer); } #[test] fn provider_construction_with_different_auth_styles() { use zeroclaw::providers::compatible::OpenAiCompatibleProvider; - let _bearer = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Bearer); - let _xapi = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::XApiKey); - let _custom = OpenAiCompatibleProvider::new("Test", "https://api.test.com", Some("key"), AuthStyle::Custom("X-My-Auth".into())); + let _bearer = OpenAiCompatibleProvider::new( + "Test", + "https://api.test.com", + Some("key"), + AuthStyle::Bearer, + ); + let _xapi = OpenAiCompatibleProvider::new( + "Test", + "https://api.test.com", + Some("key"), + AuthStyle::XApiKey, + ); + let _custom = OpenAiCompatibleProvider::new( + "Test", + "https://api.test.com", + Some("key"), + AuthStyle::Custom("X-My-Auth".into()), + ); } // ─────────────────────────────────────────────────────────────────────────────