fix(security): remediate unassigned CodeQL findings
- harden URL/request handling for composio and whatsapp integrations - reduce cleartext logging exposure across providers/tools/gateway - hash and constant-time compare gateway webhook secrets - expand nested secret encryption coverage in config - align feature aliases and add regression tests for security paths - fix bubblewrap all-features test invocation surfaced during deep validation
This commit is contained in:
parent
f9d681063d
commit
1711f140be
14 changed files with 481 additions and 146 deletions
15
Cargo.toml
15
Cargo.toml
|
|
@ -63,9 +63,6 @@ rand = "0.8"
|
||||||
# Fast mutexes that don't poison on panic
|
# Fast mutexes that don't poison on panic
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
|
|
||||||
# Landlock (Linux sandbox) - optional dependency
|
|
||||||
landlock = { version = "0.4", optional = true }
|
|
||||||
|
|
||||||
# Async traits
|
# Async traits
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
|
||||||
|
|
@ -120,14 +117,24 @@ probe-rs = { version = "0.30", optional = true }
|
||||||
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
||||||
pdf-extract = { version = "0.10", optional = true }
|
pdf-extract = { version = "0.10", optional = true }
|
||||||
|
|
||||||
# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS
|
# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
|
||||||
[target.'cfg(target_os = "linux")'.dependencies]
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
rppal = { version = "0.14", optional = true }
|
rppal = { version = "0.14", optional = true }
|
||||||
|
landlock = { version = "0.4", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["hardware"]
|
default = ["hardware"]
|
||||||
hardware = ["nusb", "tokio-serial"]
|
hardware = ["nusb", "tokio-serial"]
|
||||||
peripheral-rpi = ["rppal"]
|
peripheral-rpi = ["rppal"]
|
||||||
|
# Browser backend feature alias used by cfg(feature = "browser-native")
|
||||||
|
browser-native = ["dep:fantoccini"]
|
||||||
|
# Backward-compatible alias for older invocations
|
||||||
|
fantoccini = ["browser-native"]
|
||||||
|
# Sandbox feature aliases used by cfg(feature = "sandbox-*")
|
||||||
|
sandbox-landlock = ["dep:landlock"]
|
||||||
|
sandbox-bubblewrap = []
|
||||||
|
# Backward-compatible alias for older invocations
|
||||||
|
landlock = ["sandbox-landlock"]
|
||||||
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
||||||
probe = ["dep:probe-rs"]
|
probe = ["dep:probe-rs"]
|
||||||
# rag-pdf = PDF ingestion for datasheet RAG
|
# rag-pdf = PDF ingestion for datasheet RAG
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use uuid::Uuid;
|
||||||
/// happens in the gateway when Meta sends webhook events.
|
/// happens in the gateway when Meta sends webhook events.
|
||||||
pub struct WhatsAppChannel {
|
pub struct WhatsAppChannel {
|
||||||
access_token: String,
|
access_token: String,
|
||||||
phone_number_id: String,
|
endpoint_id: String,
|
||||||
verify_token: String,
|
verify_token: String,
|
||||||
allowed_numbers: Vec<String>,
|
allowed_numbers: Vec<String>,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
|
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
|
||||||
impl WhatsAppChannel {
|
impl WhatsAppChannel {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
access_token: String,
|
access_token: String,
|
||||||
phone_number_id: String,
|
endpoint_id: String,
|
||||||
verify_token: String,
|
verify_token: String,
|
||||||
allowed_numbers: Vec<String>,
|
allowed_numbers: Vec<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
access_token,
|
access_token,
|
||||||
phone_number_id,
|
endpoint_id,
|
||||||
verify_token,
|
verify_token,
|
||||||
allowed_numbers,
|
allowed_numbers,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
|
|
@ -142,7 +142,7 @@ impl Channel for WhatsAppChannel {
|
||||||
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
|
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"https://graph.facebook.com/v18.0/{}/messages",
|
"https://graph.facebook.com/v18.0/{}/messages",
|
||||||
self.phone_number_id
|
self.endpoint_id
|
||||||
);
|
);
|
||||||
|
|
||||||
// Normalize recipient (remove leading + if present for API)
|
// Normalize recipient (remove leading + if present for API)
|
||||||
|
|
@ -162,7 +162,7 @@ impl Channel for WhatsAppChannel {
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
.bearer_auth(&self.access_token)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
|
|
@ -195,11 +195,11 @@ impl Channel for WhatsAppChannel {
|
||||||
|
|
||||||
async fn health_check(&self) -> bool {
|
async fn health_check(&self) -> bool {
|
||||||
// Check if we can reach the WhatsApp API
|
// Check if we can reach the WhatsApp API
|
||||||
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id);
|
let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
|
||||||
|
|
||||||
self.client
|
self.client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
.bearer_auth(&self.access_token)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map(|r| r.status().is_success())
|
.map(|r| r.status().is_success())
|
||||||
|
|
|
||||||
|
|
@ -1678,6 +1678,40 @@ fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
|
||||||
workspace_config_dir
|
workspace_config_dir
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn decrypt_optional_secret(
|
||||||
|
store: &crate::security::SecretStore,
|
||||||
|
value: &mut Option<String>,
|
||||||
|
field_name: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
if let Some(raw) = value.clone() {
|
||||||
|
if crate::security::SecretStore::is_encrypted(&raw) {
|
||||||
|
*value = Some(
|
||||||
|
store
|
||||||
|
.decrypt(&raw)
|
||||||
|
.with_context(|| format!("Failed to decrypt {field_name}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_optional_secret(
|
||||||
|
store: &crate::security::SecretStore,
|
||||||
|
value: &mut Option<String>,
|
||||||
|
field_name: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
if let Some(raw) = value.clone() {
|
||||||
|
if !crate::security::SecretStore::is_encrypted(&raw) {
|
||||||
|
*value = Some(
|
||||||
|
store
|
||||||
|
.encrypt(&raw)
|
||||||
|
.with_context(|| format!("Failed to encrypt {field_name}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
pub fn load_or_init() -> Result<Self> {
|
pub fn load_or_init() -> Result<Self> {
|
||||||
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
|
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
|
||||||
|
|
@ -1702,6 +1736,23 @@ impl Config {
|
||||||
// Set computed paths that are skipped during serialization
|
// Set computed paths that are skipped during serialization
|
||||||
config.config_path = config_path.clone();
|
config.config_path = config_path.clone();
|
||||||
config.workspace_dir = workspace_dir;
|
config.workspace_dir = workspace_dir;
|
||||||
|
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
|
||||||
|
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
|
||||||
|
decrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config.composio.api_key,
|
||||||
|
"config.composio.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
decrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config.browser.computer_use.api_key,
|
||||||
|
"config.browser.computer_use.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for agent in config.agents.values_mut() {
|
||||||
|
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||||
|
}
|
||||||
config.apply_env_overrides();
|
config.apply_env_overrides();
|
||||||
Ok(config)
|
Ok(config)
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -1789,23 +1840,29 @@ impl Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save(&self) -> Result<()> {
|
pub fn save(&self) -> Result<()> {
|
||||||
// Encrypt agent API keys before serialization
|
// Encrypt secrets before serialization
|
||||||
let mut config_to_save = self.clone();
|
let mut config_to_save = self.clone();
|
||||||
let zeroclaw_dir = self
|
let zeroclaw_dir = self
|
||||||
.config_path
|
.config_path
|
||||||
.parent()
|
.parent()
|
||||||
.context("Config path must have a parent directory")?;
|
.context("Config path must have a parent directory")?;
|
||||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
||||||
|
|
||||||
|
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
|
||||||
|
encrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config_to_save.composio.api_key,
|
||||||
|
"config.composio.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
encrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config_to_save.browser.computer_use.api_key,
|
||||||
|
"config.browser.computer_use.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
for agent in config_to_save.agents.values_mut() {
|
for agent in config_to_save.agents.values_mut() {
|
||||||
if let Some(ref plaintext_key) = agent.api_key {
|
encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||||
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
|
|
||||||
agent.api_key = Some(
|
|
||||||
store
|
|
||||||
.encrypt(plaintext_key)
|
|
||||||
.context("Failed to encrypt agent API key")?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let toml_str =
|
let toml_str =
|
||||||
|
|
@ -2182,13 +2239,82 @@ tool_dispatcher = "xml"
|
||||||
|
|
||||||
let contents = fs::read_to_string(&config_path).unwrap();
|
let contents = fs::read_to_string(&config_path).unwrap();
|
||||||
let loaded: Config = toml::from_str(&contents).unwrap();
|
let loaded: Config = toml::from_str(&contents).unwrap();
|
||||||
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
|
assert!(loaded
|
||||||
|
.api_key
|
||||||
|
.as_deref()
|
||||||
|
.is_some_and(crate::security::SecretStore::is_encrypted));
|
||||||
|
let store = crate::security::SecretStore::new(&dir, true);
|
||||||
|
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
|
||||||
|
assert_eq!(decrypted, "sk-roundtrip");
|
||||||
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
||||||
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
||||||
|
|
||||||
let _ = fs::remove_dir_all(&dir);
|
let _ = fs::remove_dir_all(&dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_save_encrypts_nested_credentials() {
|
||||||
|
let dir = std::env::temp_dir().join(format!(
|
||||||
|
"zeroclaw_test_nested_credentials_{}",
|
||||||
|
uuid::Uuid::new_v4()
|
||||||
|
));
|
||||||
|
fs::create_dir_all(&dir).unwrap();
|
||||||
|
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.workspace_dir = dir.join("workspace");
|
||||||
|
config.config_path = dir.join("config.toml");
|
||||||
|
config.api_key = Some("root-credential".into());
|
||||||
|
config.composio.api_key = Some("composio-credential".into());
|
||||||
|
config.browser.computer_use.api_key = Some("browser-credential".into());
|
||||||
|
|
||||||
|
config.agents.insert(
|
||||||
|
"worker".into(),
|
||||||
|
DelegateAgentConfig {
|
||||||
|
provider: "openrouter".into(),
|
||||||
|
model: "model-test".into(),
|
||||||
|
system_prompt: None,
|
||||||
|
api_key: Some("agent-credential".into()),
|
||||||
|
temperature: None,
|
||||||
|
max_depth: 3,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
config.save().unwrap();
|
||||||
|
|
||||||
|
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
|
||||||
|
let stored: Config = toml::from_str(&contents).unwrap();
|
||||||
|
let store = crate::security::SecretStore::new(&dir, true);
|
||||||
|
|
||||||
|
let root_encrypted = stored.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
|
||||||
|
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
|
||||||
|
|
||||||
|
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(
|
||||||
|
composio_encrypted
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
store.decrypt(composio_encrypted).unwrap(),
|
||||||
|
"composio-credential"
|
||||||
|
);
|
||||||
|
|
||||||
|
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(
|
||||||
|
browser_encrypted
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
store.decrypt(browser_encrypted).unwrap(),
|
||||||
|
"browser-credential"
|
||||||
|
);
|
||||||
|
|
||||||
|
let worker = stored.agents.get("worker").unwrap();
|
||||||
|
let worker_encrypted = worker.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
|
||||||
|
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
|
||||||
|
|
||||||
|
let _ = fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn config_save_atomic_cleanup() {
|
fn config_save_atomic_cleanup() {
|
||||||
let dir =
|
let dir =
|
||||||
|
|
|
||||||
|
|
@ -48,6 +48,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
|
||||||
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn hash_webhook_secret(value: &str) -> String {
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
let digest = Sha256::digest(value.as_bytes());
|
||||||
|
hex::encode(digest)
|
||||||
|
}
|
||||||
|
|
||||||
/// How often the rate limiter sweeps stale IP entries from its map.
|
/// How often the rate limiter sweeps stale IP entries from its map.
|
||||||
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
||||||
|
|
||||||
|
|
@ -179,7 +186,8 @@ pub struct AppState {
|
||||||
pub temperature: f64,
|
pub temperature: f64,
|
||||||
pub mem: Arc<dyn Memory>,
|
pub mem: Arc<dyn Memory>,
|
||||||
pub auto_save: bool,
|
pub auto_save: bool,
|
||||||
pub webhook_secret: Option<Arc<str>>,
|
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
|
||||||
|
pub webhook_secret_hash: Option<Arc<str>>,
|
||||||
pub pairing: Arc<PairingGuard>,
|
pub pairing: Arc<PairingGuard>,
|
||||||
pub rate_limiter: Arc<GatewayRateLimiter>,
|
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||||
pub idempotency_store: Arc<IdempotencyStore>,
|
pub idempotency_store: Arc<IdempotencyStore>,
|
||||||
|
|
@ -253,11 +261,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
&config,
|
&config,
|
||||||
));
|
));
|
||||||
// Extract webhook secret for authentication
|
// Extract webhook secret for authentication
|
||||||
let webhook_secret: Option<Arc<str>> = config
|
let webhook_secret_hash: Option<Arc<str>> = config
|
||||||
.channels_config
|
.channels_config
|
||||||
.webhook
|
.webhook
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.and_then(|w| w.secret.as_deref())
|
.and_then(|w| w.secret.as_deref())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|secret| !secret.is_empty())
|
||||||
|
.map(hash_webhook_secret)
|
||||||
.map(Arc::from);
|
.map(Arc::from);
|
||||||
|
|
||||||
// WhatsApp channel (if configured)
|
// WhatsApp channel (if configured)
|
||||||
|
|
@ -344,7 +355,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
} else {
|
} else {
|
||||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||||
}
|
}
|
||||||
if webhook_secret.is_some() {
|
if webhook_secret_hash.is_some() {
|
||||||
println!(" 🔒 Webhook secret: ENABLED");
|
println!(" 🔒 Webhook secret: ENABLED");
|
||||||
}
|
}
|
||||||
println!(" Press Ctrl+C to stop.\n");
|
println!(" Press Ctrl+C to stop.\n");
|
||||||
|
|
@ -358,7 +369,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
temperature,
|
temperature,
|
||||||
mem,
|
mem,
|
||||||
auto_save: config.memory.auto_save,
|
auto_save: config.memory.auto_save,
|
||||||
webhook_secret,
|
webhook_secret_hash,
|
||||||
pairing,
|
pairing,
|
||||||
rate_limiter,
|
rate_limiter,
|
||||||
idempotency_store,
|
idempotency_store,
|
||||||
|
|
@ -484,12 +495,15 @@ async fn handle_webhook(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Webhook secret auth (optional, additional layer) ──
|
// ── Webhook secret auth (optional, additional layer) ──
|
||||||
if let Some(ref secret) = state.webhook_secret {
|
if let Some(ref secret_hash) = state.webhook_secret_hash {
|
||||||
let header_val = headers
|
let header_hash = headers
|
||||||
.get("X-Webhook-Secret")
|
.get("X-Webhook-Secret")
|
||||||
.and_then(|v| v.to_str().ok());
|
.and_then(|v| v.to_str().ok())
|
||||||
match header_val {
|
.map(str::trim)
|
||||||
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(hash_webhook_secret);
|
||||||
|
match header_hash {
|
||||||
|
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
|
||||||
_ => {
|
_ => {
|
||||||
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||||
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||||
|
|
@ -993,7 +1007,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
mem: memory,
|
mem: memory,
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
|
@ -1041,7 +1055,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
mem: memory,
|
mem: memory,
|
||||||
auto_save: true,
|
auto_save: true,
|
||||||
webhook_secret: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
|
@ -1079,6 +1093,125 @@ mod tests {
|
||||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn webhook_secret_hash_is_deterministic_and_nonempty() {
|
||||||
|
let one = hash_webhook_secret("secret-value");
|
||||||
|
let two = hash_webhook_secret("secret-value");
|
||||||
|
let other = hash_webhook_secret("other-value");
|
||||||
|
|
||||||
|
assert_eq!(one, two);
|
||||||
|
assert_ne!(one, other);
|
||||||
|
assert_eq!(one.len(), 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_secret_hash_rejects_missing_header() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
HeaderMap::new(),
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_secret_hash_rejects_invalid_header() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
headers,
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_secret_hash_accepts_valid_header() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
headers,
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
|
||||||
|
|
@ -285,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn run_quick_setup(
|
pub fn run_quick_setup(
|
||||||
api_key: Option<&str>,
|
credential_override: Option<&str>,
|
||||||
provider: Option<&str>,
|
provider: Option<&str>,
|
||||||
memory_backend: Option<&str>,
|
memory_backend: Option<&str>,
|
||||||
) -> Result<Config> {
|
) -> Result<Config> {
|
||||||
|
|
@ -319,7 +319,7 @@ pub fn run_quick_setup(
|
||||||
let config = Config {
|
let config = Config {
|
||||||
workspace_dir: workspace_dir.clone(),
|
workspace_dir: workspace_dir.clone(),
|
||||||
config_path: config_path.clone(),
|
config_path: config_path.clone(),
|
||||||
api_key: api_key.map(String::from),
|
api_key: credential_override.map(String::from),
|
||||||
api_url: None,
|
api_url: None,
|
||||||
default_provider: Some(provider_name.clone()),
|
default_provider: Some(provider_name.clone()),
|
||||||
default_model: Some(model.clone()),
|
default_model: Some(model.clone()),
|
||||||
|
|
@ -379,7 +379,7 @@ pub fn run_quick_setup(
|
||||||
println!(
|
println!(
|
||||||
" {} API Key: {}",
|
" {} API Key: {}",
|
||||||
style("✓").green().bold(),
|
style("✓").green().bold(),
|
||||||
if api_key.is_some() {
|
if credential_override.is_some() {
|
||||||
style("set").green()
|
style("set").green()
|
||||||
} else {
|
} else {
|
||||||
style("not set (use --api-key or edit config.toml)").yellow()
|
style("not set (use --api-key or edit config.toml)").yellow()
|
||||||
|
|
@ -428,7 +428,7 @@ pub fn run_quick_setup(
|
||||||
);
|
);
|
||||||
println!();
|
println!();
|
||||||
println!(" {}", style("Next steps:").white().bold());
|
println!(" {}", style("Next steps:").white().bold());
|
||||||
if api_key.is_none() {
|
if credential_override.is_none() {
|
||||||
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
|
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
|
||||||
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
|
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
|
||||||
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
|
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
|
||||||
|
|
@ -2801,22 +2801,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
.header("Authorization", format!("Bearer {access_token_clone}"))
|
.header("Authorization", format!("Bearer {access_token_clone}"))
|
||||||
.send()?;
|
.send()?;
|
||||||
let ok = resp.status().is_success();
|
let ok = resp.status().is_success();
|
||||||
let data: serde_json::Value = resp.json().unwrap_or_default();
|
Ok::<_, reqwest::Error>(ok)
|
||||||
let user_id = data
|
|
||||||
.get("user_id")
|
|
||||||
.and_then(serde_json::Value::as_str)
|
|
||||||
.unwrap_or("unknown")
|
|
||||||
.to_string();
|
|
||||||
Ok::<_, reqwest::Error>((ok, user_id))
|
|
||||||
})
|
})
|
||||||
.join();
|
.join();
|
||||||
match thread_result {
|
match thread_result {
|
||||||
Ok(Ok((true, user_id))) => {
|
Ok(Ok(true)) => println!(
|
||||||
println!(
|
"\r {} Connection verified ",
|
||||||
"\r {} Connected as {user_id} ",
|
|
||||||
style("✅").green().bold()
|
style("✅").green().bold()
|
||||||
);
|
),
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
println!(
|
println!(
|
||||||
"\r {} Connection failed — check homeserver URL and token",
|
"\r {} Connection failed — check homeserver URL and token",
|
||||||
|
|
|
||||||
|
|
@ -106,17 +106,17 @@ struct NativeContentIn {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(credential: Option<&str>) -> Self {
|
||||||
Self::with_base_url(api_key, None)
|
Self::with_base_url(credential, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
|
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
|
||||||
let base_url = base_url
|
let base_url = base_url
|
||||||
.map(|u| u.trim_end_matches('/'))
|
.map(|u| u.trim_end_matches('/'))
|
||||||
.unwrap_or("https://api.anthropic.com")
|
.unwrap_or("https://api.anthropic.com")
|
||||||
.to_string();
|
.to_string();
|
||||||
Self {
|
Self {
|
||||||
credential: api_key
|
credential: credential
|
||||||
.map(str::trim)
|
.map(str::trim)
|
||||||
.filter(|k| !k.is_empty())
|
.filter(|k| !k.is_empty())
|
||||||
.map(ToString::to_string),
|
.map(ToString::to_string),
|
||||||
|
|
@ -410,9 +410,9 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
|
||||||
assert!(p.credential.is_some());
|
assert!(p.credential.is_some());
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -431,17 +431,19 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_whitespace_key() {
|
fn creates_with_whitespace_key() {
|
||||||
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
|
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
|
||||||
assert!(p.credential.is_some());
|
assert!(p.credential.is_some());
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_custom_base_url() {
|
fn creates_with_custom_base_url() {
|
||||||
let p =
|
let p = AnthropicProvider::with_base_url(
|
||||||
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
Some("anthropic-credential"),
|
||||||
|
Some("https://api.example.com"),
|
||||||
|
);
|
||||||
assert_eq!(p.base_url, "https://api.example.com");
|
assert_eq!(p.base_url, "https://api.example.com");
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
|
||||||
pub struct OpenAiCompatibleProvider {
|
pub struct OpenAiCompatibleProvider {
|
||||||
pub(crate) name: String,
|
pub(crate) name: String,
|
||||||
pub(crate) base_url: String,
|
pub(crate) base_url: String,
|
||||||
pub(crate) api_key: Option<String>,
|
pub(crate) credential: Option<String>,
|
||||||
pub(crate) auth_header: AuthStyle,
|
pub(crate) auth_header: AuthStyle,
|
||||||
/// When false, do not fall back to /v1/responses on chat completions 404.
|
/// When false, do not fall back to /v1/responses on chat completions 404.
|
||||||
/// GLM/Zhipu does not support the responses API.
|
/// GLM/Zhipu does not support the responses API.
|
||||||
|
|
@ -37,11 +37,16 @@ pub enum AuthStyle {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompatibleProvider {
|
impl OpenAiCompatibleProvider {
|
||||||
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
|
pub fn new(
|
||||||
|
name: &str,
|
||||||
|
base_url: &str,
|
||||||
|
credential: Option<&str>,
|
||||||
|
auth_style: AuthStyle,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
auth_header: auth_style,
|
auth_header: auth_style,
|
||||||
supports_responses_fallback: true,
|
supports_responses_fallback: true,
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
|
|
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
|
||||||
pub fn new_no_responses_fallback(
|
pub fn new_no_responses_fallback(
|
||||||
name: &str,
|
name: &str,
|
||||||
base_url: &str,
|
base_url: &str,
|
||||||
api_key: Option<&str>,
|
credential: Option<&str>,
|
||||||
auth_style: AuthStyle,
|
auth_style: AuthStyle,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
auth_header: auth_style,
|
auth_header: auth_style,
|
||||||
supports_responses_fallback: false,
|
supports_responses_fallback: false,
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
|
|
@ -409,18 +414,18 @@ impl OpenAiCompatibleProvider {
|
||||||
fn apply_auth_header(
|
fn apply_auth_header(
|
||||||
&self,
|
&self,
|
||||||
req: reqwest::RequestBuilder,
|
req: reqwest::RequestBuilder,
|
||||||
api_key: &str,
|
credential: &str,
|
||||||
) -> reqwest::RequestBuilder {
|
) -> reqwest::RequestBuilder {
|
||||||
match &self.auth_header {
|
match &self.auth_header {
|
||||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
|
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
|
||||||
AuthStyle::XApiKey => req.header("x-api-key", api_key),
|
AuthStyle::XApiKey => req.header("x-api-key", credential),
|
||||||
AuthStyle::Custom(header) => req.header(header, api_key),
|
AuthStyle::Custom(header) => req.header(header, credential),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_via_responses(
|
async fn chat_via_responses(
|
||||||
&self,
|
&self,
|
||||||
api_key: &str,
|
credential: &str,
|
||||||
system_prompt: Option<&str>,
|
system_prompt: Option<&str>,
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
|
@ -438,7 +443,7 @@ impl OpenAiCompatibleProvider {
|
||||||
let url = self.responses_url();
|
let url = self.responses_url();
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -463,7 +468,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
self.name
|
self.name
|
||||||
|
|
@ -494,7 +499,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -505,7 +510,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(api_key, system_prompt, message, model)
|
.chat_via_responses(credential, system_prompt, message, model)
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
|
|
@ -549,7 +554,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
self.name
|
self.name
|
||||||
|
|
@ -573,7 +578,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
let response = self
|
let response = self
|
||||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -588,7 +593,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
if let Some(user_msg) = last_user {
|
if let Some(user_msg) = last_user {
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(
|
.chat_via_responses(
|
||||||
api_key,
|
credential,
|
||||||
system.map(|m| m.content.as_str()),
|
system.map(|m| m.content.as_str()),
|
||||||
&user_msg.content,
|
&user_msg.content,
|
||||||
model,
|
model,
|
||||||
|
|
@ -795,16 +800,20 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
|
let p = make_provider(
|
||||||
|
"venice",
|
||||||
|
"https://api.venice.ai",
|
||||||
|
Some("venice-test-credential"),
|
||||||
|
);
|
||||||
assert_eq!(p.name, "venice");
|
assert_eq!(p.name, "venice");
|
||||||
assert_eq!(p.base_url, "https://api.venice.ai");
|
assert_eq!(p.base_url, "https://api.venice.ai");
|
||||||
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
|
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_without_key() {
|
fn creates_without_key() {
|
||||||
let p = make_provider("test", "https://example.com", None);
|
let p = make_provider("test", "https://example.com", None);
|
||||||
assert!(p.api_key.is_none());
|
assert!(p.credential.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,8 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
|
||||||
///
|
///
|
||||||
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
||||||
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
||||||
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
|
||||||
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
|
if let Some(key) = credential_override.map(str::trim).filter(|k| !k.is_empty()) {
|
||||||
return Some(key.to_string());
|
return Some(key.to_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -194,7 +194,7 @@ pub fn create_provider_with_url(
|
||||||
api_key: Option<&str>,
|
api_key: Option<&str>,
|
||||||
api_url: Option<&str>,
|
api_url: Option<&str>,
|
||||||
) -> anyhow::Result<Box<dyn Provider>> {
|
) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
let resolved_key = resolve_api_key(name, api_key);
|
let resolved_key = resolve_provider_credential(name, api_key);
|
||||||
let key = resolved_key.as_deref();
|
let key = resolved_key.as_deref();
|
||||||
match name {
|
match name {
|
||||||
// ── Primary providers (custom implementations) ───────
|
// ── Primary providers (custom implementations) ───────
|
||||||
|
|
@ -454,8 +454,8 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_api_key_prefers_explicit_argument() {
|
fn resolve_provider_credential_prefers_explicit_argument() {
|
||||||
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
|
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
|
||||||
assert_eq!(resolved.as_deref(), Some("explicit-key"));
|
assert_eq!(resolved.as_deref(), Some("explicit-key"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -463,18 +463,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_openrouter() {
|
fn factory_openrouter() {
|
||||||
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
|
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
|
||||||
assert!(create_provider("openrouter", None).is_ok());
|
assert!(create_provider("openrouter", None).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_anthropic() {
|
fn factory_anthropic() {
|
||||||
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
|
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_openai() {
|
fn factory_openai() {
|
||||||
assert!(create_provider("openai", Some("sk-test")).is_ok());
|
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -774,15 +774,24 @@ mod tests {
|
||||||
scheduler_retries: 2,
|
scheduler_retries: 2,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider = create_resilient_provider("openrouter", Some("sk-test"), None, &reliability);
|
let provider = create_resilient_provider(
|
||||||
|
"openrouter",
|
||||||
|
Some("provider-test-credential"),
|
||||||
|
None,
|
||||||
|
&reliability,
|
||||||
|
);
|
||||||
assert!(provider.is_ok());
|
assert!(provider.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resilient_provider_errors_for_invalid_primary() {
|
fn resilient_provider_errors_for_invalid_primary() {
|
||||||
let reliability = crate::config::ReliabilityConfig::default();
|
let reliability = crate::config::ReliabilityConfig::default();
|
||||||
let provider =
|
let provider = create_resilient_provider(
|
||||||
create_resilient_provider("totally-invalid", Some("sk-test"), None, &reliability);
|
"totally-invalid",
|
||||||
|
Some("provider-test-credential"),
|
||||||
|
None,
|
||||||
|
&reliability,
|
||||||
|
);
|
||||||
assert!(provider.is_err());
|
assert!(provider.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub struct OpenAiProvider {
|
pub struct OpenAiProvider {
|
||||||
api_key: Option<String>,
|
credential: Option<String>,
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiProvider {
|
impl OpenAiProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(credential: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(120))
|
.timeout(std::time::Duration::from_secs(120))
|
||||||
.connect_timeout(std::time::Duration::from_secs(10))
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
|
|
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://api.openai.com/v1/chat/completions")
|
.post("https://api.openai.com/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://api.openai.com/v1/chat/completions")
|
.post("https://api.openai.com/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.json(&native_request)
|
.json(&native_request)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -330,20 +330,20 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
let p = OpenAiProvider::new(Some("openai-test-credential"));
|
||||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_without_key() {
|
fn creates_without_key() {
|
||||||
let p = OpenAiProvider::new(None);
|
let p = OpenAiProvider::new(None);
|
||||||
assert!(p.api_key.is_none());
|
assert!(p.credential.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_empty_key() {
|
fn creates_with_empty_key() {
|
||||||
let p = OpenAiProvider::new(Some(""));
|
let p = OpenAiProvider::new(Some(""));
|
||||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
assert_eq!(p.credential.as_deref(), Some(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub struct OpenRouterProvider {
|
pub struct OpenRouterProvider {
|
||||||
api_key: Option<String>,
|
credential: Option<String>,
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenRouterProvider {
|
impl OpenRouterProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(credential: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(120))
|
.timeout(std::time::Duration::from_secs(120))
|
||||||
.connect_timeout(std::time::Duration::from_secs(10))
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
|
|
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
||||||
// This prevents the first real chat request from timing out on cold start.
|
// This prevents the first real chat request from timing out on cold start.
|
||||||
if let Some(api_key) = self.api_key.as_ref() {
|
if let Some(credential) = self.credential.as_ref() {
|
||||||
self.client
|
self.client
|
||||||
.get("https://openrouter.ai/api/v1/auth/key")
|
.get("https://openrouter.ai/api/v1/auth/key")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.send()
|
.send()
|
||||||
.await?
|
.await?
|
||||||
.error_for_status()?;
|
.error_for_status()?;
|
||||||
|
|
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref()
|
let credential = self.credential.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref()
|
let credential = self.credential.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
let api_messages: Vec<Message> = messages
|
let api_messages: Vec<Message> = messages
|
||||||
|
|
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||||
)
|
)
|
||||||
|
|
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -494,14 +494,17 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let provider = OpenRouterProvider::new(Some("sk-or-123"));
|
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
|
||||||
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
|
assert_eq!(
|
||||||
|
provider.credential.as_deref(),
|
||||||
|
Some("openrouter-test-credential")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_without_key() {
|
fn creates_without_key() {
|
||||||
let provider = OpenRouterProvider::new(None);
|
let provider = OpenRouterProvider::new(None);
|
||||||
assert!(provider.api_key.is_none());
|
assert!(provider.credential.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
||||||
|
|
@ -81,14 +81,17 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bubblewrap_sandbox_name() {
|
fn bubblewrap_sandbox_name() {
|
||||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
let sandbox = BubblewrapSandbox;
|
||||||
|
assert_eq!(sandbox.name(), "bubblewrap");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bubblewrap_is_available_only_if_installed() {
|
fn bubblewrap_is_available_only_if_installed() {
|
||||||
// Result depends on whether bwrap is installed
|
// Result depends on whether bwrap is installed
|
||||||
let available = BubblewrapSandbox::is_available();
|
let sandbox = BubblewrapSandbox;
|
||||||
|
let _available = sandbox.is_available();
|
||||||
|
|
||||||
// Either way, the name should still work
|
// Either way, the name should still work
|
||||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
assert_eq!(sandbox.name(), "bubblewrap");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,12 +112,12 @@ impl ComposioTool {
|
||||||
action_name: &str,
|
action_name: &str,
|
||||||
params: serde_json::Value,
|
params: serde_json::Value,
|
||||||
entity_id: Option<&str>,
|
entity_id: Option<&str>,
|
||||||
connected_account_id: Option<&str>,
|
connected_account_ref: Option<&str>,
|
||||||
) -> anyhow::Result<serde_json::Value> {
|
) -> anyhow::Result<serde_json::Value> {
|
||||||
let tool_slug = normalize_tool_slug(action_name);
|
let tool_slug = normalize_tool_slug(action_name);
|
||||||
|
|
||||||
match self
|
match self
|
||||||
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id)
|
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => Ok(result),
|
Ok(result) => Ok(result),
|
||||||
|
|
@ -130,21 +130,16 @@ impl ComposioTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_action_v3(
|
fn build_execute_action_v3_request(
|
||||||
&self,
|
|
||||||
tool_slug: &str,
|
tool_slug: &str,
|
||||||
params: serde_json::Value,
|
params: serde_json::Value,
|
||||||
entity_id: Option<&str>,
|
entity_id: Option<&str>,
|
||||||
connected_account_id: Option<&str>,
|
connected_account_ref: Option<&str>,
|
||||||
) -> anyhow::Result<serde_json::Value> {
|
) -> (String, serde_json::Value) {
|
||||||
let url = if let Some(connected_account_id) = connected_account_id
|
let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
|
||||||
|
let account_ref = connected_account_ref
|
||||||
.map(str::trim)
|
.map(str::trim)
|
||||||
.filter(|id| !id.is_empty())
|
.filter(|id| !id.is_empty());
|
||||||
{
|
|
||||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
|
|
||||||
} else {
|
|
||||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"arguments": params,
|
"arguments": params,
|
||||||
|
|
@ -153,6 +148,26 @@ impl ComposioTool {
|
||||||
if let Some(entity) = entity_id {
|
if let Some(entity) = entity_id {
|
||||||
body["user_id"] = json!(entity);
|
body["user_id"] = json!(entity);
|
||||||
}
|
}
|
||||||
|
if let Some(account_ref) = account_ref {
|
||||||
|
body["connected_account_id"] = json!(account_ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
(url, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_action_v3(
|
||||||
|
&self,
|
||||||
|
tool_slug: &str,
|
||||||
|
params: serde_json::Value,
|
||||||
|
entity_id: Option<&str>,
|
||||||
|
connected_account_ref: Option<&str>,
|
||||||
|
) -> anyhow::Result<serde_json::Value> {
|
||||||
|
let (url, body) = Self::build_execute_action_v3_request(
|
||||||
|
tool_slug,
|
||||||
|
params,
|
||||||
|
entity_id,
|
||||||
|
connected_account_ref,
|
||||||
|
);
|
||||||
|
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
|
|
@ -474,11 +489,11 @@ impl Tool for ComposioTool {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let params = args.get("params").cloned().unwrap_or(json!({}));
|
let params = args.get("params").cloned().unwrap_or(json!({}));
|
||||||
let connected_account_id =
|
let connected_account_ref =
|
||||||
args.get("connected_account_id").and_then(|v| v.as_str());
|
args.get("connected_account_id").and_then(|v| v.as_str());
|
||||||
|
|
||||||
match self
|
match self
|
||||||
.execute_action(action_name, params, Some(entity_id), connected_account_id)
|
.execute_action(action_name, params, Some(entity_id), connected_account_ref)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
|
|
@ -948,4 +963,40 @@ mod tests {
|
||||||
fn composio_api_base_url_is_v3() {
|
fn composio_api_base_url_is_v3() {
|
||||||
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
|
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
|
||||||
|
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||||
|
"gmail-send-email",
|
||||||
|
json!({"to": "test@example.com"}),
|
||||||
|
Some("workspace-user"),
|
||||||
|
Some("account-42"),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
url,
|
||||||
|
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
|
||||||
|
);
|
||||||
|
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
|
||||||
|
assert_eq!(body["user_id"], json!("workspace-user"));
|
||||||
|
assert_eq!(body["connected_account_id"], json!("account-42"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_execute_action_v3_request_drops_blank_optional_fields() {
|
||||||
|
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||||
|
"github-list-repos",
|
||||||
|
json!({}),
|
||||||
|
None,
|
||||||
|
Some(" "),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
url,
|
||||||
|
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
|
||||||
|
);
|
||||||
|
assert_eq!(body["arguments"], json!({}));
|
||||||
|
assert!(body.get("connected_account_id").is_none());
|
||||||
|
assert!(body.get("user_id").is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
|
||||||
/// summarization) to purpose-built sub-agents.
|
/// summarization) to purpose-built sub-agents.
|
||||||
pub struct DelegateTool {
|
pub struct DelegateTool {
|
||||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||||
/// Global API key fallback (from config.api_key)
|
/// Global credential fallback (from config.api_key)
|
||||||
fallback_api_key: Option<String>,
|
fallback_credential: Option<String>,
|
||||||
/// Depth at which this tool instance lives in the delegation chain.
|
/// Depth at which this tool instance lives in the delegation chain.
|
||||||
depth: u32,
|
depth: u32,
|
||||||
}
|
}
|
||||||
|
|
@ -25,11 +25,11 @@ pub struct DelegateTool {
|
||||||
impl DelegateTool {
|
impl DelegateTool {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
agents: HashMap<String, DelegateAgentConfig>,
|
agents: HashMap<String, DelegateAgentConfig>,
|
||||||
fallback_api_key: Option<String>,
|
fallback_credential: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
agents: Arc::new(agents),
|
agents: Arc::new(agents),
|
||||||
fallback_api_key,
|
fallback_credential,
|
||||||
depth: 0,
|
depth: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -39,12 +39,12 @@ impl DelegateTool {
|
||||||
/// their DelegateTool via this method with `depth: parent.depth + 1`.
|
/// their DelegateTool via this method with `depth: parent.depth + 1`.
|
||||||
pub fn with_depth(
|
pub fn with_depth(
|
||||||
agents: HashMap<String, DelegateAgentConfig>,
|
agents: HashMap<String, DelegateAgentConfig>,
|
||||||
fallback_api_key: Option<String>,
|
fallback_credential: Option<String>,
|
||||||
depth: u32,
|
depth: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
agents: Arc::new(agents),
|
agents: Arc::new(agents),
|
||||||
fallback_api_key,
|
fallback_credential,
|
||||||
depth,
|
depth,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -165,13 +165,13 @@ impl Tool for DelegateTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create provider for this agent
|
// Create provider for this agent
|
||||||
let api_key = agent_config
|
let provider_credential = agent_config
|
||||||
.api_key
|
.api_key
|
||||||
.as_deref()
|
.as_deref()
|
||||||
.or(self.fallback_api_key.as_deref());
|
.or(self.fallback_credential.as_deref());
|
||||||
|
|
||||||
let provider: Box<dyn Provider> =
|
let provider: Box<dyn Provider> =
|
||||||
match providers::create_provider(&agent_config.provider, api_key) {
|
match providers::create_provider(&agent_config.provider, provider_credential) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
|
|
@ -268,7 +268,7 @@ mod tests {
|
||||||
provider: "openrouter".to_string(),
|
provider: "openrouter".to_string(),
|
||||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||||
system_prompt: None,
|
system_prompt: None,
|
||||||
api_key: Some("sk-test".to_string()),
|
api_key: Some("delegate-test-credential".to_string()),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
max_depth: 2,
|
max_depth: 2,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -440,7 +440,7 @@ mod tests {
|
||||||
&http,
|
&http,
|
||||||
tmp.path(),
|
tmp.path(),
|
||||||
&agents,
|
&agents,
|
||||||
Some("sk-test"),
|
Some("delegate-test-credential"),
|
||||||
&cfg,
|
&cfg,
|
||||||
);
|
);
|
||||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue