diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 9b299ea..39f4b39 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -396,6 +396,8 @@ pub async fn run( mem.clone(), composio_key, &config.browser, + &config.agents, + config.api_key.as_deref(), ); // ── Resolve provider ───────────────────────────────────────── @@ -470,6 +472,14 @@ pub async fn run( "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", )); } + if !config.agents.is_empty() { + tool_descs.push(( + "delegate", + "Delegate a subtask to a specialized agent. Use when: a task benefits from a different model \ + (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single \ + prompt and returns its response.", + )); + } let mut system_prompt = crate::channels::build_system_prompt( &config.workspace_dir, model_name, diff --git a/src/config/mod.rs b/src/config/mod.rs index b442538..b18a699 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,8 +1,9 @@ pub mod schema; pub use schema::{ - AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, - DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, - MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, - RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig, + AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DelegateAgentConfig, + DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, + IdentityConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, + ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, + WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index 1912334..7b4a198 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -2,6 +2,7 @@ use crate::security::AutonomyLevel; use anyhow::{Context, Result}; use directories::UserDirs; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::fs::{self, File, OpenOptions}; use std::io::Write; use std::path::{Path, PathBuf}; @@ -63,6 +64,22 @@ pub struct Config { #[serde(default)] pub identity: IdentityConfig, + + /// Named delegate agents for agent-to-agent handoff. + /// + /// ```toml + /// [agents.researcher] + /// provider = "gemini" + /// model = "gemini-2.0-flash" + /// system_prompt = "You are a research assistant..." + /// + /// [agents.coder] + /// provider = "openrouter" + /// model = "anthropic/claude-sonnet-4-20250514" + /// system_prompt = "You are a coding assistant..." + /// ``` + #[serde(default)] + pub agents: HashMap, } // ── Identity (AIEOS / OpenClaw format) ────────────────────────── @@ -94,6 +111,36 @@ impl Default for IdentityConfig { } } +// ── Agent delegation ───────────────────────────────────────────── + +/// Configuration for a named delegate agent that can be invoked via the +/// `delegate` tool. Each agent uses its own provider/model combination +/// and system prompt, enabling multi-agent workflows with specialization. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DelegateAgentConfig { + /// Provider name (e.g. "gemini", "openrouter", "ollama") + pub provider: String, + /// Model identifier for the provider + pub model: String, + /// System prompt defining the agent's role and capabilities + #[serde(default)] + pub system_prompt: Option, + /// Optional API key override (uses default if not set). + /// Stored encrypted when `secrets.encrypt = true`. + #[serde(default)] + pub api_key: Option, + /// Temperature override (uses 0.7 if not set) + #[serde(default)] + pub temperature: Option, + /// Maximum delegation depth to prevent infinite recursion (default: 3) + #[serde(default = "default_max_delegation_depth")] + pub max_depth: u32, +} + +fn default_max_delegation_depth() -> u32 { + 3 +} + // ── Gateway security ───────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -832,6 +879,7 @@ impl Default for Config { secrets: SecretsConfig::default(), browser: BrowserConfig::default(), identity: IdentityConfig::default(), + agents: HashMap::new(), } } } @@ -858,6 +906,19 @@ impl Config { // Set computed paths that are skipped during serialization config.config_path = config_path.clone(); config.workspace_dir = zeroclaw_dir.join("workspace"); + + // Decrypt agent API keys if encryption is enabled + let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt); + for agent in config.agents.values_mut() { + if let Some(ref encrypted_key) = agent.api_key { + agent.api_key = Some( + store + .decrypt(encrypted_key) + .context("Failed to decrypt agent API key")?, + ); + } + } + Ok(config) } else { let mut config = Config::default(); @@ -928,7 +989,27 @@ impl Config { } pub fn save(&self) -> Result<()> { - let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?; + // Encrypt agent API keys before serialization + let mut config_to_save = self.clone(); + let zeroclaw_dir = self + .config_path + .parent() + .context("Config path must have a parent directory")?; + let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt); + for agent in config_to_save.agents.values_mut() { + if let Some(ref plaintext_key) = agent.api_key { + if !crate::security::SecretStore::is_encrypted(plaintext_key) { + agent.api_key = Some( + store + .encrypt(plaintext_key) + .context("Failed to encrypt agent API key")?, + ); + } + } + } + + let toml_str = + toml::to_string_pretty(&config_to_save).context("Failed to serialize config")?; let parent_dir = self .config_path @@ -1013,6 +1094,7 @@ fn sync_directory(_path: &Path) -> Result<()> { mod tests { use super::*; use std::path::PathBuf; + use tempfile::TempDir; // ── Defaults ───────────────────────────────────────────── @@ -1142,6 +1224,7 @@ mod tests { secrets: SecretsConfig::default(), browser: BrowserConfig::default(), identity: IdentityConfig::default(), + agents: HashMap::new(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -1213,6 +1296,7 @@ default_temperature = 0.7 secrets: SecretsConfig::default(), browser: BrowserConfig::default(), identity: IdentityConfig::default(), + agents: HashMap::new(), }; config.save().unwrap(); @@ -1967,4 +2051,171 @@ default_temperature = 0.7 assert!(!g.allow_public_bind); assert!(g.paired_tokens.is_empty()); } + + // ══════════════════════════════════════════════════════════ + // AGENT DELEGATION CONFIG TESTS + // ══════════════════════════════════════════════════════════ + + #[test] + fn agents_config_default_empty() { + let c = Config::default(); + assert!(c.agents.is_empty()); + } + + #[test] + fn agents_config_backward_compat_missing_section() { + let minimal = r#" +workspace_dir = "/tmp/ws" +config_path = "/tmp/config.toml" +default_temperature = 0.7 +"#; + let parsed: Config = toml::from_str(minimal).unwrap(); + assert!(parsed.agents.is_empty()); + } + + #[test] + fn agents_config_toml_roundtrip() { + let toml_str = r#" +default_temperature = 0.7 + +[agents.researcher] +provider = "gemini" +model = "gemini-2.0-flash" +system_prompt = "You are a research assistant." +max_depth = 2 + +[agents.coder] +provider = "openrouter" +model = "anthropic/claude-sonnet-4-20250514" +"#; + let parsed: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(parsed.agents.len(), 2); + + let researcher = &parsed.agents["researcher"]; + assert_eq!(researcher.provider, "gemini"); + assert_eq!(researcher.model, "gemini-2.0-flash"); + assert_eq!( + researcher.system_prompt.as_deref(), + Some("You are a research assistant.") + ); + assert_eq!(researcher.max_depth, 2); + assert!(researcher.api_key.is_none()); + assert!(researcher.temperature.is_none()); + + let coder = &parsed.agents["coder"]; + assert_eq!(coder.provider, "openrouter"); + assert_eq!(coder.model, "anthropic/claude-sonnet-4-20250514"); + assert!(coder.system_prompt.is_none()); + assert_eq!(coder.max_depth, 3); // default + } + + #[test] + fn agents_config_with_api_key_and_temperature() { + let toml_str = r#" +[agents.fast] +provider = "groq" +model = "llama-3.3-70b-versatile" +api_key = "gsk-test-key" +temperature = 0.3 +"#; + let parsed: HashMap = toml::from_str::(toml_str) + .unwrap()["agents"] + .clone() + .try_into() + .unwrap(); + let fast = &parsed["fast"]; + assert_eq!(fast.api_key.as_deref(), Some("gsk-test-key")); + assert!((fast.temperature.unwrap() - 0.3).abs() < f64::EPSILON); + } + + #[test] + fn agent_api_key_encrypted_on_save_and_decrypted_on_load() { + let tmp = TempDir::new().unwrap(); + let zeroclaw_dir = tmp.path(); + let config_path = zeroclaw_dir.join("config.toml"); + + // Create a config with a plaintext agent API key + let mut agents = HashMap::new(); + agents.insert( + "test_agent".to_string(), + DelegateAgentConfig { + provider: "openrouter".to_string(), + model: "test-model".to_string(), + system_prompt: None, + api_key: Some("sk-super-secret".to_string()), + temperature: None, + max_depth: 3, + }, + ); + let mut config = Config { + config_path: config_path.clone(), + workspace_dir: zeroclaw_dir.join("workspace"), + secrets: SecretsConfig { encrypt: true }, + agents, + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config.save().unwrap(); + + // Read the raw TOML and verify the key is encrypted (not plaintext) + let raw = std::fs::read_to_string(&config_path).unwrap(); + assert!( + !raw.contains("sk-super-secret"), + "Plaintext API key should not appear in saved config" + ); + assert!( + raw.contains("enc2:"), + "Encrypted key should use enc2: prefix" + ); + + // Parse and decrypt — simulate load_or_init by reading + decrypting + let store = crate::security::SecretStore::new(zeroclaw_dir, true); + let mut loaded: Config = toml::from_str(&raw).unwrap(); + for agent in loaded.agents.values_mut() { + if let Some(ref encrypted_key) = agent.api_key { + agent.api_key = Some(store.decrypt(encrypted_key).unwrap()); + } + } + assert_eq!( + loaded.agents["test_agent"].api_key.as_deref(), + Some("sk-super-secret"), + "Decrypted key should match original" + ); + } + + #[test] + fn agent_api_key_not_encrypted_when_disabled() { + let tmp = TempDir::new().unwrap(); + let zeroclaw_dir = tmp.path(); + let config_path = zeroclaw_dir.join("config.toml"); + + let mut agents = HashMap::new(); + agents.insert( + "test_agent".to_string(), + DelegateAgentConfig { + provider: "openrouter".to_string(), + model: "test-model".to_string(), + system_prompt: None, + api_key: Some("sk-plaintext-ok".to_string()), + temperature: None, + max_depth: 3, + }, + ); + let config = Config { + config_path: config_path.clone(), + workspace_dir: zeroclaw_dir.join("workspace"), + secrets: SecretsConfig { encrypt: false }, + agents, + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config.save().unwrap(); + + let raw = std::fs::read_to_string(&config_path).unwrap(); + assert!( + raw.contains("sk-plaintext-ok"), + "With encryption disabled, key should remain plaintext" + ); + assert!(!raw.contains("enc2:"), "No encryption prefix when disabled"); + } } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 3a74a50..28ae154 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -106,6 +106,7 @@ pub fn run_wizard() -> Result { secrets: secrets_config, browser: BrowserConfig::default(), identity: crate::config::IdentityConfig::default(), + agents: std::collections::HashMap::new(), }; println!( @@ -297,6 +298,7 @@ pub fn run_quick_setup( secrets: SecretsConfig::default(), browser: BrowserConfig::default(), identity: crate::config::IdentityConfig::default(), + agents: std::collections::HashMap::new(), }; config.save()?; diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs new file mode 100644 index 0000000..c2660a4 --- /dev/null +++ b/src/tools/delegate.rs @@ -0,0 +1,426 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::DelegateAgentConfig; +use crate::providers::{self, Provider}; +use async_trait::async_trait; +use serde_json::json; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +/// Default timeout for sub-agent provider calls. +const DELEGATE_TIMEOUT_SECS: u64 = 120; + +/// Tool that delegates a subtask to a named agent with a different +/// provider/model configuration. Enables multi-agent workflows where +/// a primary agent can hand off specialized work (research, coding, +/// summarization) to purpose-built sub-agents. +pub struct DelegateTool { + agents: Arc>, + /// Global API key fallback (from config.api_key) + fallback_api_key: Option, + /// Depth at which this tool instance lives in the delegation chain. + depth: u32, +} + +impl DelegateTool { + pub fn new( + agents: HashMap, + fallback_api_key: Option, + ) -> Self { + Self { + agents: Arc::new(agents), + fallback_api_key, + depth: 0, + } + } + + /// Create a DelegateTool for a sub-agent (with incremented depth). + /// When sub-agents eventually get their own tool registry, construct + /// their DelegateTool via this method with `depth: parent.depth + 1`. + pub fn with_depth( + agents: HashMap, + fallback_api_key: Option, + depth: u32, + ) -> Self { + Self { + agents: Arc::new(agents), + fallback_api_key, + depth, + } + } +} + +#[async_trait] +impl Tool for DelegateTool { + fn name(&self) -> &str { + "delegate" + } + + fn description(&self) -> &str { + "Delegate a subtask to a specialized agent. Use when: a task benefits from a different model \ + (e.g. fast summarization, deep reasoning, code generation). The sub-agent runs a single \ + prompt and returns its response." + } + + fn parameters_schema(&self) -> serde_json::Value { + let agent_names: Vec<&str> = self.agents.keys().map(|s: &String| s.as_str()).collect(); + json!({ + "type": "object", + "additionalProperties": false, + "properties": { + "agent": { + "type": "string", + "minLength": 1, + "description": format!( + "Name of the agent to delegate to. Available: {}", + if agent_names.is_empty() { + "(none configured)".to_string() + } else { + agent_names.join(", ") + } + ) + }, + "prompt": { + "type": "string", + "minLength": 1, + "description": "The task/prompt to send to the sub-agent" + }, + "context": { + "type": "string", + "description": "Optional context to prepend (e.g. relevant code, prior findings)" + } + }, + "required": ["agent", "prompt"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let agent_name = args + .get("agent") + .and_then(|v| v.as_str()) + .map(str::trim) + .ok_or_else(|| anyhow::anyhow!("Missing 'agent' parameter"))?; + + if agent_name.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'agent' parameter must not be empty".into()), + }); + } + + let prompt = args + .get("prompt") + .and_then(|v| v.as_str()) + .map(str::trim) + .ok_or_else(|| anyhow::anyhow!("Missing 'prompt' parameter"))?; + + if prompt.is_empty() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'prompt' parameter must not be empty".into()), + }); + } + + let context = args + .get("context") + .and_then(|v| v.as_str()) + .map(str::trim) + .unwrap_or(""); + + // Look up agent config + let agent_config = match self.agents.get(agent_name) { + Some(cfg) => cfg, + None => { + let available: Vec<&str> = + self.agents.keys().map(|s: &String| s.as_str()).collect(); + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown agent '{agent_name}'. Available agents: {}", + if available.is_empty() { + "(none configured)".to_string() + } else { + available.join(", ") + } + )), + }); + } + }; + + // Check recursion depth (immutable — set at construction, incremented for sub-agents) + if self.depth >= agent_config.max_depth { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Delegation depth limit reached ({depth}/{max}). \ + Cannot delegate further to prevent infinite loops.", + depth = self.depth, + max = agent_config.max_depth + )), + }); + } + + // Create provider for this agent + let api_key = agent_config + .api_key + .as_deref() + .or(self.fallback_api_key.as_deref()); + + let provider: Box = + match providers::create_provider(&agent_config.provider, api_key) { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Failed to create provider '{}' for agent '{agent_name}': {e}", + agent_config.provider + )), + }); + } + }; + + // Build the message + let full_prompt = if context.is_empty() { + prompt.to_string() + } else { + format!("[Context]\n{context}\n\n[Task]\n{prompt}") + }; + + let temperature = agent_config.temperature.unwrap_or(0.7); + + // Wrap the provider call in a timeout to prevent indefinite blocking + let result = tokio::time::timeout( + Duration::from_secs(DELEGATE_TIMEOUT_SECS), + provider.chat_with_system( + agent_config.system_prompt.as_deref(), + &full_prompt, + &agent_config.model, + temperature, + ), + ) + .await; + + let result = match result { + Ok(inner) => inner, + Err(_elapsed) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Agent '{agent_name}' timed out after {DELEGATE_TIMEOUT_SECS}s" + )), + }); + } + }; + + match result { + Ok(response) => Ok(ToolResult { + success: true, + output: format!( + "[Agent '{agent_name}' ({provider}/{model})]\n{response}", + provider = agent_config.provider, + model = agent_config.model + ), + error: None, + }), + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Agent '{agent_name}' failed: {e}",)), + }), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn sample_agents() -> HashMap { + let mut agents = HashMap::new(); + agents.insert( + "researcher".to_string(), + DelegateAgentConfig { + provider: "ollama".to_string(), + model: "llama3".to_string(), + system_prompt: Some("You are a research assistant.".to_string()), + api_key: None, + temperature: Some(0.3), + max_depth: 3, + }, + ); + agents.insert( + "coder".to_string(), + DelegateAgentConfig { + provider: "openrouter".to_string(), + model: "anthropic/claude-sonnet-4-20250514".to_string(), + system_prompt: None, + api_key: Some("sk-test".to_string()), + temperature: None, + max_depth: 2, + }, + ); + agents + } + + #[test] + fn name_and_schema() { + let tool = DelegateTool::new(sample_agents(), None); + assert_eq!(tool.name(), "delegate"); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["agent"].is_object()); + assert!(schema["properties"]["prompt"].is_object()); + assert!(schema["properties"]["context"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("agent"))); + assert!(required.contains(&json!("prompt"))); + assert_eq!(schema["additionalProperties"], json!(false)); + assert_eq!(schema["properties"]["agent"]["minLength"], json!(1)); + assert_eq!(schema["properties"]["prompt"]["minLength"], json!(1)); + } + + #[test] + fn description_not_empty() { + let tool = DelegateTool::new(sample_agents(), None); + assert!(!tool.description().is_empty()); + } + + #[test] + fn schema_lists_agent_names() { + let tool = DelegateTool::new(sample_agents(), None); + let schema = tool.parameters_schema(); + let desc = schema["properties"]["agent"]["description"] + .as_str() + .unwrap(); + assert!(desc.contains("researcher") || desc.contains("coder")); + } + + #[tokio::test] + async fn missing_agent_param() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool.execute(json!({"prompt": "test"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn missing_prompt_param() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool.execute(json!({"agent": "researcher"})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn unknown_agent_returns_error() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool + .execute(json!({"agent": "nonexistent", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Unknown agent")); + } + + #[tokio::test] + async fn depth_limit_enforced() { + let tool = DelegateTool::with_depth(sample_agents(), None, 3); + let result = tool + .execute(json!({"agent": "researcher", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("depth limit")); + } + + #[tokio::test] + async fn depth_limit_per_agent() { + // coder has max_depth=2, so depth=2 should be blocked + let tool = DelegateTool::with_depth(sample_agents(), None, 2); + let result = tool + .execute(json!({"agent": "coder", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("depth limit")); + } + + #[test] + fn empty_agents_schema() { + let tool = DelegateTool::new(HashMap::new(), None); + let schema = tool.parameters_schema(); + let desc = schema["properties"]["agent"]["description"] + .as_str() + .unwrap(); + assert!(desc.contains("none configured")); + } + + #[tokio::test] + async fn invalid_provider_returns_error() { + let mut agents = HashMap::new(); + agents.insert( + "broken".to_string(), + DelegateAgentConfig { + provider: "totally-invalid-provider".to_string(), + model: "model".to_string(), + system_prompt: None, + api_key: None, + temperature: None, + max_depth: 3, + }, + ); + let tool = DelegateTool::new(agents, None); + let result = tool + .execute(json!({"agent": "broken", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("Failed to create provider")); + } + + #[tokio::test] + async fn blank_agent_rejected() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool + .execute(json!({"agent": " ", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("must not be empty")); + } + + #[tokio::test] + async fn blank_prompt_rejected() { + let tool = DelegateTool::new(sample_agents(), None); + let result = tool + .execute(json!({"agent": "researcher", "prompt": " \t "})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("must not be empty")); + } + + #[tokio::test] + async fn whitespace_agent_name_trimmed_and_found() { + let tool = DelegateTool::new(sample_agents(), None); + // " researcher " with surrounding whitespace — after trim becomes "researcher" + let result = tool + .execute(json!({"agent": " researcher ", "prompt": "test"})) + .await + .unwrap(); + // Should find "researcher" after trim — will fail at provider level + // since ollama isn't running, but must NOT get "Unknown agent". + assert!( + result.error.is_none() + || !result + .error + .as_deref() + .unwrap_or("") + .contains("Unknown agent") + ); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 446c1ee..c2814c0 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,6 +1,7 @@ pub mod browser; pub mod browser_open; pub mod composio; +pub mod delegate; pub mod file_read; pub mod file_write; pub mod image_info; @@ -14,6 +15,7 @@ pub mod traits; pub use browser::BrowserTool; pub use browser_open::BrowserOpenTool; pub use composio::ComposioTool; +pub use delegate::DelegateTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; pub use image_info::ImageInfoTool; @@ -26,9 +28,11 @@ pub use traits::Tool; #[allow(unused_imports)] pub use traits::{ToolResult, ToolSpec}; +use crate::config::DelegateAgentConfig; use crate::memory::Memory; use crate::runtime::{NativeRuntime, RuntimeAdapter}; use crate::security::SecurityPolicy; +use std::collections::HashMap; use std::sync::Arc; /// Create the default tool registry @@ -54,6 +58,8 @@ pub fn all_tools( memory: Arc, composio_key: Option<&str>, browser_config: &crate::config::BrowserConfig, + agents: &HashMap, + fallback_api_key: Option<&str>, ) -> Vec> { all_tools_with_runtime( security, @@ -61,6 +67,8 @@ pub fn all_tools( memory, composio_key, browser_config, + agents, + fallback_api_key, ) } @@ -71,6 +79,8 @@ pub fn all_tools_with_runtime( memory: Arc, composio_key: Option<&str>, browser_config: &crate::config::BrowserConfig, + agents: &HashMap, + fallback_api_key: Option<&str>, ) -> Vec> { let mut tools: Vec> = vec![ Box::new(ShellTool::new(security.clone(), runtime)), @@ -105,6 +115,14 @@ pub fn all_tools_with_runtime( } } + // Add delegation tool when agents are configured + if !agents.is_empty() { + tools.push(Box::new(DelegateTool::new( + agents.clone(), + fallback_api_key.map(String::from), + ))); + } + tools } @@ -138,7 +156,7 @@ mod tests { session_name: None, }; - let tools = all_tools(&security, mem, None, &browser); + let tools = all_tools(&security, mem, None, &browser, &HashMap::new(), None); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); } @@ -160,7 +178,7 @@ mod tests { session_name: None, }; - let tools = all_tools(&security, mem, None, &browser); + let tools = all_tools(&security, mem, None, &browser, &HashMap::new(), None); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); } @@ -258,4 +276,53 @@ mod tests { assert_eq!(parsed.name, "test"); assert_eq!(parsed.description, "A test tool"); } + + #[test] + fn all_tools_includes_delegate_when_agents_configured() { + let tmp = TempDir::new().unwrap(); + let security = Arc::new(SecurityPolicy::default()); + let mem_cfg = MemoryConfig { + backend: "markdown".into(), + ..MemoryConfig::default() + }; + let mem: Arc = + Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap()); + + let browser = BrowserConfig::default(); + + let mut agents = HashMap::new(); + agents.insert( + "researcher".to_string(), + DelegateAgentConfig { + provider: "ollama".to_string(), + model: "llama3".to_string(), + system_prompt: None, + api_key: None, + temperature: None, + max_depth: 3, + }, + ); + + let tools = all_tools(&security, mem, None, &browser, &agents, Some("sk-test")); + let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); + assert!(names.contains(&"delegate")); + } + + #[test] + fn all_tools_excludes_delegate_when_no_agents() { + let tmp = TempDir::new().unwrap(); + let security = Arc::new(SecurityPolicy::default()); + let mem_cfg = MemoryConfig { + backend: "markdown".into(), + ..MemoryConfig::default() + }; + let mem: Arc = + Arc::from(crate::memory::create_memory(&mem_cfg, tmp.path(), None).unwrap()); + + let browser = BrowserConfig::default(); + + let tools = all_tools(&security, mem, None, &browser, &HashMap::new(), None); + let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); + assert!(!names.contains(&"delegate")); + } }