diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2282e66..acf62a4 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,8 +10,14 @@ use crate::channels::{Channel, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::providers::{self, ChatResponse, Provider}; -use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; +use crate::observability::{self, Observer}; +use crate::providers::{self, ChatMessage, Provider}; +use crate::runtime; +use crate::security::{ + pairing::{constant_time_eq, is_public_bind, PairingGuard}, + SecurityPolicy, +}; +use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; use axum::{ @@ -45,29 +51,33 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } -fn gateway_reply_from_response(response: ChatResponse) -> String { - let has_tool_calls = response.has_tool_calls(); - let tool_call_count = response.tool_calls.len(); - let mut reply = response.text.unwrap_or_default(); - - if has_tool_calls { - tracing::warn!( - tool_call_count, - "Provider requested tool calls in gateway mode; tool calls are not executed here" - ); - if reply.trim().is_empty() { - reply = "I need to use tools to answer that, but tool execution is not enabled for gateway requests yet." - .to_string(); - } - } - +fn normalize_gateway_reply(reply: String) -> String { if reply.trim().is_empty() { - reply = "Model returned an empty response.".to_string(); + return "Model returned an empty response.".to_string(); } reply } +async fn gateway_agent_reply(state: &AppState, message: &str) -> Result { + let mut history = vec![ + ChatMessage::system(state.system_prompt.as_str()), + ChatMessage::user(message), + ]; + + let reply = crate::agent::loop_::run_tool_call_loop( + state.provider.as_ref(), + &mut history, + state.tools_registry.as_ref(), + state.observer.as_ref(), + &state.model, + state.temperature, + ) + .await?; + + Ok(normalize_gateway_reply(reply)) +} + #[derive(Debug)] struct SlidingWindowRateLimiter { limit_per_window: u32, @@ -182,6 +192,9 @@ fn client_key_from_headers(headers: &HeaderMap) -> String { #[derive(Clone)] pub struct AppState { pub provider: Arc, + pub observer: Arc, + pub tools_registry: Arc>>, + pub system_prompt: Arc, pub model: String, pub temperature: f64, pub mem: Arc, @@ -228,6 +241,47 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &config.workspace_dir, config.api_key.as_deref(), )?); + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + + let composio_key = if config.composio.enabled { + config.composio.api_key.as_deref() + } else { + None + }; + + let tools_registry = Arc::new(tools::all_tools_with_runtime( + &security, + runtime, + Arc::clone(&mem), + composio_key, + &config.browser, + &config.agents, + config.api_key.as_deref(), + )); + let skills = crate::skills::load_skills(&config.workspace_dir); + let tool_descs: Vec<(&str, &str)> = tools_registry + .iter() + .map(|tool| (tool.name(), tool.description())) + .collect(); + + let mut system_prompt = crate::channels::build_system_prompt( + &config.workspace_dir, + &model, + &tool_descs, + &skills, + Some(&config.identity), + ); + system_prompt.push_str(&crate::agent::loop_::build_tool_instructions( + tools_registry.as_ref(), + )); + let system_prompt = Arc::new(system_prompt); // Extract webhook secret for authentication let webhook_secret: Option> = config @@ -331,6 +385,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { // Build shared state let state = AppState { provider, + observer, + tools_registry, + system_prompt, model, temperature, mem, @@ -514,13 +571,8 @@ async fn handle_webhook( .await; } - match state - .provider - .chat(message, &state.model, state.temperature) - .await - { - Ok(response) => { - let reply = gateway_reply_from_response(response); + match gateway_agent_reply(&state, message).await { + Ok(reply) => { let body = serde_json::json!({"response": reply, "model": state.model}); (StatusCode::OK, Json(body)) } @@ -669,13 +721,8 @@ async fn handle_whatsapp_message( } // Call the LLM - match state - .provider - .chat(&msg.content, &state.model, state.temperature) - .await - { - Ok(response) => { - let reply = gateway_reply_from_response(response); + match gateway_agent_reply(&state, &msg.content).await { + Ok(reply) => { // Send reply via WhatsApp if let Err(e) = wa.send(&reply, &msg.sender).await { tracing::error!("Failed to send WhatsApp reply: {e}"); @@ -847,9 +894,9 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - Ok(ChatResponse::with_text("ok")) + Ok(crate::providers::ChatResponse::with_text("ok")) } } @@ -910,25 +957,36 @@ mod tests { } } - #[tokio::test] - async fn webhook_idempotency_skips_duplicate_provider_calls() { - let provider_impl = Arc::new(MockProvider::default()); - let provider: Arc = provider_impl.clone(); - let memory: Arc = Arc::new(MockMemory); - - let state = AppState { + fn test_app_state( + provider: Arc, + memory: Arc, + auto_save: bool, + ) -> AppState { + AppState { provider, + observer: Arc::new(crate::observability::NoopObserver), + tools_registry: Arc::new(Vec::new()), + system_prompt: Arc::new("test-system-prompt".into()), model: "test-model".into(), temperature: 0.0, mem: memory, - auto_save: false, + auto_save, webhook_secret: None, pairing: Arc::new(PairingGuard::new(false, &[])), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), whatsapp: None, whatsapp_app_secret: None, - }; + } + } + + #[tokio::test] + async fn webhook_idempotency_skips_duplicate_provider_calls() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = test_app_state(provider, memory, false); let mut headers = HeaderMap::new(); headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123")); @@ -964,19 +1022,7 @@ mod tests { let tracking_impl = Arc::new(TrackingMemory::default()); let memory: Arc = tracking_impl.clone(); - let state = AppState { - provider, - model: "test-model".into(), - temperature: 0.0, - mem: memory, - auto_save: true, - webhook_secret: None, - pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), - whatsapp: None, - whatsapp_app_secret: None, - }; + let state = test_app_state(provider, memory, true); let headers = HeaderMap::new(); @@ -1008,6 +1054,110 @@ mod tests { assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); } + #[derive(Default)] + struct StructuredToolCallProvider { + calls: AtomicUsize, + } + + #[async_trait] + impl Provider for StructuredToolCallProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let turn = self.calls.fetch_add(1, Ordering::SeqCst); + + if turn == 0 { + return Ok(crate::providers::ChatResponse { + text: Some("Running tool...".into()), + tool_calls: vec![crate::providers::ToolCall { + id: "call_1".into(), + name: "mock_tool".into(), + arguments: r#"{"query":"gateway"}"#.into(), + }], + }); + } + + Ok(crate::providers::ChatResponse::with_text( + "Gateway tool result ready.", + )) + } + } + + struct MockTool { + calls: Arc, + } + + #[async_trait] + impl Tool for MockTool { + fn name(&self) -> &str { + "mock_tool" + } + + fn description(&self) -> &str { + "Mock tool for gateway tests" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "query": {"type": "string"} + }, + "required": ["query"] + }) + } + + async fn execute( + &self, + args: serde_json::Value, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + assert_eq!(args["query"], "gateway"); + + Ok(crate::tools::ToolResult { + success: true, + output: "ok".into(), + error: None, + }) + } + } + + #[tokio::test] + async fn webhook_executes_structured_tool_calls() { + let provider_impl = Arc::new(StructuredToolCallProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let tool_calls = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![Box::new(MockTool { + calls: Arc::clone(&tool_calls), + })]; + + let mut state = test_app_state(provider, memory, false); + state.tools_registry = Arc::new(tools); + + let response = handle_webhook( + State(state), + HeaderMap::new(), + Ok(Json(WebhookBody { + message: "please use tool".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::OK); + let payload = response.into_body().collect().await.unwrap().to_bytes(); + let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap(); + assert_eq!(parsed["response"], "Gateway tool result ready."); + assert_eq!(tool_calls.load(Ordering::SeqCst), 1); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); + } + // ══════════════════════════════════════════════════════════ // WhatsApp Signature Verification Tests (CWE-345 Prevention) // ══════════════════════════════════════════════════════════ diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 5911904..7c30650 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -8,6 +8,7 @@ pub mod reliable; pub mod router; pub mod traits; +#[allow(unused_imports)] pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index c2660a4..f205a58 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -220,15 +220,27 @@ impl Tool for DelegateTool { }; match result { - Ok(response) => Ok(ToolResult { - success: true, - output: format!( - "[Agent '{agent_name}' ({provider}/{model})]\n{response}", - provider = agent_config.provider, - model = agent_config.model - ), - error: None, - }), + Ok(response) => { + let has_tool_calls = response.has_tool_calls(); + let mut rendered = response.text.unwrap_or_default(); + if rendered.trim().is_empty() { + if has_tool_calls { + rendered = "[Tool-only response; no text content]".to_string(); + } else { + rendered = "[Empty response]".to_string(); + } + } + + Ok(ToolResult { + success: true, + output: format!( + "[Agent '{agent_name}' ({provider}/{model})]\n{rendered}", + provider = agent_config.provider, + model = agent_config.model + ), + error: None, + }) + } Err(e) => Ok(ToolResult { success: false, output: String::new(),