From e4944a5fc2f2e3ccd0caf433cfdf5ab62e849721 Mon Sep 17 00:00:00 2001 From: Chummy Date: Mon, 16 Feb 2026 23:40:47 +0800 Subject: [PATCH] feat(cost): add budget tracking core and harden storage reliability (#292) --- src/channels/mod.rs | 3 +- src/config/mod.rs | 2 +- src/config/schema.rs | 147 ++++++++++++ src/cost/mod.rs | 5 + src/cost/tracker.rs | 539 ++++++++++++++++++++++++++++++++++++++++++ src/cost/types.rs | 193 +++++++++++++++ src/lib.rs | 1 + src/onboard/wizard.rs | 2 + 8 files changed, 890 insertions(+), 2 deletions(-) create mode 100644 src/cost/mod.rs create mode 100644 src/cost/tracker.rs create mode 100644 src/cost/types.rs diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 1981472..0589e2e 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -682,7 +682,8 @@ pub async fn start_channels(config: Config) -> Result<()> { let provider_name = config .default_provider .clone() - .unwrap_or_else(|| "openrouter".to_string()); + .unwrap_or_else(|| "openrouter".into()); + let provider: Arc = Arc::from(providers::create_resilient_provider( provider_name.as_str(), config.api_key.as_deref(), diff --git a/src/config/mod.rs b/src/config/mod.rs index a61c29c..e53b597 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,7 +2,7 @@ pub mod schema; #[allow(unused_imports)] pub use schema::{ - AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, + AuditConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, CostConfig, DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, diff --git a/src/config/schema.rs b/src/config/schema.rs index 8d2ec55..8a66124 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -71,6 +71,9 @@ pub struct Config { #[serde(default)] pub identity: IdentityConfig, + #[serde(default)] + pub cost: CostConfig, + /// Hardware Abstraction Layer (HAL) configuration. /// Controls how ZeroClaw interfaces with physical hardware /// (GPIO, serial, debug probes). @@ -127,6 +130,147 @@ impl Default for IdentityConfig { } } +// ── Cost tracking and budget enforcement ─────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostConfig { + /// Enable cost tracking (default: false) + #[serde(default)] + pub enabled: bool, + + /// Daily spending limit in USD (default: 10.00) + #[serde(default = "default_daily_limit")] + pub daily_limit_usd: f64, + + /// Monthly spending limit in USD (default: 100.00) + #[serde(default = "default_monthly_limit")] + pub monthly_limit_usd: f64, + + /// Warn when spending reaches this percentage of limit (default: 80) + #[serde(default = "default_warn_percent")] + pub warn_at_percent: u8, + + /// Allow requests to exceed budget with --override flag (default: false) + #[serde(default)] + pub allow_override: bool, + + /// Per-model pricing (USD per 1M tokens) + #[serde(default)] + pub prices: std::collections::HashMap, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelPricing { + /// Input price per 1M tokens + #[serde(default)] + pub input: f64, + + /// Output price per 1M tokens + #[serde(default)] + pub output: f64, +} + +fn default_daily_limit() -> f64 { + 10.0 +} + +fn default_monthly_limit() -> f64 { + 100.0 +} + +fn default_warn_percent() -> u8 { + 80 +} + +impl Default for CostConfig { + fn default() -> Self { + Self { + enabled: false, + daily_limit_usd: default_daily_limit(), + monthly_limit_usd: default_monthly_limit(), + warn_at_percent: default_warn_percent(), + allow_override: false, + prices: get_default_pricing(), + } + } +} + +/// Default pricing for popular models (USD per 1M tokens) +fn get_default_pricing() -> std::collections::HashMap { + let mut prices = std::collections::HashMap::new(); + + // Anthropic models + prices.insert( + "anthropic/claude-sonnet-4-20250514".into(), + ModelPricing { + input: 3.0, + output: 15.0, + }, + ); + prices.insert( + "anthropic/claude-opus-4-20250514".into(), + ModelPricing { + input: 15.0, + output: 75.0, + }, + ); + prices.insert( + "anthropic/claude-3.5-sonnet".into(), + ModelPricing { + input: 3.0, + output: 15.0, + }, + ); + prices.insert( + "anthropic/claude-3-haiku".into(), + ModelPricing { + input: 0.25, + output: 1.25, + }, + ); + + // OpenAI models + prices.insert( + "openai/gpt-4o".into(), + ModelPricing { + input: 5.0, + output: 15.0, + }, + ); + prices.insert( + "openai/gpt-4o-mini".into(), + ModelPricing { + input: 0.15, + output: 0.60, + }, + ); + prices.insert( + "openai/o1-preview".into(), + ModelPricing { + input: 15.0, + output: 60.0, + }, + ); + + // Google models + prices.insert( + "google/gemini-2.0-flash".into(), + ModelPricing { + input: 0.10, + output: 0.40, + }, + ); + prices.insert( + "google/gemini-1.5-pro".into(), + ModelPricing { + input: 1.25, + output: 5.0, + }, + ); + + prices +} + // ── Agent delegation ───────────────────────────────────────────── /// Configuration for a named delegate agent that can be invoked via the @@ -1200,6 +1344,7 @@ impl Default for Config { browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), hardware: crate::hardware::HardwareConfig::default(), agents: HashMap::new(), security: SecurityConfig::default(), @@ -1556,6 +1701,7 @@ mod tests { browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), hardware: crate::hardware::HardwareConfig::default(), agents: HashMap::new(), security: SecurityConfig::default(), @@ -1632,6 +1778,7 @@ default_temperature = 0.7 browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), identity: IdentityConfig::default(), + cost: CostConfig::default(), hardware: crate::hardware::HardwareConfig::default(), agents: HashMap::new(), security: SecurityConfig::default(), diff --git a/src/cost/mod.rs b/src/cost/mod.rs new file mode 100644 index 0000000..14c634d --- /dev/null +++ b/src/cost/mod.rs @@ -0,0 +1,5 @@ +pub mod tracker; +pub mod types; + +pub use tracker::CostTracker; +pub use types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; diff --git a/src/cost/tracker.rs b/src/cost/tracker.rs new file mode 100644 index 0000000..16b874f --- /dev/null +++ b/src/cost/tracker.rs @@ -0,0 +1,539 @@ +use super::types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; +use crate::config::CostConfig; +use anyhow::{anyhow, Context, Result}; +use chrono::{Datelike, NaiveDate, Utc}; +use std::collections::HashMap; +use std::fs::{self, File, OpenOptions}; +use std::io::{BufRead, BufReader, Write}; +use std::path::{Path, PathBuf}; +use std::sync::{Arc, Mutex, MutexGuard}; + +/// Cost tracker for API usage monitoring and budget enforcement. +pub struct CostTracker { + config: CostConfig, + storage: Arc>, + session_id: String, + session_costs: Arc>>, +} + +impl CostTracker { + /// Create a new cost tracker. + pub fn new(config: CostConfig, workspace_dir: &Path) -> Result { + let storage_path = resolve_storage_path(workspace_dir)?; + + let storage = CostStorage::new(&storage_path).with_context(|| { + format!("Failed to open cost storage at {}", storage_path.display()) + })?; + + Ok(Self { + config, + storage: Arc::new(Mutex::new(storage)), + session_id: uuid::Uuid::new_v4().to_string(), + session_costs: Arc::new(Mutex::new(Vec::new())), + }) + } + + /// Get the session ID. + pub fn session_id(&self) -> &str { + &self.session_id + } + + fn lock_storage(&self) -> Result> { + self.storage + .lock() + .map_err(|_| anyhow!("Cost storage lock poisoned")) + } + + fn lock_session_costs(&self) -> Result>> { + self.session_costs + .lock() + .map_err(|_| anyhow!("Session cost lock poisoned")) + } + + /// Check if a request is within budget. + pub fn check_budget(&self, estimated_cost_usd: f64) -> Result { + if !self.config.enabled { + return Ok(BudgetCheck::Allowed); + } + + if !estimated_cost_usd.is_finite() || estimated_cost_usd < 0.0 { + return Err(anyhow!( + "Estimated cost must be a finite, non-negative value" + )); + } + + let mut storage = self.lock_storage()?; + let (daily_cost, monthly_cost) = storage.get_aggregated_costs()?; + + // Check daily limit + let projected_daily = daily_cost + estimated_cost_usd; + if projected_daily > self.config.daily_limit_usd { + return Ok(BudgetCheck::Exceeded { + current_usd: daily_cost, + limit_usd: self.config.daily_limit_usd, + period: UsagePeriod::Day, + }); + } + + // Check monthly limit + let projected_monthly = monthly_cost + estimated_cost_usd; + if projected_monthly > self.config.monthly_limit_usd { + return Ok(BudgetCheck::Exceeded { + current_usd: monthly_cost, + limit_usd: self.config.monthly_limit_usd, + period: UsagePeriod::Month, + }); + } + + // Check warning thresholds + let warn_threshold = f64::from(self.config.warn_at_percent.min(100)) / 100.0; + let daily_warn_threshold = self.config.daily_limit_usd * warn_threshold; + let monthly_warn_threshold = self.config.monthly_limit_usd * warn_threshold; + + if projected_daily >= daily_warn_threshold { + return Ok(BudgetCheck::Warning { + current_usd: daily_cost, + limit_usd: self.config.daily_limit_usd, + period: UsagePeriod::Day, + }); + } + + if projected_monthly >= monthly_warn_threshold { + return Ok(BudgetCheck::Warning { + current_usd: monthly_cost, + limit_usd: self.config.monthly_limit_usd, + period: UsagePeriod::Month, + }); + } + + Ok(BudgetCheck::Allowed) + } + + /// Record a usage event. + pub fn record_usage(&self, usage: TokenUsage) -> Result<()> { + if !self.config.enabled { + return Ok(()); + } + + if !usage.cost_usd.is_finite() || usage.cost_usd < 0.0 { + return Err(anyhow!( + "Token usage cost must be a finite, non-negative value" + )); + } + + let record = CostRecord::new(&self.session_id, usage); + + // Persist first for durability guarantees. + { + let mut storage = self.lock_storage()?; + storage.add_record(record.clone())?; + } + + // Then update in-memory session snapshot. + let mut session_costs = self.lock_session_costs()?; + session_costs.push(record); + + Ok(()) + } + + /// Get the current cost summary. + pub fn get_summary(&self) -> Result { + let (daily_cost, monthly_cost) = { + let mut storage = self.lock_storage()?; + storage.get_aggregated_costs()? + }; + + let session_costs = self.lock_session_costs()?; + let session_cost: f64 = session_costs + .iter() + .map(|record| record.usage.cost_usd) + .sum(); + let total_tokens: u64 = session_costs + .iter() + .map(|record| record.usage.total_tokens) + .sum(); + let request_count = session_costs.len(); + let by_model = build_session_model_stats(&session_costs); + + Ok(CostSummary { + session_cost_usd: session_cost, + daily_cost_usd: daily_cost, + monthly_cost_usd: monthly_cost, + total_tokens, + request_count, + by_model, + }) + } + + /// Get the daily cost for a specific date. + pub fn get_daily_cost(&self, date: NaiveDate) -> Result { + let storage = self.lock_storage()?; + storage.get_cost_for_date(date) + } + + /// Get the monthly cost for a specific month. + pub fn get_monthly_cost(&self, year: i32, month: u32) -> Result { + let storage = self.lock_storage()?; + storage.get_cost_for_month(year, month) + } +} + +fn resolve_storage_path(workspace_dir: &Path) -> Result { + let storage_path = workspace_dir.join("state").join("costs.jsonl"); + let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db"); + + if !storage_path.exists() && legacy_path.exists() { + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + if let Err(error) = fs::rename(&legacy_path, &storage_path) { + tracing::warn!( + "Failed to move legacy cost storage from {} to {}: {error}; falling back to copy", + legacy_path.display(), + storage_path.display() + ); + fs::copy(&legacy_path, &storage_path).with_context(|| { + format!( + "Failed to copy legacy cost storage from {} to {}", + legacy_path.display(), + storage_path.display() + ) + })?; + } + } + + Ok(storage_path) +} + +fn build_session_model_stats(session_costs: &[CostRecord]) -> HashMap { + let mut by_model: HashMap = HashMap::new(); + + for record in session_costs { + let entry = by_model + .entry(record.usage.model.clone()) + .or_insert_with(|| ModelStats { + model: record.usage.model.clone(), + cost_usd: 0.0, + total_tokens: 0, + request_count: 0, + }); + + entry.cost_usd += record.usage.cost_usd; + entry.total_tokens += record.usage.total_tokens; + entry.request_count += 1; + } + + by_model +} + +/// Persistent storage for cost records. +struct CostStorage { + path: PathBuf, + daily_cost_usd: f64, + monthly_cost_usd: f64, + cached_day: NaiveDate, + cached_year: i32, + cached_month: u32, +} + +impl CostStorage { + /// Create or open cost storage. + fn new(path: &Path) -> Result { + if let Some(parent) = path.parent() { + fs::create_dir_all(parent) + .with_context(|| format!("Failed to create directory {}", parent.display()))?; + } + + let now = Utc::now(); + let mut storage = Self { + path: path.to_path_buf(), + daily_cost_usd: 0.0, + monthly_cost_usd: 0.0, + cached_day: now.date_naive(), + cached_year: now.year(), + cached_month: now.month(), + }; + + storage.rebuild_aggregates( + storage.cached_day, + storage.cached_year, + storage.cached_month, + )?; + + Ok(storage) + } + + fn for_each_record(&self, mut on_record: F) -> Result<()> + where + F: FnMut(CostRecord), + { + if !self.path.exists() { + return Ok(()); + } + + let file = File::open(&self.path) + .with_context(|| format!("Failed to read cost storage from {}", self.path.display()))?; + let reader = BufReader::new(file); + + for (line_number, line) in reader.lines().enumerate() { + let raw_line = line.with_context(|| { + format!( + "Failed to read line {} from cost storage {}", + line_number + 1, + self.path.display() + ) + })?; + + let trimmed = raw_line.trim(); + if trimmed.is_empty() { + continue; + } + + match serde_json::from_str::(trimmed) { + Ok(record) => on_record(record), + Err(error) => { + tracing::warn!( + "Skipping malformed cost record at {}:{}: {error}", + self.path.display(), + line_number + 1 + ); + } + } + } + + Ok(()) + } + + fn rebuild_aggregates(&mut self, day: NaiveDate, year: i32, month: u32) -> Result<()> { + let mut daily_cost = 0.0; + let mut monthly_cost = 0.0; + + self.for_each_record(|record| { + let timestamp = record.usage.timestamp.naive_utc(); + + if timestamp.date() == day { + daily_cost += record.usage.cost_usd; + } + + if timestamp.year() == year && timestamp.month() == month { + monthly_cost += record.usage.cost_usd; + } + })?; + + self.daily_cost_usd = daily_cost; + self.monthly_cost_usd = monthly_cost; + self.cached_day = day; + self.cached_year = year; + self.cached_month = month; + + Ok(()) + } + + fn ensure_period_cache_current(&mut self) -> Result<()> { + let now = Utc::now(); + let day = now.date_naive(); + let year = now.year(); + let month = now.month(); + + if day != self.cached_day || year != self.cached_year || month != self.cached_month { + self.rebuild_aggregates(day, year, month)?; + } + + Ok(()) + } + + /// Add a new record. + fn add_record(&mut self, record: CostRecord) -> Result<()> { + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(&self.path) + .with_context(|| format!("Failed to open cost storage at {}", self.path.display()))?; + + writeln!(file, "{}", serde_json::to_string(&record)?) + .with_context(|| format!("Failed to write cost record to {}", self.path.display()))?; + file.sync_all() + .with_context(|| format!("Failed to sync cost storage at {}", self.path.display()))?; + + self.ensure_period_cache_current()?; + + let timestamp = record.usage.timestamp.naive_utc(); + if timestamp.date() == self.cached_day { + self.daily_cost_usd += record.usage.cost_usd; + } + if timestamp.year() == self.cached_year && timestamp.month() == self.cached_month { + self.monthly_cost_usd += record.usage.cost_usd; + } + + Ok(()) + } + + /// Get aggregated costs for current day and month. + fn get_aggregated_costs(&mut self) -> Result<(f64, f64)> { + self.ensure_period_cache_current()?; + Ok((self.daily_cost_usd, self.monthly_cost_usd)) + } + + /// Get cost for a specific date. + fn get_cost_for_date(&self, date: NaiveDate) -> Result { + let mut cost = 0.0; + + self.for_each_record(|record| { + if record.usage.timestamp.naive_utc().date() == date { + cost += record.usage.cost_usd; + } + })?; + + Ok(cost) + } + + /// Get cost for a specific month. + fn get_cost_for_month(&self, year: i32, month: u32) -> Result { + let mut cost = 0.0; + + self.for_each_record(|record| { + let timestamp = record.usage.timestamp.naive_utc(); + if timestamp.year() == year && timestamp.month() == month { + cost += record.usage.cost_usd; + } + })?; + + Ok(cost) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn enabled_config() -> CostConfig { + CostConfig { + enabled: true, + ..Default::default() + } + } + + #[test] + fn cost_tracker_initialization() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + assert!(!tracker.session_id().is_empty()); + } + + #[test] + fn budget_check_when_disabled() { + let tmp = TempDir::new().unwrap(); + let config = CostConfig { + enabled: false, + ..Default::default() + }; + + let tracker = CostTracker::new(config, tmp.path()).unwrap(); + let check = tracker.check_budget(1000.0).unwrap(); + assert!(matches!(check, BudgetCheck::Allowed)); + } + + #[test] + fn record_usage_and_get_summary() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + + let usage = TokenUsage::new("test/model", 1000, 500, 1.0, 2.0); + tracker.record_usage(usage).unwrap(); + + let summary = tracker.get_summary().unwrap(); + assert_eq!(summary.request_count, 1); + assert!(summary.session_cost_usd > 0.0); + assert_eq!(summary.by_model.len(), 1); + } + + #[test] + fn budget_exceeded_daily_limit() { + let tmp = TempDir::new().unwrap(); + let config = CostConfig { + enabled: true, + daily_limit_usd: 0.01, // Very low limit + ..Default::default() + }; + + let tracker = CostTracker::new(config, tmp.path()).unwrap(); + + // Record a usage that exceeds the limit + let usage = TokenUsage::new("test/model", 10000, 5000, 1.0, 2.0); // ~0.02 USD + tracker.record_usage(usage).unwrap(); + + let check = tracker.check_budget(0.01).unwrap(); + assert!(matches!(check, BudgetCheck::Exceeded { .. })); + } + + #[test] + fn summary_by_model_is_session_scoped() { + let tmp = TempDir::new().unwrap(); + let storage_path = resolve_storage_path(tmp.path()).unwrap(); + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent).unwrap(); + } + + let old_record = CostRecord::new( + "old-session", + TokenUsage::new("legacy/model", 500, 500, 1.0, 1.0), + ); + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(storage_path) + .unwrap(); + writeln!(file, "{}", serde_json::to_string(&old_record).unwrap()).unwrap(); + file.sync_all().unwrap(); + + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + tracker + .record_usage(TokenUsage::new("session/model", 1000, 1000, 1.0, 1.0)) + .unwrap(); + + let summary = tracker.get_summary().unwrap(); + assert_eq!(summary.by_model.len(), 1); + assert!(summary.by_model.contains_key("session/model")); + assert!(!summary.by_model.contains_key("legacy/model")); + } + + #[test] + fn malformed_lines_are_ignored_while_loading() { + let tmp = TempDir::new().unwrap(); + let storage_path = resolve_storage_path(tmp.path()).unwrap(); + if let Some(parent) = storage_path.parent() { + fs::create_dir_all(parent).unwrap(); + } + + let valid_usage = TokenUsage::new("test/model", 1000, 0, 1.0, 1.0); + let valid_record = CostRecord::new("session-a", valid_usage.clone()); + + let mut file = OpenOptions::new() + .create(true) + .append(true) + .open(storage_path) + .unwrap(); + writeln!(file, "{}", serde_json::to_string(&valid_record).unwrap()).unwrap(); + writeln!(file, "not-a-json-line").unwrap(); + writeln!(file).unwrap(); + file.sync_all().unwrap(); + + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + let today_cost = tracker.get_daily_cost(Utc::now().date_naive()).unwrap(); + assert!((today_cost - valid_usage.cost_usd).abs() < f64::EPSILON); + } + + #[test] + fn invalid_budget_estimate_is_rejected() { + let tmp = TempDir::new().unwrap(); + let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); + + let err = tracker.check_budget(f64::NAN).unwrap_err(); + assert!(err + .to_string() + .contains("Estimated cost must be a finite, non-negative value")); + } +} diff --git a/src/cost/types.rs b/src/cost/types.rs new file mode 100644 index 0000000..0e8d167 --- /dev/null +++ b/src/cost/types.rs @@ -0,0 +1,193 @@ +use serde::{Deserialize, Serialize}; + +/// Token usage information from a single API call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct TokenUsage { + /// Model identifier (e.g., "anthropic/claude-sonnet-4-20250514") + pub model: String, + /// Input/prompt tokens + pub input_tokens: u64, + /// Output/completion tokens + pub output_tokens: u64, + /// Total tokens + pub total_tokens: u64, + /// Calculated cost in USD + pub cost_usd: f64, + /// Timestamp of the request + pub timestamp: chrono::DateTime, +} + +impl TokenUsage { + fn sanitize_price(value: f64) -> f64 { + if value.is_finite() && value > 0.0 { + value + } else { + 0.0 + } + } + + /// Create a new token usage record. + pub fn new( + model: impl Into, + input_tokens: u64, + output_tokens: u64, + input_price_per_million: f64, + output_price_per_million: f64, + ) -> Self { + let model = model.into(); + let input_price_per_million = Self::sanitize_price(input_price_per_million); + let output_price_per_million = Self::sanitize_price(output_price_per_million); + let total_tokens = input_tokens.saturating_add(output_tokens); + + // Calculate cost: (tokens / 1M) * price_per_million + let input_cost = (input_tokens as f64 / 1_000_000.0) * input_price_per_million; + let output_cost = (output_tokens as f64 / 1_000_000.0) * output_price_per_million; + let cost_usd = input_cost + output_cost; + + Self { + model, + input_tokens, + output_tokens, + total_tokens, + cost_usd, + timestamp: chrono::Utc::now(), + } + } + + /// Get the total cost. + pub fn cost(&self) -> f64 { + self.cost_usd + } +} + +/// Time period for cost aggregation. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +pub enum UsagePeriod { + Session, + Day, + Month, +} + +/// A single cost record for persistent storage. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostRecord { + /// Unique identifier + pub id: String, + /// Token usage details + pub usage: TokenUsage, + /// Session identifier (for grouping) + pub session_id: String, +} + +impl CostRecord { + /// Create a new cost record. + pub fn new(session_id: impl Into, usage: TokenUsage) -> Self { + Self { + id: uuid::Uuid::new_v4().to_string(), + usage, + session_id: session_id.into(), + } + } +} + +/// Budget enforcement result. +#[derive(Debug, Clone)] +pub enum BudgetCheck { + /// Within budget, request can proceed + Allowed, + /// Warning threshold exceeded but request can proceed + Warning { + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, + }, + /// Budget exceeded, request blocked + Exceeded { + current_usd: f64, + limit_usd: f64, + period: UsagePeriod, + }, +} + +/// Cost summary for reporting. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CostSummary { + /// Total cost for the session + pub session_cost_usd: f64, + /// Total cost for the day + pub daily_cost_usd: f64, + /// Total cost for the month + pub monthly_cost_usd: f64, + /// Total tokens used + pub total_tokens: u64, + /// Number of requests + pub request_count: usize, + /// Breakdown by model + pub by_model: std::collections::HashMap, +} + +/// Statistics for a specific model. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelStats { + /// Model name + pub model: String, + /// Total cost for this model + pub cost_usd: f64, + /// Total tokens for this model + pub total_tokens: u64, + /// Number of requests for this model + pub request_count: usize, +} + +impl Default for CostSummary { + fn default() -> Self { + Self { + session_cost_usd: 0.0, + daily_cost_usd: 0.0, + monthly_cost_usd: 0.0, + total_tokens: 0, + request_count: 0, + by_model: std::collections::HashMap::new(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn token_usage_calculation() { + let usage = TokenUsage::new("test/model", 1000, 500, 3.0, 15.0); + + // Expected: (1000/1M)*3 + (500/1M)*15 = 0.003 + 0.0075 = 0.0105 + assert!((usage.cost_usd - 0.0105).abs() < 0.0001); + assert_eq!(usage.input_tokens, 1000); + assert_eq!(usage.output_tokens, 500); + assert_eq!(usage.total_tokens, 1500); + } + + #[test] + fn token_usage_zero_tokens() { + let usage = TokenUsage::new("test/model", 0, 0, 3.0, 15.0); + assert!(usage.cost_usd.abs() < f64::EPSILON); + assert_eq!(usage.total_tokens, 0); + } + + #[test] + fn token_usage_negative_or_non_finite_prices_are_clamped() { + let usage = TokenUsage::new("test/model", 1000, 1000, -3.0, f64::NAN); + assert!(usage.cost_usd.abs() < f64::EPSILON); + assert_eq!(usage.total_tokens, 2000); + } + + #[test] + fn cost_record_creation() { + let usage = TokenUsage::new("test/model", 100, 50, 1.0, 2.0); + let record = CostRecord::new("session-123", usage); + + assert_eq!(record.session_id, "session-123"); + assert!(!record.id.is_empty()); + assert_eq!(record.usage.model, "test/model"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 61a2bc6..588ada3 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -41,6 +41,7 @@ use serde::{Deserialize, Serialize}; pub mod agent; pub mod channels; pub mod config; +pub mod cost; pub mod cron; pub mod daemon; pub mod doctor; diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 5fee2b6..ddac80e 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -122,6 +122,7 @@ pub fn run_wizard() -> Result { browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), + cost: crate::config::CostConfig::default(), hardware: hardware_config, agents: std::collections::HashMap::new(), security: crate::config::SecurityConfig::default(), @@ -318,6 +319,7 @@ pub fn run_quick_setup( browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), identity: crate::config::IdentityConfig::default(), + cost: crate::config::CostConfig::default(), hardware: HardwareConfig::default(), agents: std::collections::HashMap::new(), security: crate::config::SecurityConfig::default(),