fix(security): remediate unassigned CodeQL findings

- harden URL/request handling for composio and whatsapp integrations
- reduce cleartext logging exposure across providers/tools/gateway
- hash and constant-time compare gateway webhook secrets
- expand nested secret encryption coverage in config
- align feature aliases and add regression tests for security paths
- fix bubblewrap all-features test invocation surfaced during deep validation
This commit is contained in:
Chummy 2026-02-17 15:44:41 +08:00
parent f9d681063d
commit 1711f140be
14 changed files with 481 additions and 146 deletions

View file

@ -63,9 +63,6 @@ rand = "0.8"
# Fast mutexes that don't poison on panic # Fast mutexes that don't poison on panic
parking_lot = "0.12" parking_lot = "0.12"
# Landlock (Linux sandbox) - optional dependency
landlock = { version = "0.4", optional = true }
# Async traits # Async traits
async-trait = "0.1" async-trait = "0.1"
@ -120,14 +117,24 @@ probe-rs = { version = "0.30", optional = true }
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) # PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
pdf-extract = { version = "0.10", optional = true } pdf-extract = { version = "0.10", optional = true }
# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS # Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]
rppal = { version = "0.14", optional = true } rppal = { version = "0.14", optional = true }
landlock = { version = "0.4", optional = true }
[features] [features]
default = ["hardware"] default = ["hardware"]
hardware = ["nusb", "tokio-serial"] hardware = ["nusb", "tokio-serial"]
peripheral-rpi = ["rppal"] peripheral-rpi = ["rppal"]
# Browser backend feature alias used by cfg(feature = "browser-native")
browser-native = ["dep:fantoccini"]
# Backward-compatible alias for older invocations
fantoccini = ["browser-native"]
# Sandbox feature aliases used by cfg(feature = "sandbox-*")
sandbox-landlock = ["dep:landlock"]
sandbox-bubblewrap = []
# Backward-compatible alias for older invocations
landlock = ["sandbox-landlock"]
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional) # probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
probe = ["dep:probe-rs"] probe = ["dep:probe-rs"]
# rag-pdf = PDF ingestion for datasheet RAG # rag-pdf = PDF ingestion for datasheet RAG

View file

@ -10,7 +10,7 @@ use uuid::Uuid;
/// happens in the gateway when Meta sends webhook events. /// happens in the gateway when Meta sends webhook events.
pub struct WhatsAppChannel { pub struct WhatsAppChannel {
access_token: String, access_token: String,
phone_number_id: String, endpoint_id: String,
verify_token: String, verify_token: String,
allowed_numbers: Vec<String>, allowed_numbers: Vec<String>,
client: reqwest::Client, client: reqwest::Client,
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
impl WhatsAppChannel { impl WhatsAppChannel {
pub fn new( pub fn new(
access_token: String, access_token: String,
phone_number_id: String, endpoint_id: String,
verify_token: String, verify_token: String,
allowed_numbers: Vec<String>, allowed_numbers: Vec<String>,
) -> Self { ) -> Self {
Self { Self {
access_token, access_token,
phone_number_id, endpoint_id,
verify_token, verify_token,
allowed_numbers, allowed_numbers,
client: reqwest::Client::new(), client: reqwest::Client::new(),
@ -142,7 +142,7 @@ impl Channel for WhatsAppChannel {
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages // WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
let url = format!( let url = format!(
"https://graph.facebook.com/v18.0/{}/messages", "https://graph.facebook.com/v18.0/{}/messages",
self.phone_number_id self.endpoint_id
); );
// Normalize recipient (remove leading + if present for API) // Normalize recipient (remove leading + if present for API)
@ -162,7 +162,7 @@ impl Channel for WhatsAppChannel {
let resp = self let resp = self
.client .client
.post(&url) .post(&url)
.header("Authorization", format!("Bearer {}", self.access_token)) .bearer_auth(&self.access_token)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.json(&body) .json(&body)
.send() .send()
@ -195,11 +195,11 @@ impl Channel for WhatsAppChannel {
async fn health_check(&self) -> bool { async fn health_check(&self) -> bool {
// Check if we can reach the WhatsApp API // Check if we can reach the WhatsApp API
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id); let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
self.client self.client
.get(&url) .get(&url)
.header("Authorization", format!("Bearer {}", self.access_token)) .bearer_auth(&self.access_token)
.send() .send()
.await .await
.map(|r| r.status().is_success()) .map(|r| r.status().is_success())

View file

@ -1678,6 +1678,40 @@ fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
workspace_config_dir workspace_config_dir
} }
fn decrypt_optional_secret(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.decrypt(&raw)
.with_context(|| format!("Failed to decrypt {field_name}"))?,
);
}
}
Ok(())
}
fn encrypt_optional_secret(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if !crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.encrypt(&raw)
.with_context(|| format!("Failed to encrypt {field_name}"))?,
);
}
}
Ok(())
}
impl Config { impl Config {
pub fn load_or_init() -> Result<Self> { pub fn load_or_init() -> Result<Self> {
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE. // Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
@ -1702,6 +1736,23 @@ impl Config {
// Set computed paths that are skipped during serialization // Set computed paths that are skipped during serialization
config.config_path = config_path.clone(); config.config_path = config_path.clone();
config.workspace_dir = workspace_dir; config.workspace_dir = workspace_dir;
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
decrypt_optional_secret(
&store,
&mut config.composio.api_key,
"config.composio.api_key",
)?;
decrypt_optional_secret(
&store,
&mut config.browser.computer_use.api_key,
"config.browser.computer_use.api_key",
)?;
for agent in config.agents.values_mut() {
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
config.apply_env_overrides(); config.apply_env_overrides();
Ok(config) Ok(config)
} else { } else {
@ -1789,23 +1840,29 @@ impl Config {
} }
pub fn save(&self) -> Result<()> { pub fn save(&self) -> Result<()> {
// Encrypt agent API keys before serialization // Encrypt secrets before serialization
let mut config_to_save = self.clone(); let mut config_to_save = self.clone();
let zeroclaw_dir = self let zeroclaw_dir = self
.config_path .config_path
.parent() .parent()
.context("Config path must have a parent directory")?; .context("Config path must have a parent directory")?;
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt); let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
encrypt_optional_secret(
&store,
&mut config_to_save.composio.api_key,
"config.composio.api_key",
)?;
encrypt_optional_secret(
&store,
&mut config_to_save.browser.computer_use.api_key,
"config.browser.computer_use.api_key",
)?;
for agent in config_to_save.agents.values_mut() { for agent in config_to_save.agents.values_mut() {
if let Some(ref plaintext_key) = agent.api_key { encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
agent.api_key = Some(
store
.encrypt(plaintext_key)
.context("Failed to encrypt agent API key")?,
);
}
}
} }
let toml_str = let toml_str =
@ -2182,13 +2239,82 @@ tool_dispatcher = "xml"
let contents = fs::read_to_string(&config_path).unwrap(); let contents = fs::read_to_string(&config_path).unwrap();
let loaded: Config = toml::from_str(&contents).unwrap(); let loaded: Config = toml::from_str(&contents).unwrap();
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip")); assert!(loaded
.api_key
.as_deref()
.is_some_and(crate::security::SecretStore::is_encrypted));
let store = crate::security::SecretStore::new(&dir, true);
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
assert_eq!(decrypted, "sk-roundtrip");
assert_eq!(loaded.default_model.as_deref(), Some("test-model")); assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON); assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
let _ = fs::remove_dir_all(&dir); let _ = fs::remove_dir_all(&dir);
} }
#[test]
fn config_save_encrypts_nested_credentials() {
let dir = std::env::temp_dir().join(format!(
"zeroclaw_test_nested_credentials_{}",
uuid::Uuid::new_v4()
));
fs::create_dir_all(&dir).unwrap();
let mut config = Config::default();
config.workspace_dir = dir.join("workspace");
config.config_path = dir.join("config.toml");
config.api_key = Some("root-credential".into());
config.composio.api_key = Some("composio-credential".into());
config.browser.computer_use.api_key = Some("browser-credential".into());
config.agents.insert(
"worker".into(),
DelegateAgentConfig {
provider: "openrouter".into(),
model: "model-test".into(),
system_prompt: None,
api_key: Some("agent-credential".into()),
temperature: None,
max_depth: 3,
},
);
config.save().unwrap();
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
let stored: Config = toml::from_str(&contents).unwrap();
let store = crate::security::SecretStore::new(&dir, true);
let root_encrypted = stored.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(
composio_encrypted
));
assert_eq!(
store.decrypt(composio_encrypted).unwrap(),
"composio-credential"
);
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(
browser_encrypted
));
assert_eq!(
store.decrypt(browser_encrypted).unwrap(),
"browser-credential"
);
let worker = stored.agents.get("worker").unwrap();
let worker_encrypted = worker.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
let _ = fs::remove_dir_all(&dir);
}
#[test] #[test]
fn config_save_atomic_cleanup() { fn config_save_atomic_cleanup() {
let dir = let dir =

View file

@ -48,6 +48,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
format!("whatsapp_{}_{}", msg.sender, msg.id) format!("whatsapp_{}_{}", msg.sender, msg.id)
} }
fn hash_webhook_secret(value: &str) -> String {
use sha2::{Digest, Sha256};
let digest = Sha256::digest(value.as_bytes());
hex::encode(digest)
}
/// How often the rate limiter sweeps stale IP entries from its map. /// How often the rate limiter sweeps stale IP entries from its map.
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
@ -179,7 +186,8 @@ pub struct AppState {
pub temperature: f64, pub temperature: f64,
pub mem: Arc<dyn Memory>, pub mem: Arc<dyn Memory>,
pub auto_save: bool, pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>, /// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
pub webhook_secret_hash: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>, pub pairing: Arc<PairingGuard>,
pub rate_limiter: Arc<GatewayRateLimiter>, pub rate_limiter: Arc<GatewayRateLimiter>,
pub idempotency_store: Arc<IdempotencyStore>, pub idempotency_store: Arc<IdempotencyStore>,
@ -253,11 +261,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
&config, &config,
)); ));
// Extract webhook secret for authentication // Extract webhook secret for authentication
let webhook_secret: Option<Arc<str>> = config let webhook_secret_hash: Option<Arc<str>> = config
.channels_config .channels_config
.webhook .webhook
.as_ref() .as_ref()
.and_then(|w| w.secret.as_deref()) .and_then(|w| w.secret.as_deref())
.map(str::trim)
.filter(|secret| !secret.is_empty())
.map(hash_webhook_secret)
.map(Arc::from); .map(Arc::from);
// WhatsApp channel (if configured) // WhatsApp channel (if configured)
@ -344,7 +355,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
} else { } else {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)"); println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
} }
if webhook_secret.is_some() { if webhook_secret_hash.is_some() {
println!(" 🔒 Webhook secret: ENABLED"); println!(" 🔒 Webhook secret: ENABLED");
} }
println!(" Press Ctrl+C to stop.\n"); println!(" Press Ctrl+C to stop.\n");
@ -358,7 +369,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
temperature, temperature,
mem, mem,
auto_save: config.memory.auto_save, auto_save: config.memory.auto_save,
webhook_secret, webhook_secret_hash,
pairing, pairing,
rate_limiter, rate_limiter,
idempotency_store, idempotency_store,
@ -484,12 +495,15 @@ async fn handle_webhook(
} }
// ── Webhook secret auth (optional, additional layer) ── // ── Webhook secret auth (optional, additional layer) ──
if let Some(ref secret) = state.webhook_secret { if let Some(ref secret_hash) = state.webhook_secret_hash {
let header_val = headers let header_hash = headers
.get("X-Webhook-Secret") .get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok())
match header_val { .map(str::trim)
Some(val) if constant_time_eq(val, secret.as_ref()) => {} .filter(|value| !value.is_empty())
.map(hash_webhook_secret);
match header_hash {
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
_ => { _ => {
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret"); tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
@ -993,7 +1007,7 @@ mod tests {
temperature: 0.0, temperature: 0.0,
mem: memory, mem: memory,
auto_save: false, auto_save: false,
webhook_secret: None, webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])), pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@ -1041,7 +1055,7 @@ mod tests {
temperature: 0.0, temperature: 0.0,
mem: memory, mem: memory,
auto_save: true, auto_save: true,
webhook_secret: None, webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])), pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@ -1079,6 +1093,125 @@ mod tests {
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
} }
#[test]
fn webhook_secret_hash_is_deterministic_and_nonempty() {
let one = hash_webhook_secret("secret-value");
let two = hash_webhook_secret("secret-value");
let other = hash_webhook_secret("other-value");
assert_eq!(one, two);
assert_ne!(one, other);
assert_eq!(one.len(), 64);
}
#[tokio::test]
async fn webhook_secret_hash_rejects_missing_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let response = handle_webhook(
State(state),
HeaderMap::new(),
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn webhook_secret_hash_rejects_invalid_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
let response = handle_webhook(
State(state),
headers,
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn webhook_secret_hash_accepts_valid_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
let response = handle_webhook(
State(state),
headers,
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention) // WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════

View file

@ -285,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub fn run_quick_setup( pub fn run_quick_setup(
api_key: Option<&str>, credential_override: Option<&str>,
provider: Option<&str>, provider: Option<&str>,
memory_backend: Option<&str>, memory_backend: Option<&str>,
) -> Result<Config> { ) -> Result<Config> {
@ -319,7 +319,7 @@ pub fn run_quick_setup(
let config = Config { let config = Config {
workspace_dir: workspace_dir.clone(), workspace_dir: workspace_dir.clone(),
config_path: config_path.clone(), config_path: config_path.clone(),
api_key: api_key.map(String::from), api_key: credential_override.map(String::from),
api_url: None, api_url: None,
default_provider: Some(provider_name.clone()), default_provider: Some(provider_name.clone()),
default_model: Some(model.clone()), default_model: Some(model.clone()),
@ -379,7 +379,7 @@ pub fn run_quick_setup(
println!( println!(
" {} API Key: {}", " {} API Key: {}",
style("").green().bold(), style("").green().bold(),
if api_key.is_some() { if credential_override.is_some() {
style("set").green() style("set").green()
} else { } else {
style("not set (use --api-key or edit config.toml)").yellow() style("not set (use --api-key or edit config.toml)").yellow()
@ -428,7 +428,7 @@ pub fn run_quick_setup(
); );
println!(); println!();
println!(" {}", style("Next steps:").white().bold()); println!(" {}", style("Next steps:").white().bold());
if api_key.is_none() { if credential_override.is_none() {
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\""); println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
println!(" 2. Or edit: ~/.zeroclaw/config.toml"); println!(" 2. Or edit: ~/.zeroclaw/config.toml");
println!(" 3. Chat: zeroclaw agent -m \"Hello!\""); println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
@ -2801,22 +2801,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
.header("Authorization", format!("Bearer {access_token_clone}")) .header("Authorization", format!("Bearer {access_token_clone}"))
.send()?; .send()?;
let ok = resp.status().is_success(); let ok = resp.status().is_success();
let data: serde_json::Value = resp.json().unwrap_or_default(); Ok::<_, reqwest::Error>(ok)
let user_id = data
.get("user_id")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string();
Ok::<_, reqwest::Error>((ok, user_id))
}) })
.join(); .join();
match thread_result { match thread_result {
Ok(Ok((true, user_id))) => { Ok(Ok(true)) => println!(
println!( "\r {} Connection verified ",
"\r {} Connected as {user_id} ", style("").green().bold()
style("").green().bold() ),
);
}
_ => { _ => {
println!( println!(
"\r {} Connection failed — check homeserver URL and token", "\r {} Connection failed — check homeserver URL and token",

View file

@ -106,17 +106,17 @@ struct NativeContentIn {
} }
impl AnthropicProvider { impl AnthropicProvider {
pub fn new(api_key: Option<&str>) -> Self { pub fn new(credential: Option<&str>) -> Self {
Self::with_base_url(api_key, None) Self::with_base_url(credential, None)
} }
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self { pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
let base_url = base_url let base_url = base_url
.map(|u| u.trim_end_matches('/')) .map(|u| u.trim_end_matches('/'))
.unwrap_or("https://api.anthropic.com") .unwrap_or("https://api.anthropic.com")
.to_string(); .to_string();
Self { Self {
credential: api_key credential: credential
.map(str::trim) .map(str::trim)
.filter(|k| !k.is_empty()) .filter(|k| !k.is_empty())
.map(ToString::to_string), .map(ToString::to_string),
@ -410,9 +410,9 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let p = AnthropicProvider::new(Some("sk-ant-test123")); let p = AnthropicProvider::new(Some("anthropic-test-credential"));
assert!(p.credential.is_some()); assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
assert_eq!(p.base_url, "https://api.anthropic.com"); assert_eq!(p.base_url, "https://api.anthropic.com");
} }
@ -431,17 +431,19 @@ mod tests {
#[test] #[test]
fn creates_with_whitespace_key() { fn creates_with_whitespace_key() {
let p = AnthropicProvider::new(Some(" sk-ant-test123 ")); let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
assert!(p.credential.is_some()); assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
} }
#[test] #[test]
fn creates_with_custom_base_url() { fn creates_with_custom_base_url() {
let p = let p = AnthropicProvider::with_base_url(
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com")); Some("anthropic-credential"),
Some("https://api.example.com"),
);
assert_eq!(p.base_url, "https://api.example.com"); assert_eq!(p.base_url, "https://api.example.com");
assert_eq!(p.credential.as_deref(), Some("sk-ant-test")); assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
} }
#[test] #[test]

View file

@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
pub struct OpenAiCompatibleProvider { pub struct OpenAiCompatibleProvider {
pub(crate) name: String, pub(crate) name: String,
pub(crate) base_url: String, pub(crate) base_url: String,
pub(crate) api_key: Option<String>, pub(crate) credential: Option<String>,
pub(crate) auth_header: AuthStyle, pub(crate) auth_header: AuthStyle,
/// When false, do not fall back to /v1/responses on chat completions 404. /// When false, do not fall back to /v1/responses on chat completions 404.
/// GLM/Zhipu does not support the responses API. /// GLM/Zhipu does not support the responses API.
@ -37,11 +37,16 @@ pub enum AuthStyle {
} }
impl OpenAiCompatibleProvider { impl OpenAiCompatibleProvider {
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self { pub fn new(
name: &str,
base_url: &str,
credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self { Self {
name: name.to_string(), name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
auth_header: auth_style, auth_header: auth_style,
supports_responses_fallback: true, supports_responses_fallback: true,
client: Client::builder() client: Client::builder()
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
pub fn new_no_responses_fallback( pub fn new_no_responses_fallback(
name: &str, name: &str,
base_url: &str, base_url: &str,
api_key: Option<&str>, credential: Option<&str>,
auth_style: AuthStyle, auth_style: AuthStyle,
) -> Self { ) -> Self {
Self { Self {
name: name.to_string(), name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
auth_header: auth_style, auth_header: auth_style,
supports_responses_fallback: false, supports_responses_fallback: false,
client: Client::builder() client: Client::builder()
@ -409,18 +414,18 @@ impl OpenAiCompatibleProvider {
fn apply_auth_header( fn apply_auth_header(
&self, &self,
req: reqwest::RequestBuilder, req: reqwest::RequestBuilder,
api_key: &str, credential: &str,
) -> reqwest::RequestBuilder { ) -> reqwest::RequestBuilder {
match &self.auth_header { match &self.auth_header {
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")), AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
AuthStyle::XApiKey => req.header("x-api-key", api_key), AuthStyle::XApiKey => req.header("x-api-key", credential),
AuthStyle::Custom(header) => req.header(header, api_key), AuthStyle::Custom(header) => req.header(header, credential),
} }
} }
async fn chat_via_responses( async fn chat_via_responses(
&self, &self,
api_key: &str, credential: &str,
system_prompt: Option<&str>, system_prompt: Option<&str>,
message: &str, message: &str,
model: &str, model: &str,
@ -438,7 +443,7 @@ impl OpenAiCompatibleProvider {
let url = self.responses_url(); let url = self.responses_url();
let response = self let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key) .apply_auth_header(self.client.post(&url).json(&request), credential)
.send() .send()
.await?; .await?;
@ -463,7 +468,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name self.name
@ -494,7 +499,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url(); let url = self.chat_completions_url();
let response = self let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key) .apply_auth_header(self.client.post(&url).json(&request), credential)
.send() .send()
.await?; .await?;
@ -505,7 +510,7 @@ impl Provider for OpenAiCompatibleProvider {
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self return self
.chat_via_responses(api_key, system_prompt, message, model) .chat_via_responses(credential, system_prompt, message, model)
.await .await
.map_err(|responses_err| { .map_err(|responses_err| {
anyhow::anyhow!( anyhow::anyhow!(
@ -549,7 +554,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name self.name
@ -573,7 +578,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url(); let url = self.chat_completions_url();
let response = self let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key) .apply_auth_header(self.client.post(&url).json(&request), credential)
.send() .send()
.await?; .await?;
@ -588,7 +593,7 @@ impl Provider for OpenAiCompatibleProvider {
if let Some(user_msg) = last_user { if let Some(user_msg) = last_user {
return self return self
.chat_via_responses( .chat_via_responses(
api_key, credential,
system.map(|m| m.content.as_str()), system.map(|m| m.content.as_str()),
&user_msg.content, &user_msg.content,
model, model,
@ -795,16 +800,20 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key")); let p = make_provider(
"venice",
"https://api.venice.ai",
Some("venice-test-credential"),
);
assert_eq!(p.name, "venice"); assert_eq!(p.name, "venice");
assert_eq!(p.base_url, "https://api.venice.ai"); assert_eq!(p.base_url, "https://api.venice.ai");
assert_eq!(p.api_key.as_deref(), Some("vn-key")); assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
} }
#[test] #[test]
fn creates_without_key() { fn creates_without_key() {
let p = make_provider("test", "https://example.com", None); let p = make_provider("test", "https://example.com", None);
assert!(p.api_key.is_none()); assert!(p.credential.is_none());
} }
#[test] #[test]

View file

@ -104,8 +104,8 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
/// ///
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens) /// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
/// followed by `ANTHROPIC_API_KEY` (for regular API keys). /// followed by `ANTHROPIC_API_KEY` (for regular API keys).
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> { fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) { if let Some(key) = credential_override.map(str::trim).filter(|k| !k.is_empty()) {
return Some(key.to_string()); return Some(key.to_string());
} }
@ -194,7 +194,7 @@ pub fn create_provider_with_url(
api_key: Option<&str>, api_key: Option<&str>,
api_url: Option<&str>, api_url: Option<&str>,
) -> anyhow::Result<Box<dyn Provider>> { ) -> anyhow::Result<Box<dyn Provider>> {
let resolved_key = resolve_api_key(name, api_key); let resolved_key = resolve_provider_credential(name, api_key);
let key = resolved_key.as_deref(); let key = resolved_key.as_deref();
match name { match name {
// ── Primary providers (custom implementations) ─────── // ── Primary providers (custom implementations) ───────
@ -454,8 +454,8 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn resolve_api_key_prefers_explicit_argument() { fn resolve_provider_credential_prefers_explicit_argument() {
let resolved = resolve_api_key("openrouter", Some(" explicit-key ")); let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
assert_eq!(resolved.as_deref(), Some("explicit-key")); assert_eq!(resolved.as_deref(), Some("explicit-key"));
} }
@ -463,18 +463,18 @@ mod tests {
#[test] #[test]
fn factory_openrouter() { fn factory_openrouter() {
assert!(create_provider("openrouter", Some("sk-test")).is_ok()); assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
assert!(create_provider("openrouter", None).is_ok()); assert!(create_provider("openrouter", None).is_ok());
} }
#[test] #[test]
fn factory_anthropic() { fn factory_anthropic() {
assert!(create_provider("anthropic", Some("sk-test")).is_ok()); assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
} }
#[test] #[test]
fn factory_openai() { fn factory_openai() {
assert!(create_provider("openai", Some("sk-test")).is_ok()); assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
} }
#[test] #[test]
@ -774,15 +774,24 @@ mod tests {
scheduler_retries: 2, scheduler_retries: 2,
}; };
let provider = create_resilient_provider("openrouter", Some("sk-test"), None, &reliability); let provider = create_resilient_provider(
"openrouter",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_ok()); assert!(provider.is_ok());
} }
#[test] #[test]
fn resilient_provider_errors_for_invalid_primary() { fn resilient_provider_errors_for_invalid_primary() {
let reliability = crate::config::ReliabilityConfig::default(); let reliability = crate::config::ReliabilityConfig::default();
let provider = let provider = create_resilient_provider(
create_resilient_provider("totally-invalid", Some("sk-test"), None, &reliability); "totally-invalid",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_err()); assert!(provider.is_err());
} }

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub struct OpenAiProvider { pub struct OpenAiProvider {
api_key: Option<String>, credential: Option<String>,
client: Client, client: Client,
} }
@ -110,9 +110,9 @@ struct NativeResponseMessage {
} }
impl OpenAiProvider { impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self { pub fn new(credential: Option<&str>) -> Self {
Self { Self {
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
client: Client::builder() client: Client::builder()
.timeout(std::time::Duration::from_secs(120)) .timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10)) .connect_timeout(std::time::Duration::from_secs(10))
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?; })?;
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
let response = self let response = self
.client .client
.post("https://api.openai.com/v1/chat/completions") .post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.json(&request) .json(&request)
.send() .send()
.await?; .await?;
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<ProviderChatResponse> { ) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?; })?;
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
let response = self let response = self
.client .client
.post("https://api.openai.com/v1/chat/completions") .post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.json(&native_request) .json(&native_request)
.send() .send()
.await?; .await?;
@ -330,20 +330,20 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123")); let p = OpenAiProvider::new(Some("openai-test-credential"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123")); assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
} }
#[test] #[test]
fn creates_without_key() { fn creates_without_key() {
let p = OpenAiProvider::new(None); let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none()); assert!(p.credential.is_none());
} }
#[test] #[test]
fn creates_with_empty_key() { fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some("")); let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some("")); assert_eq!(p.credential.as_deref(), Some(""));
} }
#[tokio::test] #[tokio::test]

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider { pub struct OpenRouterProvider {
api_key: Option<String>, credential: Option<String>,
client: Client, client: Client,
} }
@ -110,9 +110,9 @@ struct NativeResponseMessage {
} }
impl OpenRouterProvider { impl OpenRouterProvider {
pub fn new(api_key: Option<&str>) -> Self { pub fn new(credential: Option<&str>) -> Self {
Self { Self {
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
client: Client::builder() client: Client::builder()
.timeout(std::time::Duration::from_secs(120)) .timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10)) .connect_timeout(std::time::Duration::from_secs(10))
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
async fn warmup(&self) -> anyhow::Result<()> { async fn warmup(&self) -> anyhow::Result<()> {
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool. // Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
// This prevents the first real chat request from timing out on cold start. // This prevents the first real chat request from timing out on cold start.
if let Some(api_key) = self.api_key.as_ref() { if let Some(credential) = self.credential.as_ref() {
self.client self.client
.get("https://openrouter.ai/api/v1/auth/key") .get("https://openrouter.ai/api/v1/auth/key")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.send() .send()
.await? .await?
.error_for_status()?; .error_for_status()?;
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref() let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new(); let mut messages = Vec::new();
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref() let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec<Message> = messages let api_messages: Vec<Message> = messages
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<ProviderChatResponse> { ) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
) )
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -494,14 +494,17 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let provider = OpenRouterProvider::new(Some("sk-or-123")); let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123")); assert_eq!(
provider.credential.as_deref(),
Some("openrouter-test-credential")
);
} }
#[test] #[test]
fn creates_without_key() { fn creates_without_key() {
let provider = OpenRouterProvider::new(None); let provider = OpenRouterProvider::new(None);
assert!(provider.api_key.is_none()); assert!(provider.credential.is_none());
} }
#[tokio::test] #[tokio::test]

View file

@ -81,14 +81,17 @@ mod tests {
#[test] #[test]
fn bubblewrap_sandbox_name() { fn bubblewrap_sandbox_name() {
assert_eq!(BubblewrapSandbox.name(), "bubblewrap"); let sandbox = BubblewrapSandbox;
assert_eq!(sandbox.name(), "bubblewrap");
} }
#[test] #[test]
fn bubblewrap_is_available_only_if_installed() { fn bubblewrap_is_available_only_if_installed() {
// Result depends on whether bwrap is installed // Result depends on whether bwrap is installed
let available = BubblewrapSandbox::is_available(); let sandbox = BubblewrapSandbox;
let _available = sandbox.is_available();
// Either way, the name should still work // Either way, the name should still work
assert_eq!(BubblewrapSandbox.name(), "bubblewrap"); assert_eq!(sandbox.name(), "bubblewrap");
} }
} }

View file

@ -112,12 +112,12 @@ impl ComposioTool {
action_name: &str, action_name: &str,
params: serde_json::Value, params: serde_json::Value,
entity_id: Option<&str>, entity_id: Option<&str>,
connected_account_id: Option<&str>, connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> { ) -> anyhow::Result<serde_json::Value> {
let tool_slug = normalize_tool_slug(action_name); let tool_slug = normalize_tool_slug(action_name);
match self match self
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id) .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
.await .await
{ {
Ok(result) => Ok(result), Ok(result) => Ok(result),
@ -130,21 +130,16 @@ impl ComposioTool {
} }
} }
async fn execute_action_v3( fn build_execute_action_v3_request(
&self,
tool_slug: &str, tool_slug: &str,
params: serde_json::Value, params: serde_json::Value,
entity_id: Option<&str>, entity_id: Option<&str>,
connected_account_id: Option<&str>, connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> { ) -> (String, serde_json::Value) {
let url = if let Some(connected_account_id) = connected_account_id let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
let account_ref = connected_account_ref
.map(str::trim) .map(str::trim)
.filter(|id| !id.is_empty()) .filter(|id| !id.is_empty());
{
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
} else {
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
};
let mut body = json!({ let mut body = json!({
"arguments": params, "arguments": params,
@ -153,6 +148,26 @@ impl ComposioTool {
if let Some(entity) = entity_id { if let Some(entity) = entity_id {
body["user_id"] = json!(entity); body["user_id"] = json!(entity);
} }
if let Some(account_ref) = account_ref {
body["connected_account_id"] = json!(account_ref);
}
(url, body)
}
async fn execute_action_v3(
&self,
tool_slug: &str,
params: serde_json::Value,
entity_id: Option<&str>,
connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let (url, body) = Self::build_execute_action_v3_request(
tool_slug,
params,
entity_id,
connected_account_ref,
);
let resp = self let resp = self
.client .client
@ -474,11 +489,11 @@ impl Tool for ComposioTool {
})?; })?;
let params = args.get("params").cloned().unwrap_or(json!({})); let params = args.get("params").cloned().unwrap_or(json!({}));
let connected_account_id = let connected_account_ref =
args.get("connected_account_id").and_then(|v| v.as_str()); args.get("connected_account_id").and_then(|v| v.as_str());
match self match self
.execute_action(action_name, params, Some(entity_id), connected_account_id) .execute_action(action_name, params, Some(entity_id), connected_account_ref)
.await .await
{ {
Ok(result) => { Ok(result) => {
@ -948,4 +963,40 @@ mod tests {
fn composio_api_base_url_is_v3() { fn composio_api_base_url_is_v3() {
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3"); assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
} }
#[test]
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
let (url, body) = ComposioTool::build_execute_action_v3_request(
"gmail-send-email",
json!({"to": "test@example.com"}),
Some("workspace-user"),
Some("account-42"),
);
assert_eq!(
url,
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
);
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
assert_eq!(body["user_id"], json!("workspace-user"));
assert_eq!(body["connected_account_id"], json!("account-42"));
}
#[test]
fn build_execute_action_v3_request_drops_blank_optional_fields() {
let (url, body) = ComposioTool::build_execute_action_v3_request(
"github-list-repos",
json!({}),
None,
Some(" "),
);
assert_eq!(
url,
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
);
assert_eq!(body["arguments"], json!({}));
assert!(body.get("connected_account_id").is_none());
assert!(body.get("user_id").is_none());
}
} }

View file

@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
/// summarization) to purpose-built sub-agents. /// summarization) to purpose-built sub-agents.
pub struct DelegateTool { pub struct DelegateTool {
agents: Arc<HashMap<String, DelegateAgentConfig>>, agents: Arc<HashMap<String, DelegateAgentConfig>>,
/// Global API key fallback (from config.api_key) /// Global credential fallback (from config.api_key)
fallback_api_key: Option<String>, fallback_credential: Option<String>,
/// Depth at which this tool instance lives in the delegation chain. /// Depth at which this tool instance lives in the delegation chain.
depth: u32, depth: u32,
} }
@ -25,11 +25,11 @@ pub struct DelegateTool {
impl DelegateTool { impl DelegateTool {
pub fn new( pub fn new(
agents: HashMap<String, DelegateAgentConfig>, agents: HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<String>, fallback_credential: Option<String>,
) -> Self { ) -> Self {
Self { Self {
agents: Arc::new(agents), agents: Arc::new(agents),
fallback_api_key, fallback_credential,
depth: 0, depth: 0,
} }
} }
@ -39,12 +39,12 @@ impl DelegateTool {
/// their DelegateTool via this method with `depth: parent.depth + 1`. /// their DelegateTool via this method with `depth: parent.depth + 1`.
pub fn with_depth( pub fn with_depth(
agents: HashMap<String, DelegateAgentConfig>, agents: HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<String>, fallback_credential: Option<String>,
depth: u32, depth: u32,
) -> Self { ) -> Self {
Self { Self {
agents: Arc::new(agents), agents: Arc::new(agents),
fallback_api_key, fallback_credential,
depth, depth,
} }
} }
@ -165,13 +165,13 @@ impl Tool for DelegateTool {
} }
// Create provider for this agent // Create provider for this agent
let api_key = agent_config let provider_credential = agent_config
.api_key .api_key
.as_deref() .as_deref()
.or(self.fallback_api_key.as_deref()); .or(self.fallback_credential.as_deref());
let provider: Box<dyn Provider> = let provider: Box<dyn Provider> =
match providers::create_provider(&agent_config.provider, api_key) { match providers::create_provider(&agent_config.provider, provider_credential) {
Ok(p) => p, Ok(p) => p,
Err(e) => { Err(e) => {
return Ok(ToolResult { return Ok(ToolResult {
@ -268,7 +268,7 @@ mod tests {
provider: "openrouter".to_string(), provider: "openrouter".to_string(),
model: "anthropic/claude-sonnet-4-20250514".to_string(), model: "anthropic/claude-sonnet-4-20250514".to_string(),
system_prompt: None, system_prompt: None,
api_key: Some("sk-test".to_string()), api_key: Some("delegate-test-credential".to_string()),
temperature: None, temperature: None,
max_depth: 2, max_depth: 2,
}, },

View file

@ -440,7 +440,7 @@ mod tests {
&http, &http,
tmp.path(), tmp.path(),
&agents, &agents,
Some("sk-test"), Some("delegate-test-credential"),
&cfg, &cfg,
); );
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();