From d7cca4b150705c6e22d6c2ea9425688cc6b5cbdd Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:38:29 +0800 Subject: [PATCH 01/12] feat: unify scheduled tasks from #337 and #338 with security-first integration Unifies scheduled task capabilities and consolidates overlapping implementations from #337 and #338 into a single security-first integration path. Co-authored-by: Edvard Co-authored-by: stawky --- src/agent/loop_.rs | 5 + src/channels/mod.rs | 5 + src/config/mod.rs | 4 +- src/config/schema.rs | 43 ++++ src/cron/mod.rs | 420 +++++++++++++++++++++++++++------ src/cron/scheduler.rs | 13 +- src/gateway/mod.rs | 1 + src/lib.rs | 17 ++ src/main.rs | 17 ++ src/onboard/wizard.rs | 2 + src/tools/mod.rs | 25 +- src/tools/schedule.rs | 522 ++++++++++++++++++++++++++++++++++++++++++ 12 files changed, 1006 insertions(+), 68 deletions(-) create mode 100644 src/tools/schedule.rs diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index a8368c6..2558bfa 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -598,6 +598,7 @@ pub async fn run( &config.workspace_dir, &config.agents, config.api_key.as_deref(), + &config, ); // ── Resolve provider ───────────────────────────────────────── @@ -672,6 +673,10 @@ pub async fn run( "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", )); } + tool_descs.push(( + "schedule", + "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", + )); if !config.agents.is_empty() { tool_descs.push(( "delegate", diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 1acc502..21f99d0 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -730,6 +730,7 @@ pub async fn start_channels(config: Config) -> Result<()> { &config.workspace_dir, &config.agents, config.api_key.as_deref(), + &config, )); // Build system prompt from workspace identity files + skills @@ -776,6 +777,10 @@ pub async fn start_channels(config: Config) -> Result<()> { "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", )); } + tool_descs.push(( + "schedule", + "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", + )); if !config.agents.is_empty() { tool_descs.push(( "delegate", diff --git a/src/config/mod.rs b/src/config/mod.rs index d8980c0..a61c29c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,8 +6,8 @@ pub use schema::{ DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, - SandboxBackend, SandboxConfig, SecretsConfig, SecurityConfig, SlackConfig, TelegramConfig, - TunnelConfig, WebhookConfig, + SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, + TelegramConfig, TunnelConfig, WebhookConfig, }; #[cfg(test)] diff --git a/src/config/schema.rs b/src/config/schema.rs index bc27e4e..8d2ec55 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -34,6 +34,9 @@ pub struct Config { #[serde(default)] pub reliability: ReliabilityConfig, + #[serde(default)] + pub scheduler: SchedulerConfig, + /// Model routing rules — route `hint:` to specific provider+model combos. #[serde(default)] pub model_routes: Vec, @@ -697,6 +700,43 @@ impl Default for ReliabilityConfig { } } +// ── Scheduler ──────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SchedulerConfig { + /// Enable the built-in scheduler loop. + #[serde(default = "default_scheduler_enabled")] + pub enabled: bool, + /// Maximum number of persisted scheduled tasks. + #[serde(default = "default_scheduler_max_tasks")] + pub max_tasks: usize, + /// Maximum tasks executed per scheduler polling cycle. + #[serde(default = "default_scheduler_max_concurrent")] + pub max_concurrent: usize, +} + +fn default_scheduler_enabled() -> bool { + true +} + +fn default_scheduler_max_tasks() -> usize { + 64 +} + +fn default_scheduler_max_concurrent() -> usize { + 4 +} + +impl Default for SchedulerConfig { + fn default() -> Self { + Self { + enabled: default_scheduler_enabled(), + max_tasks: default_scheduler_max_tasks(), + max_concurrent: default_scheduler_max_concurrent(), + } + } +} + // ── Model routing ──────────────────────────────────────────────── /// Route a task hint to a specific provider + model. @@ -1148,6 +1188,7 @@ impl Default for Config { autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), + scheduler: SchedulerConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), @@ -1485,6 +1526,7 @@ mod tests { ..RuntimeConfig::default() }, reliability: ReliabilityConfig::default(), + scheduler: SchedulerConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig { enabled: true, @@ -1578,6 +1620,7 @@ default_temperature = 0.7 autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), + scheduler: SchedulerConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), diff --git a/src/cron/mod.rs b/src/cron/mod.rs index 444445f..4fe0c39 100644 --- a/src/cron/mod.rs +++ b/src/cron/mod.rs @@ -16,6 +16,8 @@ pub struct CronJob { pub next_run: DateTime, pub last_run: Option>, pub last_status: Option, + pub paused: bool, + pub one_shot: bool, } #[allow(clippy::needless_pass_by_value)] @@ -27,6 +29,7 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<( println!("No scheduled tasks yet."); println!("\nUsage:"); println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'"); + println!(" zeroclaw cron once 30m 'echo reminder'"); return Ok(()); } @@ -36,13 +39,20 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<( .last_run .map_or_else(|| "never".into(), |d| d.to_rfc3339()); let last_status = job.last_status.unwrap_or_else(|| "n/a".into()); + let flags = match (job.paused, job.one_shot) { + (true, true) => " [paused, one-shot]", + (true, false) => " [paused]", + (false, true) => " [one-shot]", + (false, false) => "", + }; println!( - "- {} | {} | next={} | last={} ({})\n cmd: {}", + "- {} | {} | next={} | last={} ({}){}\n cmd: {}", job.id, job.expression, job.next_run.to_rfc3339(), last_run, last_status, + flags, job.command ); } @@ -59,19 +69,41 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<( println!(" Cmd : {}", job.command); Ok(()) } - crate::CronCommands::Remove { id } => remove_job(config, &id), + crate::CronCommands::Once { delay, command } => { + let job = add_once(config, &delay, &command)?; + println!("✅ Added one-shot task {}", job.id); + println!(" Runs at: {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); + Ok(()) + } + crate::CronCommands::Remove { id } => { + remove_job(config, &id)?; + println!("✅ Removed cron job {id}"); + Ok(()) + } + crate::CronCommands::Pause { id } => { + pause_job(config, &id)?; + println!("⏸️ Paused job {id}"); + Ok(()) + } + crate::CronCommands::Resume { id } => { + resume_job(config, &id)?; + println!("▶️ Resumed job {id}"); + Ok(()) + } } } pub fn add_job(config: &Config, expression: &str, command: &str) -> Result { + check_max_tasks(config)?; let now = Utc::now(); let next_run = next_run_for(expression, now)?; let id = Uuid::new_v4().to_string(); with_connection(config, |conn| { conn.execute( - "INSERT INTO cron_jobs (id, expression, command, created_at, next_run) - VALUES (?1, ?2, ?3, ?4, ?5)", + "INSERT INTO cron_jobs (id, expression, command, created_at, next_run, paused, one_shot) + VALUES (?1, ?2, ?3, ?4, ?5, 0, 0)", params![ id, expression, @@ -91,43 +123,169 @@ pub fn add_job(config: &Config, expression: &str, command: &str) -> Result, command: &str) -> Result { + add_one_shot_job_with_expression(config, run_at, command, "@once".to_string()) +} + +pub fn add_once(config: &Config, delay: &str, command: &str) -> Result { + let duration = parse_duration(delay)?; + let run_at = Utc::now() + duration; + add_one_shot_job_with_expression(config, run_at, command, format!("@once:{delay}")) +} + +pub fn add_once_at(config: &Config, at: DateTime, command: &str) -> Result { + add_one_shot_job_with_expression(config, at, command, format!("@at:{}", at.to_rfc3339())) +} + +fn add_one_shot_job_with_expression( + config: &Config, + run_at: DateTime, + command: &str, + expression: String, +) -> Result { + check_max_tasks(config)?; + let now = Utc::now(); + if run_at <= now { + anyhow::bail!("Scheduled time must be in the future"); + } + + let id = Uuid::new_v4().to_string(); + + with_connection(config, |conn| { + conn.execute( + "INSERT INTO cron_jobs (id, expression, command, created_at, next_run, paused, one_shot) + VALUES (?1, ?2, ?3, ?4, ?5, 0, 1)", + params![id, expression, command, now.to_rfc3339(), run_at.to_rfc3339()], + ) + .context("Failed to insert one-shot task")?; + Ok(()) + })?; + + Ok(CronJob { + id, + expression, + command: command.to_string(), + next_run: run_at, + last_run: None, + last_status: None, + paused: false, + one_shot: true, + }) +} + +pub fn get_job(config: &Config, id: &str) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, next_run, last_run, last_status, paused, one_shot + FROM cron_jobs WHERE id = ?1", + )?; + + let mut rows = stmt.query_map(params![id], |row| Ok(parse_job_row(row)))?; + + match rows.next() { + Some(Ok(job_result)) => Ok(Some(job_result?)), + Some(Err(e)) => Err(e.into()), + None => Ok(None), + } + }) +} + +pub fn pause_job(config: &Config, id: &str) -> Result<()> { + let changed = with_connection(config, |conn| { + conn.execute("UPDATE cron_jobs SET paused = 1 WHERE id = ?1", params![id]) + .context("Failed to pause cron job") + })?; + + if changed == 0 { + anyhow::bail!("Cron job '{id}' not found"); + } + + Ok(()) +} + +pub fn resume_job(config: &Config, id: &str) -> Result<()> { + let changed = with_connection(config, |conn| { + conn.execute("UPDATE cron_jobs SET paused = 0 WHERE id = ?1", params![id]) + .context("Failed to resume cron job") + })?; + + if changed == 0 { + anyhow::bail!("Cron job '{id}' not found"); + } + + Ok(()) +} + +fn check_max_tasks(config: &Config) -> Result<()> { + let count = with_connection(config, |conn| { + let mut stmt = conn.prepare("SELECT COUNT(*) FROM cron_jobs")?; + let count: i64 = stmt.query_row([], |row| row.get(0))?; + usize::try_from(count).context("Unexpected negative task count") + })?; + + if count >= config.scheduler.max_tasks { + anyhow::bail!( + "Maximum number of scheduled tasks ({}) reached", + config.scheduler.max_tasks + ); + } + + Ok(()) +} + +fn parse_duration(input: &str) -> Result { + let input = input.trim(); + if input.is_empty() { + anyhow::bail!("Empty delay string"); + } + + let (num_str, unit) = if input.ends_with(|c: char| c.is_ascii_alphabetic()) { + let split = input.len() - 1; + (&input[..split], &input[split..]) + } else { + (input, "m") + }; + + let n: u64 = num_str + .trim() + .parse() + .with_context(|| format!("Invalid duration number: {num_str}"))?; + + let multiplier: u64 = match unit { + "s" => 1, + "m" => 60, + "h" => 3600, + "d" => 86400, + "w" => 604_800, + _ => anyhow::bail!("Unknown duration unit '{unit}', expected s/m/h/d/w"), + }; + + let secs = n + .checked_mul(multiplier) + .filter(|&s| i64::try_from(s).is_ok()) + .ok_or_else(|| anyhow::anyhow!("Duration value too large: {input}"))?; + + #[allow(clippy::cast_possible_wrap)] + Ok(chrono::Duration::seconds(secs as i64)) +} + pub fn list_jobs(config: &Config) -> Result> { with_connection(config, |conn| { let mut stmt = conn.prepare( - "SELECT id, expression, command, next_run, last_run, last_status + "SELECT id, expression, command, next_run, last_run, last_status, paused, one_shot FROM cron_jobs ORDER BY next_run ASC", )?; - let rows = stmt.query_map([], |row| { - let next_run_raw: String = row.get(3)?; - let last_run_raw: Option = row.get(4)?; - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - next_run_raw, - last_run_raw, - row.get::<_, Option>(5)?, - )) - })?; + let rows = stmt.query_map([], |row| Ok(parse_job_row(row)))?; let mut jobs = Vec::new(); for row in rows { - let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; - jobs.push(CronJob { - id, - expression, - command, - next_run: parse_rfc3339(&next_run_raw)?, - last_run: match last_run_raw { - Some(raw) => Some(parse_rfc3339(&raw)?), - None => None, - }, - last_status, - }); + jobs.push(row??); } Ok(jobs) }) @@ -143,44 +301,21 @@ pub fn remove_job(config: &Config, id: &str) -> Result<()> { anyhow::bail!("Cron job '{id}' not found"); } - println!("✅ Removed cron job {id}"); Ok(()) } pub fn due_jobs(config: &Config, now: DateTime) -> Result> { with_connection(config, |conn| { let mut stmt = conn.prepare( - "SELECT id, expression, command, next_run, last_run, last_status - FROM cron_jobs WHERE next_run <= ?1 ORDER BY next_run ASC", + "SELECT id, expression, command, next_run, last_run, last_status, paused, one_shot + FROM cron_jobs WHERE next_run <= ?1 AND paused = 0 ORDER BY next_run ASC", )?; - let rows = stmt.query_map(params![now.to_rfc3339()], |row| { - let next_run_raw: String = row.get(3)?; - let last_run_raw: Option = row.get(4)?; - Ok(( - row.get::<_, String>(0)?, - row.get::<_, String>(1)?, - row.get::<_, String>(2)?, - next_run_raw, - last_run_raw, - row.get::<_, Option>(5)?, - )) - })?; + let rows = stmt.query_map(params![now.to_rfc3339()], |row| Ok(parse_job_row(row)))?; let mut jobs = Vec::new(); for row in rows { - let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; - jobs.push(CronJob { - id, - expression, - command, - next_run: parse_rfc3339(&next_run_raw)?, - last_run: match last_run_raw { - Some(raw) => Some(parse_rfc3339(&raw)?), - None => None, - }, - last_status, - }); + jobs.push(row??); } Ok(jobs) }) @@ -192,6 +327,15 @@ pub fn reschedule_after_run( success: bool, output: &str, ) -> Result<()> { + if job.one_shot { + with_connection(config, |conn| { + conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![job.id]) + .context("Failed to remove one-shot task after execution")?; + Ok(()) + })?; + return Ok(()); + } + let now = Utc::now(); let next_run = next_run_for(&job.expression, now)?; let status = if success { "ok" } else { "error" }; @@ -229,9 +373,7 @@ fn normalize_expression(expression: &str) -> Result { let field_count = expression.split_whitespace().count(); match field_count { - // standard crontab syntax: minute hour day month weekday 5 => Ok(format!("0 {expression}")), - // crate-native syntax includes seconds (+ optional year) 6 | 7 => Ok(expression.to_string()), _ => anyhow::bail!( "Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})" @@ -239,6 +381,31 @@ fn normalize_expression(expression: &str) -> Result { } } +fn parse_job_row(row: &rusqlite::Row<'_>) -> Result { + let id: String = row.get(0)?; + let expression: String = row.get(1)?; + let command: String = row.get(2)?; + let next_run_raw: String = row.get(3)?; + let last_run_raw: Option = row.get(4)?; + let last_status: Option = row.get(5)?; + let paused: bool = row.get(6)?; + let one_shot: bool = row.get(7)?; + + Ok(CronJob { + id, + expression, + command, + next_run: parse_rfc3339(&next_run_raw)?, + last_run: match last_run_raw { + Some(raw) => Some(parse_rfc3339(&raw)?), + None => None, + }, + last_status, + paused, + one_shot, + }) +} + fn parse_rfc3339(raw: &str) -> Result> { let parsed = DateTime::parse_from_rfc3339(raw) .with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?; @@ -255,7 +422,6 @@ fn with_connection(config: &Config, f: impl FnOnce(&Connection) -> Result) let conn = Connection::open(&db_path) .with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?; - // ── Production-grade PRAGMA tuning ────────────────────── conn.execute_batch( "PRAGMA journal_mode = WAL; PRAGMA synchronous = NORMAL; @@ -274,12 +440,19 @@ fn with_connection(config: &Config, f: impl FnOnce(&Connection) -> Result) next_run TEXT NOT NULL, last_run TEXT, last_status TEXT, - last_output TEXT + last_output TEXT, + paused INTEGER NOT NULL DEFAULT 0, + one_shot INTEGER NOT NULL DEFAULT 0 ); CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run);", ) .context("Failed to initialize cron schema")?; + for column in ["paused", "one_shot"] { + let alter = format!("ALTER TABLE cron_jobs ADD COLUMN {column} INTEGER NOT NULL DEFAULT 0"); + let _ = conn.execute_batch(&alter); + } + f(&conn) } @@ -309,6 +482,8 @@ mod tests { assert_eq!(job.expression, "*/5 * * * *"); assert_eq!(job.command, "echo ok"); + assert!(!job.one_shot); + assert!(!job.paused); } #[test] @@ -335,18 +510,72 @@ mod tests { } #[test] - fn due_jobs_filters_by_timestamp() { + fn add_once_creates_one_shot_job() { let tmp = TempDir::new().unwrap(); let config = test_config(&tmp); - let _job = add_job(&config, "* * * * *", "echo due").unwrap(); + let job = add_once(&config, "30m", "echo once").unwrap(); + assert!(job.one_shot); + assert!(job.expression.starts_with("@once:")); + + let fetched = get_job(&config, &job.id).unwrap().unwrap(); + assert!(fetched.one_shot); + assert!(!fetched.paused); + } + + #[test] + fn add_once_at_rejects_past_timestamp() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let run_at = Utc::now() - ChronoDuration::minutes(1); + let err = add_once_at(&config, run_at, "echo past").unwrap_err(); + assert!(err.to_string().contains("future")); + } + + #[test] + fn get_job_found_and_missing() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/5 * * * *", "echo found").unwrap(); + let found = get_job(&config, &job.id).unwrap(); + assert!(found.is_some()); + assert_eq!(found.unwrap().id, job.id); + + let missing = get_job(&config, "nonexistent").unwrap(); + assert!(missing.is_none()); + } + + #[test] + fn pause_resume_roundtrip() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/5 * * * *", "echo pause").unwrap(); + pause_job(&config, &job.id).unwrap(); + assert!(get_job(&config, &job.id).unwrap().unwrap().paused); + + resume_job(&config, &job.id).unwrap(); + assert!(!get_job(&config, &job.id).unwrap().unwrap().paused); + } + + #[test] + fn due_jobs_filters_by_timestamp_and_skips_paused() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let active = add_job(&config, "* * * * *", "echo due").unwrap(); + let paused = add_job(&config, "* * * * *", "echo paused").unwrap(); + pause_job(&config, &paused.id).unwrap(); let due_now = due_jobs(&config, Utc::now()).unwrap(); - assert!(due_now.is_empty(), "new job should not be due immediately"); + assert!(due_now.is_empty(), "new jobs should not be due immediately"); let far_future = Utc::now() + ChronoDuration::days(365); let due_future = due_jobs(&config, far_future).unwrap(); - assert_eq!(due_future.len(), 1, "job should be due in far future"); + assert_eq!(due_future.len(), 1); + assert_eq!(due_future[0].id, active.id); } #[test] @@ -362,4 +591,67 @@ mod tests { assert_eq!(stored.last_status.as_deref(), Some("error")); assert!(stored.last_run.is_some()); } + + #[test] + fn reschedule_after_run_removes_one_shot_jobs() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let run_at = Utc::now() + ChronoDuration::minutes(1); + let job = add_one_shot_job(&config, run_at, "echo once").unwrap(); + reschedule_after_run(&config, &job, true, "ok").unwrap(); + + assert!(get_job(&config, &job.id).unwrap().is_none()); + } + + #[test] + fn scheduler_columns_migrate_from_old_schema() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let db_path = config.workspace_dir.join("cron").join("jobs.db"); + std::fs::create_dir_all(db_path.parent().unwrap()).unwrap(); + + { + let conn = rusqlite::Connection::open(&db_path).unwrap(); + conn.execute_batch( + "CREATE TABLE cron_jobs ( + id TEXT PRIMARY KEY, + expression TEXT NOT NULL, + command TEXT NOT NULL, + created_at TEXT NOT NULL, + next_run TEXT NOT NULL, + last_run TEXT, + last_status TEXT, + last_output TEXT + );", + ) + .unwrap(); + conn.execute( + "INSERT INTO cron_jobs (id, expression, command, created_at, next_run) + VALUES ('old-job', '* * * * *', 'echo old', '2025-01-01T00:00:00Z', '2030-01-01T00:00:00Z')", + [], + ) + .unwrap(); + } + + let jobs = list_jobs(&config).unwrap(); + assert_eq!(jobs.len(), 1); + assert_eq!(jobs[0].id, "old-job"); + assert!(!jobs[0].paused); + assert!(!jobs[0].one_shot); + } + + #[test] + fn max_tasks_limit_is_enforced() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.scheduler.max_tasks = 1; + + let _first = add_job(&config, "*/10 * * * *", "echo first").unwrap(); + let err = add_job(&config, "*/11 * * * *", "echo second").unwrap_err(); + assert!(err + .to_string() + .contains("Maximum number of scheduled tasks")); + } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index bab1965..bdb5f0b 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -9,9 +9,18 @@ use tokio::time::{self, Duration}; const MIN_POLL_SECONDS: u64 = 5; pub async fn run(config: Config) -> Result<()> { + if !config.scheduler.enabled { + tracing::info!("Scheduler disabled by config"); + crate::health::mark_component_ok("scheduler"); + loop { + time::sleep(Duration::from_secs(3600)).await; + } + } + let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS); let mut interval = time::interval(Duration::from_secs(poll_secs)); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + let max_concurrent = config.scheduler.max_concurrent.max(1); crate::health::mark_component_ok("scheduler"); @@ -27,7 +36,7 @@ pub async fn run(config: Config) -> Result<()> { } }; - for job in jobs { + for job in jobs.into_iter().take(max_concurrent) { crate::health::mark_component_ok("scheduler"); let (success, output) = execute_job_with_retry(&config, &security, &job).await; @@ -224,6 +233,8 @@ mod tests { next_run: Utc::now(), last_run: None, last_status: None, + paused: false, + one_shot: false, } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 8eaa57c..104d4de 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -267,6 +267,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &config.workspace_dir, &config.agents, config.api_key.as_deref(), + &config, )); let skills = crate::skills::load_skills(&config.workspace_dir); let tool_descs: Vec<(&str, &str)> = tools_registry diff --git a/src/lib.rs b/src/lib.rs index 619190b..61a2bc6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -147,11 +147,28 @@ pub enum CronCommands { /// Command to run command: String, }, + /// Add a one-shot delayed task (e.g. "30m", "2h", "1d") + Once { + /// Delay duration + delay: String, + /// Command to run + command: String, + }, /// Remove a scheduled task Remove { /// Task ID id: String, }, + /// Pause a scheduled task + Pause { + /// Task ID + id: String, + }, + /// Resume a paused task + Resume { + /// Task ID + id: String, + }, } /// Integration subcommands diff --git a/src/main.rs b/src/main.rs index 426fdfd..3253594 100644 --- a/src/main.rs +++ b/src/main.rs @@ -234,11 +234,28 @@ enum CronCommands { /// Command to run command: String, }, + /// Add a one-shot delayed task (e.g. "30m", "2h", "1d") + Once { + /// Delay duration + delay: String, + /// Command to run + command: String, + }, /// Remove a scheduled task Remove { /// Task ID id: String, }, + /// Pause a scheduled task + Pause { + /// Task ID + id: String, + }, + /// Resume a paused task + Resume { + /// Task ID + id: String, + }, } #[derive(Subcommand, Debug)] diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 0447d23..7fbcc44 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -110,6 +110,7 @@ pub fn run_wizard() -> Result { autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), + scheduler: crate::config::SchedulerConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config, @@ -305,6 +306,7 @@ pub fn run_quick_setup( autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), + scheduler: crate::config::SchedulerConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 22e8d1a..b5cd67a 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -10,6 +10,7 @@ pub mod image_info; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod schedule; pub mod screenshot; pub mod shell; pub mod traits; @@ -26,6 +27,7 @@ pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use schedule::ScheduleTool; pub use screenshot::ScreenshotTool; pub use shell::ShellTool; pub use traits::Tool; @@ -67,6 +69,7 @@ pub fn all_tools( workspace_dir: &std::path::Path, agents: &HashMap, fallback_api_key: Option<&str>, + config: &crate::config::Config, ) -> Vec> { all_tools_with_runtime( security, @@ -78,6 +81,7 @@ pub fn all_tools( workspace_dir, agents, fallback_api_key, + config, ) } @@ -93,6 +97,7 @@ pub fn all_tools_with_runtime( workspace_dir: &std::path::Path, agents: &HashMap, fallback_api_key: Option<&str>, + config: &crate::config::Config, ) -> Vec> { let mut tools: Vec> = vec![ Box::new(ShellTool::new(security.clone(), runtime)), @@ -101,6 +106,7 @@ pub fn all_tools_with_runtime( Box::new(MemoryStoreTool::new(memory.clone())), Box::new(MemoryRecallTool::new(memory.clone())), Box::new(MemoryForgetTool::new(memory)), + Box::new(ScheduleTool::new(security.clone(), config.clone())), Box::new(GitOperationsTool::new( security.clone(), workspace_dir.to_path_buf(), @@ -158,9 +164,17 @@ pub fn all_tools_with_runtime( #[cfg(test)] mod tests { use super::*; - use crate::config::{BrowserConfig, MemoryConfig}; + use crate::config::{BrowserConfig, Config, MemoryConfig}; use tempfile::TempDir; + fn test_config(tmp: &TempDir) -> Config { + Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + } + } + #[test] fn default_tools_has_three() { let security = Arc::new(SecurityPolicy::default()); @@ -186,6 +200,7 @@ mod tests { ..BrowserConfig::default() }; let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); let tools = all_tools( &security, @@ -196,9 +211,11 @@ mod tests { tmp.path(), &HashMap::new(), None, + &cfg, ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); + assert!(names.contains(&"schedule")); } #[test] @@ -219,6 +236,7 @@ mod tests { ..BrowserConfig::default() }; let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); let tools = all_tools( &security, @@ -229,6 +247,7 @@ mod tests { tmp.path(), &HashMap::new(), None, + &cfg, ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); @@ -341,6 +360,7 @@ mod tests { let browser = BrowserConfig::default(); let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); let mut agents = HashMap::new(); agents.insert( @@ -364,6 +384,7 @@ mod tests { tmp.path(), &agents, Some("sk-test"), + &cfg, ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"delegate")); @@ -382,6 +403,7 @@ mod tests { let browser = BrowserConfig::default(); let http = crate::config::HttpRequestConfig::default(); + let cfg = test_config(&tmp); let tools = all_tools( &security, @@ -392,6 +414,7 @@ mod tests { tmp.path(), &HashMap::new(), None, + &cfg, ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"delegate")); diff --git a/src/tools/schedule.rs b/src/tools/schedule.rs new file mode 100644 index 0000000..43234b8 --- /dev/null +++ b/src/tools/schedule.rs @@ -0,0 +1,522 @@ +use super::traits::{Tool, ToolResult}; +use crate::config::Config; +use crate::cron; +use crate::security::SecurityPolicy; +use anyhow::Result; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use serde_json::json; +use std::sync::Arc; + +/// Tool that lets the agent manage recurring and one-shot scheduled tasks. +pub struct ScheduleTool { + security: Arc, + config: Config, +} + +impl ScheduleTool { + pub fn new(security: Arc, config: Config) -> Self { + Self { security, config } + } +} + +#[async_trait] +impl Tool for ScheduleTool { + fn name(&self) -> &str { + "schedule" + } + + fn description(&self) -> &str { + "Manage scheduled tasks. Actions: create/add/once/list/get/cancel/remove/pause/resume" + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["create", "add", "once", "list", "get", "cancel", "remove", "pause", "resume"], + "description": "Action to perform" + }, + "expression": { + "type": "string", + "description": "Cron expression for recurring tasks (e.g. '*/5 * * * *')." + }, + "delay": { + "type": "string", + "description": "Delay for one-shot tasks (e.g. '30m', '2h', '1d')." + }, + "run_at": { + "type": "string", + "description": "Absolute RFC3339 time for one-shot tasks (e.g. '2030-01-01T00:00:00Z')." + }, + "command": { + "type": "string", + "description": "Shell command to execute. Required for create/add/once." + }, + "id": { + "type": "string", + "description": "Task ID. Required for get/cancel/remove/pause/resume." + } + }, + "required": ["action"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let action = args + .get("action") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'action' parameter"))?; + + match action { + "list" => self.handle_list(), + "get" => { + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for get action"))?; + self.handle_get(id) + } + "create" | "add" | "once" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + self.handle_create_like(action, &args) + } + "cancel" | "remove" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for cancel action"))?; + Ok(self.handle_cancel(id)) + } + "pause" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for pause action"))?; + Ok(self.handle_pause_resume(id, true)) + } + "resume" => { + if let Some(blocked) = self.enforce_mutation_allowed(action) { + return Ok(blocked); + } + let id = args + .get("id") + .and_then(|value| value.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'id' parameter for resume action"))?; + Ok(self.handle_pause_resume(id, false)) + } + other => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Unknown action '{other}'. Use create/add/once/list/get/cancel/remove/pause/resume." + )), + }), + } + } +} + +impl ScheduleTool { + fn enforce_mutation_allowed(&self, action: &str) -> Option { + if !self.security.can_act() { + return Some(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Security policy: read-only mode, cannot perform '{action}'" + )), + }); + } + + if !self.security.record_action() { + return Some(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".to_string()), + }); + } + + None + } + + fn handle_list(&self) -> Result { + let jobs = cron::list_jobs(&self.config)?; + if jobs.is_empty() { + return Ok(ToolResult { + success: true, + output: "No scheduled jobs.".to_string(), + error: None, + }); + } + + let mut lines = Vec::with_capacity(jobs.len()); + for job in jobs { + let flags = match (job.paused, job.one_shot) { + (true, true) => " [paused, one-shot]", + (true, false) => " [paused]", + (false, true) => " [one-shot]", + (false, false) => "", + }; + let last_run = job + .last_run + .map_or_else(|| "never".to_string(), |value| value.to_rfc3339()); + let last_status = job.last_status.unwrap_or_else(|| "n/a".to_string()); + lines.push(format!( + "- {} | {} | next={} | last={} ({}){} | cmd: {}", + job.id, + job.expression, + job.next_run.to_rfc3339(), + last_run, + last_status, + flags, + job.command + )); + } + + Ok(ToolResult { + success: true, + output: format!("Scheduled jobs ({}):\n{}", lines.len(), lines.join("\n")), + error: None, + }) + } + + fn handle_get(&self, id: &str) -> Result { + match cron::get_job(&self.config, id)? { + Some(job) => { + let detail = json!({ + "id": job.id, + "expression": job.expression, + "command": job.command, + "next_run": job.next_run.to_rfc3339(), + "last_run": job.last_run.map(|value| value.to_rfc3339()), + "last_status": job.last_status, + "paused": job.paused, + "one_shot": job.one_shot, + }); + Ok(ToolResult { + success: true, + output: serde_json::to_string_pretty(&detail)?, + error: None, + }) + } + None => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Job '{id}' not found")), + }), + } + } + + fn handle_create_like(&self, action: &str, args: &serde_json::Value) -> Result { + let command = args + .get("command") + .and_then(|value| value.as_str()) + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing or empty 'command' parameter"))?; + + let expression = args.get("expression").and_then(|value| value.as_str()); + let delay = args.get("delay").and_then(|value| value.as_str()); + let run_at = args.get("run_at").and_then(|value| value.as_str()); + + match action { + "add" => { + if expression.is_none() || delay.is_some() || run_at.is_some() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'add' requires 'expression' and forbids delay/run_at".into()), + }); + } + } + "once" => { + if expression.is_some() || (delay.is_none() && run_at.is_none()) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'once' requires exactly one of 'delay' or 'run_at'".into()), + }); + } + if delay.is_some() && run_at.is_some() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("'once' supports either delay or run_at, not both".into()), + }); + } + } + _ => { + let count = [expression.is_some(), delay.is_some(), run_at.is_some()] + .into_iter() + .filter(|value| *value) + .count(); + if count != 1 { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "Exactly one of 'expression', 'delay', or 'run_at' must be provided" + .into(), + ), + }); + } + } + } + + if let Some(value) = expression { + let job = cron::add_job(&self.config, value, command)?; + return Ok(ToolResult { + success: true, + output: format!( + "Created recurring job {} (expr: {}, next: {}, cmd: {})", + job.id, + job.expression, + job.next_run.to_rfc3339(), + job.command + ), + error: None, + }); + } + + if let Some(value) = delay { + let job = cron::add_once(&self.config, value, command)?; + return Ok(ToolResult { + success: true, + output: format!( + "Created one-shot job {} (runs at: {}, cmd: {})", + job.id, + job.next_run.to_rfc3339(), + job.command + ), + error: None, + }); + } + + let run_at_raw = run_at.ok_or_else(|| anyhow::anyhow!("Missing scheduling parameters"))?; + let run_at_parsed: DateTime = DateTime::parse_from_rfc3339(run_at_raw) + .map_err(|error| anyhow::anyhow!("Invalid run_at timestamp: {error}"))? + .with_timezone(&Utc); + + let job = cron::add_once_at(&self.config, run_at_parsed, command)?; + Ok(ToolResult { + success: true, + output: format!( + "Created one-shot job {} (runs at: {}, cmd: {})", + job.id, + job.next_run.to_rfc3339(), + job.command + ), + error: None, + }) + } + + fn handle_cancel(&self, id: &str) -> ToolResult { + match cron::remove_job(&self.config, id) { + Ok(()) => ToolResult { + success: true, + output: format!("Cancelled job {id}"), + error: None, + }, + Err(error) => ToolResult { + success: false, + output: String::new(), + error: Some(error.to_string()), + }, + } + } + + fn handle_pause_resume(&self, id: &str, pause: bool) -> ToolResult { + let operation = if pause { + cron::pause_job(&self.config, id) + } else { + cron::resume_job(&self.config, id) + }; + + match operation { + Ok(()) => ToolResult { + success: true, + output: if pause { + format!("Paused job {id}") + } else { + format!("Resumed job {id}") + }, + error: None, + }, + Err(error) => ToolResult { + success: false, + output: String::new(), + error: Some(error.to_string()), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::AutonomyLevel; + use tempfile::TempDir; + + fn test_setup() -> (TempDir, Config, Arc) { + let tmp = TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + (tmp, config, security) + } + + #[test] + fn tool_name_and_schema() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + assert_eq!(tool.name(), "schedule"); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["action"].is_object()); + } + + #[tokio::test] + async fn list_empty() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let result = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(result.success); + assert!(result.output.contains("No scheduled jobs")); + } + + #[tokio::test] + async fn create_get_and_cancel_roundtrip() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let create = tool + .execute(json!({ + "action": "create", + "expression": "*/5 * * * *", + "command": "echo hello" + })) + .await + .unwrap(); + assert!(create.success); + assert!(create.output.contains("Created recurring job")); + + let list = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(list.success); + assert!(list.output.contains("echo hello")); + + let id = create.output.split_whitespace().nth(3).unwrap(); + + let get = tool + .execute(json!({"action": "get", "id": id})) + .await + .unwrap(); + assert!(get.success); + assert!(get.output.contains("echo hello")); + + let cancel = tool + .execute(json!({"action": "cancel", "id": id})) + .await + .unwrap(); + assert!(cancel.success); + } + + #[tokio::test] + async fn once_and_pause_resume_aliases_work() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let once = tool + .execute(json!({ + "action": "once", + "delay": "30m", + "command": "echo delayed" + })) + .await + .unwrap(); + assert!(once.success); + + let add = tool + .execute(json!({ + "action": "add", + "expression": "*/10 * * * *", + "command": "echo recurring" + })) + .await + .unwrap(); + assert!(add.success); + + let id = add.output.split_whitespace().nth(3).unwrap(); + let pause = tool + .execute(json!({"action": "pause", "id": id})) + .await + .unwrap(); + assert!(pause.success); + + let resume = tool + .execute(json!({"action": "resume", "id": id})) + .await + .unwrap(); + assert!(resume.success); + } + + #[tokio::test] + async fn readonly_blocks_mutating_actions() { + let tmp = TempDir::new().unwrap(); + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + autonomy: crate::config::AutonomyConfig { + level: AutonomyLevel::ReadOnly, + ..Default::default() + }, + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + + let tool = ScheduleTool::new(security, config); + + let blocked = tool + .execute(json!({ + "action": "create", + "expression": "* * * * *", + "command": "echo blocked" + })) + .await + .unwrap(); + assert!(!blocked.success); + assert!(blocked.error.as_deref().unwrap().contains("read-only")); + + let list = tool.execute(json!({"action": "list"})).await.unwrap(); + assert!(list.success); + } + + #[tokio::test] + async fn unknown_action_returns_failure() { + let (_tmp, config, security) = test_setup(); + let tool = ScheduleTool::new(security, config); + + let result = tool.execute(json!({"action": "explode"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.as_deref().unwrap().contains("Unknown action")); + } +} From e9fa267c8442f11ed410f347490ce0bda0057d93 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:40:33 +0800 Subject: [PATCH 02/12] feat(onboard): add provider model refresh command with TTL cache (#323) --- src/main.rs | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/main.rs b/src/main.rs index 3253594..a5c17f4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -272,6 +272,20 @@ enum ModelCommands { }, } +#[derive(Subcommand, Debug)] +enum ModelCommands { + /// Refresh and cache provider models + Refresh { + /// Provider name (defaults to configured default provider) + #[arg(long)] + provider: Option, + + /// Force live refresh and ignore fresh cache + #[arg(long)] + force: bool, + }, +} + #[derive(Subcommand, Debug)] enum ChannelCommands { /// List configured channels From fe1fb042787ed5089e2b666860a2e8855c8f3373 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:40:37 +0800 Subject: [PATCH 03/12] fix(composio): align v3 execute path and honor configured entity_id (#322) --- README.md | 2 ++ src/agent/loop_.rs | 12 +++++--- src/channels/mod.rs | 12 +++++--- src/gateway/mod.rs | 10 +++++-- src/tools/composio.rs | 69 ++++++++++++++++++++++++++++++++----------- src/tools/mod.rs | 13 ++++++-- 6 files changed, 86 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6ff65b9..7cd5aab 100644 --- a/README.md +++ b/README.md @@ -315,6 +315,8 @@ native_webdriver_url = "http://127.0.0.1:9515" # WebDriver endpoint (chromedrive [composio] enabled = false # opt-in: 1000+ OAuth apps via composio.dev +# api_key = "cmp_..." # optional: stored encrypted when [secrets].encrypt = true +entity_id = "default" # default user_id for Composio tool calls [identity] format = "openclaw" # "openclaw" (default, markdown files) or "aieos" (JSON) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 2558bfa..932606f 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -583,16 +583,20 @@ pub async fn run( tracing::info!(backend = mem.name(), "Memory initialized"); // ── Tools (including memory tools) ──────────────────────────── - let composio_key = if config.composio.enabled { - config.composio.api_key.as_deref() + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) } else { - None + (None, None) }; let tools_registry = tools::all_tools_with_runtime( &security, runtime, mem.clone(), composio_key, + composio_entity_id, &config.browser, &config.http_request, &config.workspace_dir, @@ -670,7 +674,7 @@ pub async fn run( if config.composio.enabled { tool_descs.push(( "composio", - "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", + "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run (optionally with connected_account_id), 'connect' to OAuth.", )); } tool_descs.push(( diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 21f99d0..9579ff8 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -715,16 +715,20 @@ pub async fn start_channels(config: Config) -> Result<()> { config.api_key.as_deref(), )?); - let composio_key = if config.composio.enabled { - config.composio.api_key.as_deref() + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) } else { - None + (None, None) }; let tools_registry = Arc::new(tools::all_tools_with_runtime( &security, runtime, Arc::clone(&mem), composio_key, + composio_entity_id, &config.browser, &config.http_request, &config.workspace_dir, @@ -774,7 +778,7 @@ pub async fn start_channels(config: Config) -> Result<()> { if config.composio.enabled { tool_descs.push(( "composio", - "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", + "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run (optionally with connected_account_id), 'connect' to OAuth.", )); } tool_descs.push(( diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 104d4de..638de00 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -251,10 +251,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &config.workspace_dir, )); - let composio_key = if config.composio.enabled { - config.composio.api_key.as_deref() + let (composio_key, composio_entity_id) = if config.composio.enabled { + ( + config.composio.api_key.as_deref(), + Some(config.composio.entity_id.as_str()), + ) } else { - None + (None, None) }; let tools_registry = Arc::new(tools::all_tools_with_runtime( @@ -262,6 +265,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { runtime, Arc::clone(&mem), composio_key, + composio_entity_id, &config.browser, &config.http_request, &config.workspace_dir, diff --git a/src/tools/composio.rs b/src/tools/composio.rs index 2850d33..b010240 100644 --- a/src/tools/composio.rs +++ b/src/tools/composio.rs @@ -19,13 +19,15 @@ const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3"; /// A tool that proxies actions to the Composio managed tool platform. pub struct ComposioTool { api_key: String, + default_entity_id: String, client: Client, } impl ComposioTool { - pub fn new(api_key: &str) -> Self { + pub fn new(api_key: &str, default_entity_id: Option<&str>) -> Self { Self { api_key: api_key.to_string(), + default_entity_id: normalize_entity_id(default_entity_id.unwrap_or("default")), client: Client::builder() .timeout(std::time::Duration::from_secs(60)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -59,9 +61,9 @@ impl ComposioTool { let url = format!("{COMPOSIO_API_BASE_V3}/tools"); let mut req = self.client.get(&url).header("x-api-key", &self.api_key); - req = req.query(&[("limit", 200_u16)]); - if let Some(app) = app_name { - req = req.query(&[("toolkit_slug", app)]); + req = req.query(&[("limit", "200")]); + if let Some(app) = app_name.map(str::trim).filter(|app| !app.is_empty()) { + req = req.query(&[("toolkits", app), ("toolkit_slug", app)]); } let resp = req.send().await?; @@ -110,11 +112,12 @@ impl ComposioTool { action_name: &str, params: serde_json::Value, entity_id: Option<&str>, + connected_account_id: Option<&str>, ) -> anyhow::Result { let tool_slug = normalize_tool_slug(action_name); match self - .execute_action_v3(&tool_slug, params.clone(), entity_id) + .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id) .await { Ok(result) => Ok(result), @@ -132,8 +135,16 @@ impl ComposioTool { tool_slug: &str, params: serde_json::Value, entity_id: Option<&str>, + connected_account_id: Option<&str>, ) -> anyhow::Result { - let url = format!("{COMPOSIO_API_BASE_V3}/tools/execute/{tool_slug}"); + let url = if let Some(connected_account_id) = connected_account_id + .map(str::trim) + .filter(|id| !id.is_empty()) + { + format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}") + } else { + format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute") + }; let mut body = json!({ "arguments": params, @@ -355,7 +366,7 @@ impl Tool for ComposioTool { fn description(&self) -> &str { "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). \ - Use action='list' to see available actions, action='execute' with action_name/tool_slug and params, \ + Use action='list' to see available actions, action='execute' with action_name/tool_slug, params, and optional connected_account_id, \ or action='connect' with app/auth_config_id to get OAuth URL." } @@ -386,11 +397,15 @@ impl Tool for ComposioTool { }, "entity_id": { "type": "string", - "description": "Entity/user ID for multi-user setups (defaults to 'default')" + "description": "Entity/user ID for multi-user setups (defaults to composio.entity_id from config)" }, "auth_config_id": { "type": "string", "description": "Optional Composio v3 auth config id for connect flow" + }, + "connected_account_id": { + "type": "string", + "description": "Optional connected account ID for execute flow when a specific account is required" } }, "required": ["action"] @@ -406,7 +421,7 @@ impl Tool for ComposioTool { let entity_id = args .get("entity_id") .and_then(|v| v.as_str()) - .unwrap_or("default"); + .unwrap_or(self.default_entity_id.as_str()); match action { "list" => { @@ -459,9 +474,11 @@ impl Tool for ComposioTool { })?; let params = args.get("params").cloned().unwrap_or(json!({})); + let connected_account_id = + args.get("connected_account_id").and_then(|v| v.as_str()); match self - .execute_action(action_name, params, Some(entity_id)) + .execute_action(action_name, params, Some(entity_id), connected_account_id) .await { Ok(result) => { @@ -521,6 +538,15 @@ impl Tool for ComposioTool { } } +fn normalize_entity_id(entity_id: &str) -> String { + let trimmed = entity_id.trim(); + if trimmed.is_empty() { + "default".to_string() + } else { + trimmed.to_string() + } +} + fn normalize_tool_slug(action_name: &str) -> String { action_name.trim().replace('_', "-").to_ascii_lowercase() } @@ -668,20 +694,20 @@ mod tests { #[test] fn composio_tool_has_correct_name() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); assert_eq!(tool.name(), "composio"); } #[test] fn composio_tool_has_description() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); assert!(!tool.description().is_empty()); assert!(tool.description().contains("1000+")); } #[test] fn composio_tool_schema_has_required_fields() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let schema = tool.parameters_schema(); assert!(schema["properties"]["action"].is_object()); assert!(schema["properties"]["action_name"].is_object()); @@ -689,13 +715,14 @@ mod tests { assert!(schema["properties"]["params"].is_object()); assert!(schema["properties"]["app"].is_object()); assert!(schema["properties"]["auth_config_id"].is_object()); + assert!(schema["properties"]["connected_account_id"].is_object()); let required = schema["required"].as_array().unwrap(); assert!(required.contains(&json!("action"))); } #[test] fn composio_tool_spec_roundtrip() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let spec = tool.spec(); assert_eq!(spec.name, "composio"); assert!(spec.parameters.is_object()); @@ -705,14 +732,14 @@ mod tests { #[tokio::test] async fn execute_missing_action_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({})).await; assert!(result.is_err()); } #[tokio::test] async fn execute_unknown_action_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({"action": "unknown"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("Unknown action")); @@ -720,14 +747,14 @@ mod tests { #[tokio::test] async fn execute_without_action_name_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({"action": "execute"})).await; assert!(result.is_err()); } #[tokio::test] async fn connect_without_target_returns_error() { - let tool = ComposioTool::new("test-key"); + let tool = ComposioTool::new("test-key", None); let result = tool.execute(json!({"action": "connect"})).await; assert!(result.is_err()); } @@ -788,6 +815,12 @@ mod tests { ); } + #[test] + fn normalize_entity_id_falls_back_to_default_when_blank() { + assert_eq!(normalize_entity_id(" "), "default"); + assert_eq!(normalize_entity_id("workspace-user"), "workspace-user"); + } + #[test] fn normalize_tool_slug_supports_legacy_action_name() { assert_eq!( diff --git a/src/tools/mod.rs b/src/tools/mod.rs index b5cd67a..964ba5b 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -59,11 +59,12 @@ pub fn default_tools_with_runtime( } /// Create full tool registry including memory tools and optional Composio -#[allow(clippy::implicit_hasher)] +#[allow(clippy::implicit_hasher, clippy::too_many_arguments)] pub fn all_tools( security: &Arc, memory: Arc, composio_key: Option<&str>, + composio_entity_id: Option<&str>, browser_config: &crate::config::BrowserConfig, http_config: &crate::config::HttpRequestConfig, workspace_dir: &std::path::Path, @@ -76,6 +77,7 @@ pub fn all_tools( Arc::new(NativeRuntime::new()), memory, composio_key, + composio_entity_id, browser_config, http_config, workspace_dir, @@ -86,12 +88,13 @@ pub fn all_tools( } /// Create full tool registry including memory tools and optional Composio. -#[allow(clippy::implicit_hasher)] +#[allow(clippy::implicit_hasher, clippy::too_many_arguments)] pub fn all_tools_with_runtime( security: &Arc, runtime: Arc, memory: Arc, composio_key: Option<&str>, + composio_entity_id: Option<&str>, browser_config: &crate::config::BrowserConfig, http_config: &crate::config::HttpRequestConfig, workspace_dir: &std::path::Path, @@ -146,7 +149,7 @@ pub fn all_tools_with_runtime( if let Some(key) = composio_key { if !key.is_empty() { - tools.push(Box::new(ComposioTool::new(key))); + tools.push(Box::new(ComposioTool::new(key, composio_entity_id))); } } @@ -206,6 +209,7 @@ mod tests { &security, mem, None, + None, &browser, &http, tmp.path(), @@ -242,6 +246,7 @@ mod tests { &security, mem, None, + None, &browser, &http, tmp.path(), @@ -379,6 +384,7 @@ mod tests { &security, mem, None, + None, &browser, &http, tmp.path(), @@ -409,6 +415,7 @@ mod tests { &security, mem, None, + None, &browser, &http, tmp.path(), From a85fcf43c37222457a4ef29a969c357a68211668 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:40:40 +0800 Subject: [PATCH 04/12] fix(build): reduce release-build memory pressure on low-RAM devices (#303) --- Cargo.toml | 8 ++++---- README.md | 2 ++ 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 61b5d6a..6a6bc78 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -114,10 +114,10 @@ path = "src/main.rs" [profile.release] opt-level = "z" # Optimize for size -lto = true # Link-time optimization -codegen-units = 1 # Better optimization -strip = true # Remove debug symbols -panic = "abort" # Reduce binary size +lto = "thin" # Lower memory use during release builds +codegen-units = 8 # Faster, lower-RAM codegen for small devices +strip = true # Remove debug symbols +panic = "abort" # Reduce binary size [profile.dist] inherits = "release" diff --git a/README.md b/README.md index 7cd5aab..ac9a8b2 100644 --- a/README.md +++ b/README.md @@ -114,6 +114,7 @@ zeroclaw migrate openclaw ``` > **Dev fallback (no global install):** prefix commands with `cargo run --release --` (example: `cargo run --release -- status`). +> **Low-memory boards (e.g., Raspberry Pi 3, 1GB RAM):** run `CARGO_BUILD_JOBS=1 cargo build --release` if the kernel kills rustc during compilation. ## Architecture @@ -425,6 +426,7 @@ See [aieos.org](https://aieos.org) for the full schema and live examples. ```bash cargo build # Dev build cargo build --release # Release build (~3.4MB) +CARGO_BUILD_JOBS=1 cargo build --release # Low-memory fallback (Raspberry Pi 3, 1GB RAM) cargo test # 1,017 tests cargo clippy # Lint (0 warnings) cargo fmt # Format From fac1b780cda8a2e4279a4bc3eb4e6f096cb0f531 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:40:44 +0800 Subject: [PATCH 05/12] fix(onboard): refresh MiniMax defaults and endpoint (#299) --- src/channels/mod.rs | 2 +- src/channels/telegram.rs | 3 +- src/onboard/wizard.rs | 151 +++++++++++++++++++++++++++++++++++- src/providers/compatible.rs | 12 ++- src/providers/mod.rs | 5 +- src/tools/git_operations.rs | 2 +- 6 files changed, 168 insertions(+), 7 deletions(-) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 9579ff8..1981472 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -186,7 +186,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C &mut history, ctx.tools_registry.as_ref(), ctx.observer.as_ref(), - ctx.provider_name.as_str(), + "channels", ctx.model.as_str(), ctx.temperature, ), diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index ea90e79..94ff767 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -919,8 +919,7 @@ mod tests { #[test] fn telegram_split_at_newline() { - let line = "Line of text\n"; - let text_block = line.repeat(TELEGRAM_MAX_MESSAGE_LENGTH / line.len() + 1); + let text_block = "Line of text\n".repeat(TELEGRAM_MAX_MESSAGE_LENGTH / 13 + 1); let chunks = split_message_for_telegram(&text_block); assert!(chunks.len() >= 2); for chunk in chunks { diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 7fbcc44..5fee2b6 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -428,11 +428,20 @@ fn canonical_provider_name(provider_name: &str) -> &str { } /// Pick a sensible default model for the given provider. +const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [ + ("MiniMax-M2.5", "MiniMax M2.5 (latest, recommended)"), + ("MiniMax-M2.5-highspeed", "MiniMax M2.5 High-Speed (faster)"), + ("MiniMax-M2.1", "MiniMax M2.1 (stable)"), + ("MiniMax-M2.1-highspeed", "MiniMax M2.1 High-Speed (faster)"), + ("MiniMax-M2", "MiniMax M2 (legacy)"), +]; + fn default_model_for_provider(provider: &str) -> String { match canonical_provider_name(provider) { "anthropic" => "claude-sonnet-4-20250514".into(), "openai" => "gpt-5.2".into(), "glm" | "zhipu" | "zai" | "z.ai" => "glm-5".into(), + "minimax" => "MiniMax-M2.5".into(), "ollama" => "llama3.2".into(), "groq" => "llama-3.3-70b-versatile".into(), "deepseek" => "deepseek-chat".into(), @@ -1454,7 +1463,131 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String)> { }; // ── Model selection ── - let mut model_options = curated_models_for_provider(provider_name); + let models: Vec<(&str, &str)> = match provider_name { + "openrouter" => vec![ + ( + "anthropic/claude-sonnet-4", + "Claude Sonnet 4 (balanced, recommended)", + ), + ( + "anthropic/claude-3.5-sonnet", + "Claude 3.5 Sonnet (fast, affordable)", + ), + ("openai/gpt-4o", "GPT-4o (OpenAI flagship)"), + ("openai/gpt-4o-mini", "GPT-4o Mini (fast, cheap)"), + ( + "google/gemini-2.0-flash-001", + "Gemini 2.0 Flash (Google, fast)", + ), + ( + "meta-llama/llama-3.3-70b-instruct", + "Llama 3.3 70B (open source)", + ), + ("deepseek/deepseek-chat", "DeepSeek Chat (affordable)"), + ], + "anthropic" => vec![ + ( + "claude-sonnet-4-20250514", + "Claude Sonnet 4 (balanced, recommended)", + ), + ("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet (fast)"), + ( + "claude-3-5-haiku-20241022", + "Claude 3.5 Haiku (fastest, cheapest)", + ), + ], + "openai" => vec![ + ("gpt-4o", "GPT-4o (flagship)"), + ("gpt-4o-mini", "GPT-4o Mini (fast, cheap)"), + ("o1-mini", "o1-mini (reasoning)"), + ], + "venice" => vec![ + ("llama-3.3-70b", "Llama 3.3 70B (default, fast)"), + ("claude-opus-45", "Claude Opus 4.5 via Venice (strongest)"), + ("llama-3.1-405b", "Llama 3.1 405B (largest open source)"), + ], + "groq" => vec![ + ( + "llama-3.3-70b-versatile", + "Llama 3.3 70B (fast, recommended)", + ), + ("llama-3.1-8b-instant", "Llama 3.1 8B (instant)"), + ("mixtral-8x7b-32768", "Mixtral 8x7B (32K context)"), + ], + "mistral" => vec![ + ("mistral-large-latest", "Mistral Large (flagship)"), + ("codestral-latest", "Codestral (code-focused)"), + ("mistral-small-latest", "Mistral Small (fast, cheap)"), + ], + "deepseek" => vec![ + ("deepseek-chat", "DeepSeek Chat (V3, recommended)"), + ("deepseek-reasoner", "DeepSeek Reasoner (R1)"), + ], + "xai" => vec![ + ("grok-3", "Grok 3 (flagship)"), + ("grok-3-mini", "Grok 3 Mini (fast)"), + ], + "perplexity" => vec![ + ("sonar-pro", "Sonar Pro (search + reasoning)"), + ("sonar", "Sonar (search, fast)"), + ], + "fireworks" => vec![ + ( + "accounts/fireworks/models/llama-v3p3-70b-instruct", + "Llama 3.3 70B", + ), + ( + "accounts/fireworks/models/mixtral-8x22b-instruct", + "Mixtral 8x22B", + ), + ], + "together" => vec![ + ( + "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", + "Llama 3.1 70B Turbo", + ), + ( + "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", + "Llama 3.1 8B Turbo", + ), + ("mistralai/Mixtral-8x22B-Instruct-v0.1", "Mixtral 8x22B"), + ], + "cohere" => vec![ + ("command-r-plus", "Command R+ (flagship)"), + ("command-r", "Command R (fast)"), + ], + "moonshot" => vec![ + ("moonshot-v1-128k", "Moonshot V1 128K"), + ("moonshot-v1-32k", "Moonshot V1 32K"), + ], + "glm" | "zhipu" | "zai" | "z.ai" => vec![ + ("glm-5", "GLM-5 (latest)"), + ("glm-4-plus", "GLM-4 Plus (flagship)"), + ("glm-4-flash", "GLM-4 Flash (fast)"), + ], + "minimax" => MINIMAX_ONBOARD_MODELS.to_vec(), + "ollama" => vec![ + ("llama3.2", "Llama 3.2 (recommended local)"), + ("mistral", "Mistral 7B"), + ("codellama", "Code Llama"), + ("phi3", "Phi-3 (small, fast)"), + ], + "gemini" | "google" | "google-gemini" => vec![ + ("gemini-2.0-flash", "Gemini 2.0 Flash (fast, recommended)"), + ( + "gemini-2.0-flash-lite", + "Gemini 2.0 Flash Lite (fastest, cheapest)", + ), + ("gemini-1.5-pro", "Gemini 1.5 Pro (best quality)"), + ("gemini-1.5-flash", "Gemini 1.5 Flash (balanced)"), + ], + _ => vec![("default", "Default model")], + }; + + let mut model_options: Vec<(String, String)> = models + .into_iter() + .map(|(model_id, label)| (model_id.to_string(), label.to_string())) + .collect(); let mut live_options: Option> = None; if supports_live_model_fetch(provider_name) { @@ -4206,4 +4339,20 @@ mod tests { fn provider_env_var_unknown_falls_back() { assert_eq!(provider_env_var("some-new-provider"), "API_KEY"); } + + #[test] + fn default_model_for_minimax_is_m2_5() { + assert_eq!(default_model_for_provider("minimax"), "MiniMax-M2.5"); + } + + #[test] + fn minimax_onboard_models_include_m2_variants() { + let model_names: Vec<&str> = MINIMAX_ONBOARD_MODELS + .iter() + .map(|(name, _)| *name) + .collect(); + assert_eq!(model_names.first().copied(), Some("MiniMax-M2.5")); + assert!(model_names.contains(&"MiniMax-M2.1")); + assert!(model_names.contains(&"MiniMax-M2.1-highspeed")); + } } diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index de7bff0..4c59992 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -584,7 +584,7 @@ mod tests { make_provider("Venice", "https://api.venice.ai", None), make_provider("Moonshot", "https://api.moonshot.cn", None), make_provider("GLM", "https://open.bigmodel.cn", None), - make_provider("MiniMax", "https://api.minimax.chat", None), + make_provider("MiniMax", "https://api.minimaxi.com/v1", None), make_provider("Groq", "https://api.groq.com/openai", None), make_provider("Mistral", "https://api.mistral.ai", None), make_provider("xAI", "https://api.x.ai", None), @@ -793,6 +793,16 @@ mod tests { ); } + #[test] + fn chat_completions_url_minimax() { + // MiniMax OpenAI-compatible endpoint requires /v1 base path. + let p = make_provider("minimax", "https://api.minimaxi.com/v1", None); + assert_eq!( + p.chat_completions_url(), + "https://api.minimaxi.com/v1/chat/completions" + ); + } + #[test] fn chat_completions_url_glm() { // GLM (BigModel) uses /api/paas/v4 base path diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 5dd1212..1ba11b7 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -221,7 +221,10 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "MiniMax", "https://api.minimax.chat", key, AuthStyle::Bearer, + "MiniMax", + "https://api.minimaxi.com/v1", + key, + AuthStyle::Bearer, ))), "bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new( "Amazon Bedrock", diff --git a/src/tools/git_operations.rs b/src/tools/git_operations.rs index c197eff..fc4b4d2 100644 --- a/src/tools/git_operations.rs +++ b/src/tools/git_operations.rs @@ -558,7 +558,7 @@ mod tests { use std::path::Path; use tempfile::TempDir; - fn test_tool(dir: &Path) -> GitOperationsTool { + fn test_tool(dir: &std::path::Path) -> GitOperationsTool { let security = Arc::new(SecurityPolicy { autonomy: AutonomyLevel::Supervised, ..SecurityPolicy::default() From 22714271fde7fa14806c9c1eee5d602dc67c4d4d Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:40:47 +0800 Subject: [PATCH 06/12] feat(cost): add budget tracking core and harden storage reliability (#292) --- src/channels/mod.rs | 3 +- src/config/mod.rs | 2 +- src/config/schema.rs | 147 ++++++++++++ src/cost/mod.rs | 5 + src/cost/tracker.rs | 539 ++++++++++++++++++++++++++++++++++++++++++ src/cost/types.rs | 193 +++++++++++++++ src/lib.rs | 1 + src/onboard/wizard.rs | 2 + 8 files changed, 890 insertions(+), 2 deletions(-) create mode 100644 src/cost/mod.rs create mode 100644 src/cost/tracker.rs create mode 100644 src/cost/types.rs diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 1981472..0589e2e 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -682,7 +682,8 @@ pub async fn start_channels(config: Config) -> Result<()> { let provider_name = config .default_provider .clone() - .unwrap_or_else(|| "openrouter".to_string()); + .unwrap_or_else(|| "openrouter".into()); + let provider: Arc = Arc::from(providers::create_resilient_provider( provider_name.as_str(), config.api_key.as_deref(), diff --git a/src/config/mod.rs b/src/config/mod.rs index a61c29c..e53b597 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,7 +2,7 @@ pub mod schema; #[allow(unused_imports)] pub use schema::{ - AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, + AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, CostConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, diff --git a/src/config/schema.rs b/src/config/schema.rs index 8d2ec55..8a66124 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -71,6 +71,9 @@ pub struct Config { #[serde(default)] pub identity: IdentityConfig, + #[serde(default)] + pub cost: CostConfig, + /// Hardware Abstraction Layer (HAL) configuration. /// Controls how ZeroClaw interfaces with physical hardware /// (GPIO, serial, debug probes). @@ -127,6 +130,147 @@ impl Default for IdentityConfig { } } +// ── Cost tracking and budget enforcement ─────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostConfig { + /// Enable cost tracking (default: false) + #[serde(default)] + pub enabled: bool, + + /// Daily spending limit in USD (default: 10.00) + #[serde(default = "default_daily_limit")] + pub daily_limit_usd: f64, + + /// Monthly spending limit in USD (default: 100.00) + #[serde(default = "default_monthly_limit")] + pub monthly_limit_usd: f64, + + /// Warn when spending reaches this percentage of limit (default: 80) + #[serde(default = "default_warn_percent")] + pub warn_at_percent: u8, + + /// Allow requests to exceed budget with --override flag (default: false) + #[serde(default)] + pub allow_override: bool, + + /// Per-model pricing (USD per 1M tokens) + #[serde(default)] + pub prices: std::collections::HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelPricing { + /// Input price per 1M tokens + #[serde(default)] + pub input: f64, + + /// Output price per 1M tokens + #[serde(default)] + pub output: f64, +} + +fn default_daily_limit() -> f64 { + 10.0 +} + +fn default_monthly_limit() -> f64 { + 100.0 +} + +fn default_warn_percent() -> u8 { + 80 +} + +impl Default for CostConfig { + fn default() -> Self { + Self { + enabled: false, + daily_limit_usd: default_daily_limit(), + monthly_limit_usd: default_monthly_limit(), + warn_at_percent: default_warn_percent(), + allow_override: false, + prices: get_default_pricing(), + } + } +} + +/// Default pricing for popular models (USD per 1M tokens) +fn get_default_pricing() -> std::collections::HashMap { + let mut prices = std::collections::HashMap::new(); + + // Anthropic models + prices.insert( + "anthropic/claude-sonnet-4-20250514".into(), + ModelPricing { + input: 3.0, + output: 15.0, + }, + ); + prices.insert( + "anthropic/claude-opus-4-20250514".into(), + ModelPricing { + input: 15.0, + output: 75.0, + }, + ); + prices.insert( + "anthropic/claude-3.5-sonnet".into(), + ModelPricing { + input: 3.0, + output: 15.0, + }, + ); + prices.insert( + "anthropic/claude-3-haiku".into(), + ModelPricing { + input: 0.25, + output: 1.25, + }, + ); + + // OpenAI models + prices.insert( + "openai/gpt-4o".into(), + ModelPricing { + input: 5.0, + output: 15.0, + }, + ); + prices.insert( + "openai/gpt-4o-mini".into(), + ModelPricing { + input: 0.15, + output: 0.60, + }, + ); + prices.insert( + "openai/o1-preview".into(), + ModelPricing { + input: 15.0, + output: 60.0, + }, + ); + + // Google models + prices.insert( + "google/gemini-2.0-flash".into(), + ModelPricing { + input: 0.10, + output: 0.40, + }, + ); + prices.insert( + "google/gemini-1.5-pro".into(), + ModelPricing { + input: 1.25, + output: 5.0, + }, + ); + + prices +} + // ── Agent delegation ───────────────────────────────────────────── /// Configuration for a named delegate agent that can be invoked via the @@ -1200,6 +1344,7 @@ impl Default for Config { browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), hardware: crate::hardware::HardwareConfig::default(), agents: HashMap::new(), security: SecurityConfig::default(), @@ -1556,6 +1701,7 @@ mod tests { browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), hardware: crate::hardware::HardwareConfig::default(), agents: HashMap::new(), security: SecurityConfig::default(), @@ -1632,6 +1778,7 @@ default_temperature = 0.7 browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), hardware: crate::hardware::HardwareConfig::default(), agents: HashMap::new(), security: SecurityConfig::default(), diff --git a/src/cost/mod.rs b/src/cost/mod.rs new file mode 100644 index 0000000..14c634d --- /dev/null +++ b/src/cost/mod.rs @@ -0,0 +1,5 @@ +pub mod tracker; +pub mod types; + +pub use tracker::CostTracker; +pub use types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; diff --git a/src/cost/tracker.rs b/src/cost/tracker.rs new file mode 100644 index 0000000..16b874f --- /dev/null +++ b/src/cost/tracker.rs @@ -0,0 +1,539 @@ +use super::types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; +use crate::config::CostConfig; +use anyhow::{anyhow, Context, Result}; +use chrono::{Datelike, NaiveDate, Utc}; +use std::collections::HashMap; +use std::fs::{self, File, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex, MutexGuard}; + +/// Cost tracker for API usage monitoring and budget enforcement. +pub struct CostTracker { + config: CostConfig, + storage: Arc>, + session_id: String, + session_costs: Arc>>, +} + +impl CostTracker { + /// Create a new cost tracker. + pub fn new(config: CostConfig, workspace_dir: &Path) -> Result { + let storage_path = resolve_storage_path(workspace_dir)?; + + let storage = CostStorage::new(&storage_path).with_context(|| { + format!("Failed to open cost storage at {}", storage_path.display()) + })?; + + Ok(Self { + config, + storage: Arc::new(Mutex::new(storage)), + session_id: uuid::Uuid::new_v4().to_string(), + session_costs: Arc::new(Mutex::new(Vec::new())), + }) + } + + /// Get the session ID. + pub fn session_id(&self) -> &str { + &self.session_id + } + + fn lock_storage(&self) -> Result> { + self.storage + .lock() + .map_err(|_| anyhow!("Cost storage lock poisoned")) + } + + fn lock_session_costs(&self) -> Result>> { + self.session_costs + .lock() + .map_err(|_| anyhow!("Session cost lock poisoned")) + } + + /// Check if a request is within budget. + pub fn check_budget(&self, estimated_cost_usd: f64) -> Result { + if !self.config.enabled { + return Ok(BudgetCheck::Allowed); + } + + if !estimated_cost_usd.is_finite() || estimated_cost_usd < 0.0 { + return Err(anyhow!( + "Estimated cost must be a finite, non-negative value" + )); + } + + let mut storage = self.lock_storage()?; + let (daily_cost, monthly_cost) = storage.get_aggregated_costs()?; + + // Check daily limit + let projected_daily = daily_cost + estimated_cost_usd; + if projected_daily > self.config.daily_limit_usd { + return Ok(BudgetCheck::Exceeded { + current_usd: daily_cost, + limit_usd: self.config.daily_limit_usd, + period: UsagePeriod::Day, + }); + } + + // Check monthly limit + let projected_monthly = monthly_cost + estimated_cost_usd; + if projected_monthly > self.config.monthly_limit_usd { + return Ok(BudgetCheck::Exceeded { + current_usd: monthly_cost, + limit_usd: self.config.monthly_limit_usd, + period: UsagePeriod::Month, + }); + } + + // Check warning thresholds + let warn_threshold = f64::from(self.config.warn_at_percent.min(100)) / 100.0; + let daily_warn_threshold = self.config.daily_limit_usd * warn_threshold; + let monthly_warn_threshold = self.config.monthly_limit_usd * warn_threshold; + + if projected_daily >= daily_warn_threshold { + return Ok(BudgetCheck::Warning { + current_usd: daily_cost, + limit_usd: self.config.daily_limit_usd, + period: UsagePeriod::Day, + }); + } + + if projected_monthly >= monthly_warn_threshold { + return Ok(BudgetCheck::Warning { + current_usd: monthly_cost, + limit_usd: self.config.monthly_limit_usd, + period: UsagePeriod::Month, + }); + } + + Ok(BudgetCheck::Allowed) + } + + /// Record a usage event. + pub fn record_usage(&self, usage: TokenUsage) -> Result<()> { + if !self.config.enabled { + return Ok(()); + } + + if !usage.cost_usd.is_finite() || usage.cost_usd < 0.0 { + return Err(anyhow!( + "Token usage cost must be a finite, non-negative value" + )); + } + + let record = CostRecord::new(&self.session_id, usage); + + // Persist first for durability guarantees. + { + let mut storage = self.lock_storage()?; + storage.add_record(record.clone())?; + } + + // Then update in-memory session snapshot. + let mut session_costs = self.lock_session_costs()?; + session_costs.push(record); + + Ok(()) + } + + /// Get the current cost summary. + pub fn get_summary(&self) -> Result { + let (daily_cost, monthly_cost) = { + let mut storage = self.lock_storage()?; + storage.get_aggregated_costs()? + }; + + let session_costs = self.lock_session_costs()?; + let session_cost: f64 = session_costs + .iter() + .map(|record| record.usage.cost_usd) + .sum(); + let total_tokens: u64 = session_costs + .iter() + .map(|record| record.usage.total_tokens) + .sum(); + let request_count = session_costs.len(); + let by_model = build_session_model_stats(&session_costs); + + Ok(CostSummary { + session_cost_usd: session_cost, + daily_cost_usd: daily_cost, + monthly_cost_usd: monthly_cost, + total_tokens, + request_count, + by_model, + }) + } + + /// Get the daily cost for a specific date. + pub fn get_daily_cost(&self, date: NaiveDate) -> Result { + let storage = self.lock_storage()?; + storage.get_cost_for_date(date) + } + + /// Get the monthly cost for a specific month. + pub fn get_monthly_cost(&self, year: i32, month: u32) -> Result { + let storage = self.lock_storage()?; + storage.get_cost_for_month(year, month) + } +} + +fn resolve_storage_path(workspace_dir: &Path) -> Result { + let storage_path = workspace_dir.join("state").join("costs.jsonl"); + let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db"); + + if !storage_path.exists() && legacy_path.exists() { + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + if let Err(error) = fs::rename(&legacy_path, &storage_path) { + tracing::warn!( + "Failed to move legacy cost storage from {} to {}: {error}; falling back to copy", + legacy_path.display(), + storage_path.display() + ); + fs::copy(&legacy_path, &storage_path).with_context(|| { + format!( + "Failed to copy legacy cost storage from {} to {}", + legacy_path.display(), + storage_path.display() + ) + })?; + } + } + + Ok(storage_path) +} + +fn build_session_model_stats(session_costs: &[CostRecord]) -> HashMap { + let mut by_model: HashMap = HashMap::new(); + + for record in session_costs { + let entry = by_model + .entry(record.usage.model.clone()) + .or_insert_with(|| ModelStats { + model: record.usage.model.clone(), + cost_usd: 0.0, + total_tokens: 0, + request_count: 0, + }); + + entry.cost_usd += record.usage.cost_usd; + entry.total_tokens += record.usage.total_tokens; + entry.request_count += 1; + } + + by_model +} + +/// Persistent storage for cost records. +struct CostStorage { + path: PathBuf, + daily_cost_usd: f64, + monthly_cost_usd: f64, + cached_day: NaiveDate, + cached_year: i32, + cached_month: u32, +} + +impl CostStorage { + /// Create or open cost storage. + fn new(path: &Path) -> Result { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + let now = Utc::now(); + let mut storage = Self { + path: path.to_path_buf(), + daily_cost_usd: 0.0, + monthly_cost_usd: 0.0, + cached_day: now.date_naive(), + cached_year: now.year(), + cached_month: now.month(), + }; + + storage.rebuild_aggregates( + storage.cached_day, + storage.cached_year, + storage.cached_month, + )?; + + Ok(storage) + } + + fn for_each_record(&self, mut on_record: F) -> Result<()> + where + F: FnMut(CostRecord), + { + if !self.path.exists() { + return Ok(()); + } + + let file = File::open(&self.path) + .with_context(|| format!("Failed to read cost storage from {}", self.path.display()))?; + let reader = BufReader::new(file); + + for (line_number, line) in reader.lines().enumerate() { + let raw_line = line.with_context(|| { + format!( + "Failed to read line {} from cost storage {}", + line_number + 1, + self.path.display() + ) + })?; + + let trimmed = raw_line.trim(); + if trimmed.is_empty() { + continue; + } + + match serde_json::from_str::(trimmed) { + Ok(record) => on_record(record), + Err(error) => { + tracing::warn!( + "Skipping malformed cost record at {}:{}: {error}", + self.path.display(), + line_number + 1 + ); + } + } + } + + Ok(()) + } + + fn rebuild_aggregates(&mut self, day: NaiveDate, year: i32, month: u32) -> Result<()> { + let mut daily_cost = 0.0; + let mut monthly_cost = 0.0; + + self.for_each_record(|record| { + let timestamp = record.usage.timestamp.naive_utc(); + + if timestamp.date() == day { + daily_cost += record.usage.cost_usd; + } + + if timestamp.year() == year && timestamp.month() == month { + monthly_cost += record.usage.cost_usd; + } + })?; + + self.daily_cost_usd = daily_cost; + self.monthly_cost_usd = monthly_cost; + self.cached_day = day; + self.cached_year = year; + self.cached_month = month; + + Ok(()) + } + + fn ensure_period_cache_current(&mut self) -> Result<()> { + let now = Utc::now(); + let day = now.date_naive(); + let year = now.year(); + let month = now.month(); + + if day != self.cached_day || year != self.cached_year || month != self.cached_month { + self.rebuild_aggregates(day, year, month)?; + } + + Ok(()) + } + + /// Add a new record. + fn add_record(&mut self, record: CostRecord) -> Result<()> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.path) + .with_context(|| format!("Failed to open cost storage at {}", self.path.display()))?; + + writeln!(file, "{}", serde_json::to_string(&record)?) + .with_context(|| format!("Failed to write cost record to {}", self.path.display()))?; + file.sync_all() + .with_context(|| format!("Failed to sync cost storage at {}", self.path.display()))?; + + self.ensure_period_cache_current()?; + + let timestamp = record.usage.timestamp.naive_utc(); + if timestamp.date() == self.cached_day { + self.daily_cost_usd += record.usage.cost_usd; + } + if timestamp.year() == self.cached_year && timestamp.month() == self.cached_month { + self.monthly_cost_usd += record.usage.cost_usd; + } + + Ok(()) + } + + /// Get aggregated costs for current day and month. + fn get_aggregated_costs(&mut self) -> Result<(f64, f64)> { + self.ensure_period_cache_current()?; + Ok((self.daily_cost_usd, self.monthly_cost_usd)) + } + + /// Get cost for a specific date. + fn get_cost_for_date(&self, date: NaiveDate) -> Result { + let mut cost = 0.0; + + self.for_each_record(|record| { + if record.usage.timestamp.naive_utc().date() == date { + cost += record.usage.cost_usd; + } + })?; + + Ok(cost) + } + + /// Get cost for a specific month. + fn get_cost_for_month(&self, year: i32, month: u32) -> Result { + let mut cost = 0.0; + + self.for_each_record(|record| { + let timestamp = record.usage.timestamp.naive_utc(); + if timestamp.year() == year && timestamp.month() == month { + cost += record.usage.cost_usd; + } + })?; + + Ok(cost) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn enabled_config() -> CostConfig { + CostConfig { + enabled: true, + ..Default::default() + } + } + + #[test] + fn cost_tracker_initialization() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + assert!(!tracker.session_id().is_empty()); + } + + #[test] + fn budget_check_when_disabled() { + let tmp = TempDir::new().unwrap(); + let config = CostConfig { + enabled: false, + ..Default::default() + }; + + let tracker = CostTracker::new(config, tmp.path()).unwrap(); + let check = tracker.check_budget(1000.0).unwrap(); + assert!(matches!(check, BudgetCheck::Allowed)); + } + + #[test] + fn record_usage_and_get_summary() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + + let usage = TokenUsage::new("test/model", 1000, 500, 1.0, 2.0); + tracker.record_usage(usage).unwrap(); + + let summary = tracker.get_summary().unwrap(); + assert_eq!(summary.request_count, 1); + assert!(summary.session_cost_usd > 0.0); + assert_eq!(summary.by_model.len(), 1); + } + + #[test] + fn budget_exceeded_daily_limit() { + let tmp = TempDir::new().unwrap(); + let config = CostConfig { + enabled: true, + daily_limit_usd: 0.01, // Very low limit + ..Default::default() + }; + + let tracker = CostTracker::new(config, tmp.path()).unwrap(); + + // Record a usage that exceeds the limit + let usage = TokenUsage::new("test/model", 10000, 5000, 1.0, 2.0); // ~0.02 USD + tracker.record_usage(usage).unwrap(); + + let check = tracker.check_budget(0.01).unwrap(); + assert!(matches!(check, BudgetCheck::Exceeded { .. })); + } + + #[test] + fn summary_by_model_is_session_scoped() { + let tmp = TempDir::new().unwrap(); + let storage_path = resolve_storage_path(tmp.path()).unwrap(); + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent).unwrap(); + } + + let old_record = CostRecord::new( + "old-session", + TokenUsage::new("legacy/model", 500, 500, 1.0, 1.0), + ); + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(storage_path) + .unwrap(); + writeln!(file, "{}", serde_json::to_string(&old_record).unwrap()).unwrap(); + file.sync_all().unwrap(); + + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + tracker + .record_usage(TokenUsage::new("session/model", 1000, 1000, 1.0, 1.0)) + .unwrap(); + + let summary = tracker.get_summary().unwrap(); + assert_eq!(summary.by_model.len(), 1); + assert!(summary.by_model.contains_key("session/model")); + assert!(!summary.by_model.contains_key("legacy/model")); + } + + #[test] + fn malformed_lines_are_ignored_while_loading() { + let tmp = TempDir::new().unwrap(); + let storage_path = resolve_storage_path(tmp.path()).unwrap(); + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent).unwrap(); + } + + let valid_usage = TokenUsage::new("test/model", 1000, 0, 1.0, 1.0); + let valid_record = CostRecord::new("session-a", valid_usage.clone()); + + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(storage_path) + .unwrap(); + writeln!(file, "{}", serde_json::to_string(&valid_record).unwrap()).unwrap(); + writeln!(file, "not-a-json-line").unwrap(); + writeln!(file).unwrap(); + file.sync_all().unwrap(); + + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + let today_cost = tracker.get_daily_cost(Utc::now().date_naive()).unwrap(); + assert!((today_cost - valid_usage.cost_usd).abs() < f64::EPSILON); + } + + #[test] + fn invalid_budget_estimate_is_rejected() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + + let err = tracker.check_budget(f64::NAN).unwrap_err(); + assert!(err + .to_string() + .contains("Estimated cost must be a finite, non-negative value")); + } +} diff --git a/src/cost/types.rs b/src/cost/types.rs new file mode 100644 index 0000000..0e8d167 --- /dev/null +++ b/src/cost/types.rs @@ -0,0 +1,193 @@ +use serde::{Deserialize, Serialize}; + +/// Token usage information from a single API call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + /// Model identifier (e.g., "anthropic/claude-sonnet-4-20250514") + pub model: String, + /// Input/prompt tokens + pub input_tokens: u64, + /// Output/completion tokens + pub output_tokens: u64, + /// Total tokens + pub total_tokens: u64, + /// Calculated cost in USD + pub cost_usd: f64, + /// Timestamp of the request + pub timestamp: chrono::DateTime, +} + +impl TokenUsage { + fn sanitize_price(value: f64) -> f64 { + if value.is_finite() && value > 0.0 { + value + } else { + 0.0 + } + } + + /// Create a new token usage record. + pub fn new( + model: impl Into, + input_tokens: u64, + output_tokens: u64, + input_price_per_million: f64, + output_price_per_million: f64, + ) -> Self { + let model = model.into(); + let input_price_per_million = Self::sanitize_price(input_price_per_million); + let output_price_per_million = Self::sanitize_price(output_price_per_million); + let total_tokens = input_tokens.saturating_add(output_tokens); + + // Calculate cost: (tokens / 1M) * price_per_million + let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price_per_million; + let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million; + let cost_usd = input_cost + output_cost; + + Self { + model, + input_tokens, + output_tokens, + total_tokens, + cost_usd, + timestamp: chrono::Utc::now(), + } + } + + /// Get the total cost. + pub fn cost(&self) -> f64 { + self.cost_usd + } +} + +/// Time period for cost aggregation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum UsagePeriod { + Session, + Day, + Month, +} + +/// A single cost record for persistent storage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostRecord { + /// Unique identifier + pub id: String, + /// Token usage details + pub usage: TokenUsage, + /// Session identifier (for grouping) + pub session_id: String, +} + +impl CostRecord { + /// Create a new cost record. + pub fn new(session_id: impl Into, usage: TokenUsage) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + usage, + session_id: session_id.into(), + } + } +} + +/// Budget enforcement result. +#[derive(Debug, Clone)] +pub enum BudgetCheck { + /// Within budget, request can proceed + Allowed, + /// Warning threshold exceeded but request can proceed + Warning { + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, + }, + /// Budget exceeded, request blocked + Exceeded { + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, + }, +} + +/// Cost summary for reporting. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostSummary { + /// Total cost for the session + pub session_cost_usd: f64, + /// Total cost for the day + pub daily_cost_usd: f64, + /// Total cost for the month + pub monthly_cost_usd: f64, + /// Total tokens used + pub total_tokens: u64, + /// Number of requests + pub request_count: usize, + /// Breakdown by model + pub by_model: std::collections::HashMap, +} + +/// Statistics for a specific model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelStats { + /// Model name + pub model: String, + /// Total cost for this model + pub cost_usd: f64, + /// Total tokens for this model + pub total_tokens: u64, + /// Number of requests for this model + pub request_count: usize, +} + +impl Default for CostSummary { + fn default() -> Self { + Self { + session_cost_usd: 0.0, + daily_cost_usd: 0.0, + monthly_cost_usd: 0.0, + total_tokens: 0, + request_count: 0, + by_model: std::collections::HashMap::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_usage_calculation() { + let usage = TokenUsage::new("test/model", 1000, 500, 3.0, 15.0); + + // Expected: (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105 + assert!((usage.cost_usd - 0.0105).abs() < 0.0001); + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.total_tokens, 1500); + } + + #[test] + fn token_usage_zero_tokens() { + let usage = TokenUsage::new("test/model", 0, 0, 3.0, 15.0); + assert!(usage.cost_usd.abs() < f64::EPSILON); + assert_eq!(usage.total_tokens, 0); + } + + #[test] + fn token_usage_negative_or_non_finite_prices_are_clamped() { + let usage = TokenUsage::new("test/model", 1000, 1000, -3.0, f64::NAN); + assert!(usage.cost_usd.abs() < f64::EPSILON); + assert_eq!(usage.total_tokens, 2000); + } + + #[test] + fn cost_record_creation() { + let usage = TokenUsage::new("test/model", 100, 50, 1.0, 2.0); + let record = CostRecord::new("session-123", usage); + + assert_eq!(record.session_id, "session-123"); + assert!(!record.id.is_empty()); + assert_eq!(record.usage.model, "test/model"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 61a2bc6..588ada3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ use serde::{Deserialize, Serialize}; pub mod agent; pub mod channels; pub mod config; +pub mod cost; pub mod cron; pub mod daemon; pub mod doctor; diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 5fee2b6..ddac80e 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -122,6 +122,7 @@ pub fn run_wizard() -> Result { browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), + cost: crate::config::CostConfig::default(), hardware: hardware_config, agents: std::collections::HashMap::new(), security: crate::config::SecurityConfig::default(), @@ -318,6 +319,7 @@ pub fn run_quick_setup( browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), + cost: crate::config::CostConfig::default(), hardware: HardwareConfig::default(), agents: std::collections::HashMap::new(), security: crate::config::SecurityConfig::default(), From e349067f708fa451148b40b39656d350e9f58c04 Mon Sep 17 00:00:00 2001 From: cd slash <29688941+cd-slash@users.noreply.github.com> Date: Mon, 16 Feb 2026 15:53:34 +0000 Subject: [PATCH 07/12] fix(providers): correct Fireworks AI base URL to include /v1 path (#346) The Fireworks API endpoint requires /v1/chat/completions, but the base URL was missing the /v1 path segment, causing 404 errors and triggering a broken responses fallback. Fix: Add /v1 to base URL so correct endpoint is built: https://api.fireworks.ai/inference/v1/chat/completions --- src/providers/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 1ba11b7..b342675 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -253,7 +253,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "Fireworks AI", "https://api.fireworks.ai/inference", key, AuthStyle::Bearer, + "Fireworks AI", "https://api.fireworks.ai/inference/v1", key, AuthStyle::Bearer, ))), "perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new( "Perplexity", "https://api.perplexity.ai", key, AuthStyle::Bearer, From 8e23cbc59622c4342b4f659dec773a694ca8724c Mon Sep 17 00:00:00 2001 From: Will Sarg <12886992+willsarg@users.noreply.github.com> Date: Mon, 16 Feb 2026 10:56:53 -0500 Subject: [PATCH 08/12] ci: route trusted pushes to self-hosted runner (#369) --- .github/workflows/ci.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 68cb185..e7b54ad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -118,7 +118,7 @@ jobs: name: Format & Lint needs: [changes] if: needs.changes.outputs.rust_changed == 'true' - runs-on: ubuntu-latest + runs-on: ${{ github.event_name == 'push' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 20 steps: - uses: actions/checkout@v4 @@ -138,7 +138,7 @@ jobs: name: Test needs: [changes] if: needs.changes.outputs.rust_changed == 'true' - runs-on: ubuntu-latest + runs-on: ${{ github.event_name == 'push' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 30 steps: - uses: actions/checkout@v4 @@ -153,7 +153,7 @@ jobs: name: Build (Smoke) needs: [changes] if: needs.changes.outputs.rust_changed == 'true' - runs-on: ubuntu-latest + runs-on: ${{ github.event_name == 'push' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 20 steps: @@ -187,7 +187,7 @@ jobs: name: Docs Quality needs: [changes] if: needs.changes.outputs.docs_changed == 'true' - runs-on: ubuntu-latest + runs-on: ${{ github.event_name == 'push' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 15 steps: - uses: actions/checkout@v4 From 444d80e1785e8421506260e5a4c552ee5ad37a13 Mon Sep 17 00:00:00 2001 From: fettpl <38704082+fettpl@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:51:38 +0100 Subject: [PATCH 09/12] fix(tools): use original headers for HTTP requests, redact only in display sanitize_headers was replacing sensitive header values with ***REDACTED*** before passing them to the actual HTTP request, breaking any authenticated API call. Split into parse_headers (preserves original values for the request) and redact_headers_for_display (returns redacted copy for output/logging). Closes #348 Co-Authored-By: Claude Opus 4.6 --- src/tools/http_request.rs | 84 +++++++++++++++++++++++++++++---------- 1 file changed, 62 insertions(+), 22 deletions(-) diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index 36ebbd6..43b05ac 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -76,28 +76,37 @@ impl HttpRequestTool { } } - fn sanitize_headers(&self, headers: &serde_json::Value) -> Vec<(String, String)> { + fn parse_headers(&self, headers: &serde_json::Value) -> Vec<(String, String)> { let mut result = Vec::new(); if let Some(obj) = headers.as_object() { for (key, value) in obj { if let Some(str_val) = value.as_str() { - // Redact sensitive headers from logs (we don't log headers, but this is defense-in-depth) - let is_sensitive = key.to_lowercase().contains("authorization") - || key.to_lowercase().contains("api-key") - || key.to_lowercase().contains("apikey") - || key.to_lowercase().contains("token") - || key.to_lowercase().contains("secret"); - if is_sensitive { - result.push((key.clone(), "***REDACTED***".into())); - } else { - result.push((key.clone(), str_val.to_string())); - } + result.push((key.clone(), str_val.to_string())); } } } result } + fn redact_headers_for_display(headers: &[(String, String)]) -> Vec<(String, String)> { + headers + .iter() + .map(|(key, value)| { + let lower = key.to_lowercase(); + let is_sensitive = lower.contains("authorization") + || lower.contains("api-key") + || lower.contains("apikey") + || lower.contains("token") + || lower.contains("secret"); + if is_sensitive { + (key.clone(), "***REDACTED***".into()) + } else { + (key.clone(), value.clone()) + } + }) + .collect() + } + async fn execute_request( &self, url: &str, @@ -222,10 +231,10 @@ impl Tool for HttpRequestTool { } }; - let sanitized_headers = self.sanitize_headers(&headers_val); + let request_headers = self.parse_headers(&headers_val); match self - .execute_request(&url, method, sanitized_headers, body) + .execute_request(&url, method, request_headers, body) .await { Ok(response) => { @@ -600,23 +609,54 @@ mod tests { } #[test] - fn sanitize_headers_redacts_sensitive() { + fn parse_headers_preserves_original_values() { let tool = test_tool(vec!["example.com"]); let headers = json!({ "Authorization": "Bearer secret", "Content-Type": "application/json", "X-API-Key": "my-key" }); - let sanitized = tool.sanitize_headers(&headers); - assert_eq!(sanitized.len(), 3); - assert!(sanitized + let parsed = tool.parse_headers(&headers); + assert_eq!(parsed.len(), 3); + assert!(parsed .iter() - .any(|(k, v)| k == "Authorization" && v == "***REDACTED***")); - assert!(sanitized + .any(|(k, v)| k == "Authorization" && v == "Bearer secret")); + assert!(parsed .iter() - .any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")); - assert!(sanitized + .any(|(k, v)| k == "X-API-Key" && v == "my-key")); + assert!(parsed .iter() .any(|(k, v)| k == "Content-Type" && v == "application/json")); } + + #[test] + fn redact_headers_for_display_redacts_sensitive() { + let headers = vec![ + ("Authorization".into(), "Bearer secret".into()), + ("Content-Type".into(), "application/json".into()), + ("X-API-Key".into(), "my-key".into()), + ("X-Secret-Token".into(), "tok-123".into()), + ]; + let redacted = HttpRequestTool::redact_headers_for_display(&headers); + assert_eq!(redacted.len(), 4); + assert!(redacted + .iter() + .any(|(k, v)| k == "Authorization" && v == "***REDACTED***")); + assert!(redacted + .iter() + .any(|(k, v)| k == "X-API-Key" && v == "***REDACTED***")); + assert!(redacted + .iter() + .any(|(k, v)| k == "X-Secret-Token" && v == "***REDACTED***")); + assert!(redacted + .iter() + .any(|(k, v)| k == "Content-Type" && v == "application/json")); + } + + #[test] + fn redact_headers_does_not_alter_original() { + let headers = vec![("Authorization".into(), "Bearer real-token".into())]; + let _ = HttpRequestTool::redact_headers_for_display(&headers); + assert_eq!(headers[0].1, "Bearer real-token"); + } } From a7d19b332e6547b7d03083b4f32c482d95118fad Mon Sep 17 00:00:00 2001 From: Will Sarg <12886992+willsarg@users.noreply.github.com> Date: Mon, 16 Feb 2026 10:58:45 -0500 Subject: [PATCH 10/12] ci: route trusted security and workflow checks to self-hosted (#370) --- .github/workflows/security.yml | 4 ++-- .github/workflows/workflow-sanity.yml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/security.yml b/.github/workflows/security.yml index 60febb7..bff64dc 100644 --- a/.github/workflows/security.yml +++ b/.github/workflows/security.yml @@ -21,7 +21,7 @@ env: jobs: audit: name: Security Audit - runs-on: ubuntu-latest + runs-on: ${{ github.event_name != 'pull_request' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 20 steps: - uses: actions/checkout@v4 @@ -37,7 +37,7 @@ jobs: deny: name: License & Supply Chain - runs-on: ubuntu-latest + runs-on: ${{ github.event_name != 'pull_request' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 20 steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/workflow-sanity.yml b/.github/workflows/workflow-sanity.yml index 47d692d..c37c1f9 100644 --- a/.github/workflows/workflow-sanity.yml +++ b/.github/workflows/workflow-sanity.yml @@ -22,7 +22,7 @@ permissions: jobs: no-tabs: - runs-on: ubuntu-latest + runs-on: ${{ github.event_name == 'push' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 10 steps: - name: Checkout @@ -55,7 +55,7 @@ jobs: PY actionlint: - runs-on: ubuntu-latest + runs-on: ${{ github.event_name == 'push' && fromJSON('["self-hosted","Linux","X64","lxc-ci"]') || 'ubuntu-latest' }} timeout-minutes: 10 steps: - name: Checkout From d5ca9a4a5c13c76c3676f2d5c148cf768f7fa7d0 Mon Sep 17 00:00:00 2001 From: fettpl <38704082+fettpl@users.noreply.github.com> Date: Mon, 16 Feb 2026 16:57:00 +0100 Subject: [PATCH 11/12] fix(main): remove duplicate ModelCommands enum definition A duplicate ModelCommands enum was introduced in a recent merge, causing E0119/E0428 compile errors on CI (Rust 1.92). Co-Authored-By: Claude Opus 4.6 --- src/main.rs | 14 -------------- src/tools/git_operations.rs | 1 - 2 files changed, 15 deletions(-) diff --git a/src/main.rs b/src/main.rs index a5c17f4..3253594 100644 --- a/src/main.rs +++ b/src/main.rs @@ -272,20 +272,6 @@ enum ModelCommands { }, } -#[derive(Subcommand, Debug)] -enum ModelCommands { - /// Refresh and cache provider models - Refresh { - /// Provider name (defaults to configured default provider) - #[arg(long)] - provider: Option, - - /// Force live refresh and ignore fresh cache - #[arg(long)] - force: bool, - }, -} - #[derive(Subcommand, Debug)] enum ChannelCommands { /// List configured channels diff --git a/src/tools/git_operations.rs b/src/tools/git_operations.rs index fc4b4d2..e20113a 100644 --- a/src/tools/git_operations.rs +++ b/src/tools/git_operations.rs @@ -555,7 +555,6 @@ impl Tool for GitOperationsTool { mod tests { use super::*; use crate::security::SecurityPolicy; - use std::path::Path; use tempfile::TempDir; fn test_tool(dir: &std::path::Path) -> GitOperationsTool { From 6fd8b523b92cc58533ec7fb712496fe69075b057 Mon Sep 17 00:00:00 2001 From: Will Sarg <12886992+willsarg@users.noreply.github.com> Date: Mon, 16 Feb 2026 11:00:25 -0500 Subject: [PATCH 12/12] ci: route trusted docker and release publish jobs to self-hosted (#371) --- .github/workflows/docker.yml | 2 +- .github/workflows/release.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index fd52635..ec37a37 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -62,7 +62,7 @@ jobs: publish: name: Build and Push Docker Image if: github.event_name == 'push' - runs-on: ubuntu-latest + runs-on: [self-hosted, Linux, X64, lxc-ci] timeout-minutes: 25 permissions: contents: read diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 922cff9..aa1a475 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -74,7 +74,7 @@ jobs: publish: name: Publish Release needs: build-release - runs-on: ubuntu-latest + runs-on: [self-hosted, Linux, X64, lxc-ci] timeout-minutes: 15 steps: - uses: actions/checkout@v4