use super::types::{BudgetCheck, CostRecord, CostSummary, ModelStats, TokenUsage, UsagePeriod}; use crate::config::schema::CostConfig; use anyhow::{anyhow, Context, Result}; use chrono::{Datelike, NaiveDate, Utc}; use parking_lot::{Mutex, MutexGuard}; use std::collections::HashMap; use std::fs::{self, File, OpenOptions}; use std::io::{BufRead, BufReader, Write}; use std::path::{Path, PathBuf}; use std::sync::Arc; /// Cost tracker for API usage monitoring and budget enforcement. pub struct CostTracker { config: CostConfig, storage: Arc>, session_id: String, session_costs: Arc>>, } impl CostTracker { /// Create a new cost tracker. pub fn new(config: CostConfig, workspace_dir: &Path) -> Result { let storage_path = resolve_storage_path(workspace_dir)?; let storage = CostStorage::new(&storage_path).with_context(|| { format!("Failed to open cost storage at {}", storage_path.display()) })?; Ok(Self { config, storage: Arc::new(Mutex::new(storage)), session_id: uuid::Uuid::new_v4().to_string(), session_costs: Arc::new(Mutex::new(Vec::new())), }) } /// Get the session ID. pub fn session_id(&self) -> &str { &self.session_id } fn lock_storage(&self) -> MutexGuard<'_, CostStorage> { self.storage.lock() } fn lock_session_costs(&self) -> MutexGuard<'_, Vec> { self.session_costs.lock() } /// Check if a request is within budget. pub fn check_budget(&self, estimated_cost_usd: f64) -> Result { if !self.config.enabled { return Ok(BudgetCheck::Allowed); } if !estimated_cost_usd.is_finite() || estimated_cost_usd < 0.0 { return Err(anyhow!( "Estimated cost must be a finite, non-negative value" )); } let mut storage = self.lock_storage(); let (daily_cost, monthly_cost) = storage.get_aggregated_costs()?; // Check daily limit let projected_daily = daily_cost + estimated_cost_usd; if projected_daily > self.config.daily_limit_usd { return Ok(BudgetCheck::Exceeded { current_usd: daily_cost, limit_usd: self.config.daily_limit_usd, period: UsagePeriod::Day, }); } // Check monthly limit let projected_monthly = monthly_cost + estimated_cost_usd; if projected_monthly > self.config.monthly_limit_usd { return Ok(BudgetCheck::Exceeded { current_usd: monthly_cost, limit_usd: self.config.monthly_limit_usd, period: UsagePeriod::Month, }); } // Check warning thresholds let warn_threshold = f64::from(self.config.warn_at_percent.min(100)) / 100.0; let daily_warn_threshold = self.config.daily_limit_usd * warn_threshold; let monthly_warn_threshold = self.config.monthly_limit_usd * warn_threshold; if projected_daily >= daily_warn_threshold { return Ok(BudgetCheck::Warning { current_usd: daily_cost, limit_usd: self.config.daily_limit_usd, period: UsagePeriod::Day, }); } if projected_monthly >= monthly_warn_threshold { return Ok(BudgetCheck::Warning { current_usd: monthly_cost, limit_usd: self.config.monthly_limit_usd, period: UsagePeriod::Month, }); } Ok(BudgetCheck::Allowed) } /// Record a usage event. pub fn record_usage(&self, usage: TokenUsage) -> Result<()> { if !self.config.enabled { return Ok(()); } if !usage.cost_usd.is_finite() || usage.cost_usd < 0.0 { return Err(anyhow!( "Token usage cost must be a finite, non-negative value" )); } let record = CostRecord::new(&self.session_id, usage); // Persist first for durability guarantees. { let mut storage = self.lock_storage(); storage.add_record(record.clone())?; } // Then update in-memory session snapshot. let mut session_costs = self.lock_session_costs(); session_costs.push(record); Ok(()) } /// Get the current cost summary. pub fn get_summary(&self) -> Result { let (daily_cost, monthly_cost) = { let mut storage = self.lock_storage(); storage.get_aggregated_costs()? }; let session_costs = self.lock_session_costs(); let session_cost: f64 = session_costs .iter() .map(|record| record.usage.cost_usd) .sum(); let total_tokens: u64 = session_costs .iter() .map(|record| record.usage.total_tokens) .sum(); let request_count = session_costs.len(); let by_model = build_session_model_stats(&session_costs); Ok(CostSummary { session_cost_usd: session_cost, daily_cost_usd: daily_cost, monthly_cost_usd: monthly_cost, total_tokens, request_count, by_model, }) } /// Get the daily cost for a specific date. pub fn get_daily_cost(&self, date: NaiveDate) -> Result { let storage = self.lock_storage(); storage.get_cost_for_date(date) } /// Get the monthly cost for a specific month. pub fn get_monthly_cost(&self, year: i32, month: u32) -> Result { let storage = self.lock_storage(); storage.get_cost_for_month(year, month) } } fn resolve_storage_path(workspace_dir: &Path) -> Result { let storage_path = workspace_dir.join("state").join("costs.jsonl"); let legacy_path = workspace_dir.join(".zeroclaw").join("costs.db"); if !storage_path.exists() && legacy_path.exists() { if let Some(parent) = storage_path.parent() { fs::create_dir_all(parent) .with_context(|| format!("Failed to create directory {}", parent.display()))?; } if let Err(error) = fs::rename(&legacy_path, &storage_path) { tracing::warn!( "Failed to move legacy cost storage from {} to {}: {error}; falling back to copy", legacy_path.display(), storage_path.display() ); fs::copy(&legacy_path, &storage_path).with_context(|| { format!( "Failed to copy legacy cost storage from {} to {}", legacy_path.display(), storage_path.display() ) })?; } } Ok(storage_path) } fn build_session_model_stats(session_costs: &[CostRecord]) -> HashMap { let mut by_model: HashMap = HashMap::new(); for record in session_costs { let entry = by_model .entry(record.usage.model.clone()) .or_insert_with(|| ModelStats { model: record.usage.model.clone(), cost_usd: 0.0, total_tokens: 0, request_count: 0, }); entry.cost_usd += record.usage.cost_usd; entry.total_tokens += record.usage.total_tokens; entry.request_count += 1; } by_model } /// Persistent storage for cost records. struct CostStorage { path: PathBuf, daily_cost_usd: f64, monthly_cost_usd: f64, cached_day: NaiveDate, cached_year: i32, cached_month: u32, } impl CostStorage { /// Create or open cost storage. fn new(path: &Path) -> Result { if let Some(parent) = path.parent() { fs::create_dir_all(parent) .with_context(|| format!("Failed to create directory {}", parent.display()))?; } let now = Utc::now(); let mut storage = Self { path: path.to_path_buf(), daily_cost_usd: 0.0, monthly_cost_usd: 0.0, cached_day: now.date_naive(), cached_year: now.year(), cached_month: now.month(), }; storage.rebuild_aggregates( storage.cached_day, storage.cached_year, storage.cached_month, )?; Ok(storage) } fn for_each_record(&self, mut on_record: F) -> Result<()> where F: FnMut(CostRecord), { if !self.path.exists() { return Ok(()); } let file = File::open(&self.path) .with_context(|| format!("Failed to read cost storage from {}", self.path.display()))?; let reader = BufReader::new(file); for (line_number, line) in reader.lines().enumerate() { let raw_line = line.with_context(|| { format!( "Failed to read line {} from cost storage {}", line_number + 1, self.path.display() ) })?; let trimmed = raw_line.trim(); if trimmed.is_empty() { continue; } match serde_json::from_str::(trimmed) { Ok(record) => on_record(record), Err(error) => { tracing::warn!( "Skipping malformed cost record at {}:{}: {error}", self.path.display(), line_number + 1 ); } } } Ok(()) } fn rebuild_aggregates(&mut self, day: NaiveDate, year: i32, month: u32) -> Result<()> { let mut daily_cost = 0.0; let mut monthly_cost = 0.0; self.for_each_record(|record| { let timestamp = record.usage.timestamp.naive_utc(); if timestamp.date() == day { daily_cost += record.usage.cost_usd; } if timestamp.year() == year && timestamp.month() == month { monthly_cost += record.usage.cost_usd; } })?; self.daily_cost_usd = daily_cost; self.monthly_cost_usd = monthly_cost; self.cached_day = day; self.cached_year = year; self.cached_month = month; Ok(()) } fn ensure_period_cache_current(&mut self) -> Result<()> { let now = Utc::now(); let day = now.date_naive(); let year = now.year(); let month = now.month(); if day != self.cached_day || year != self.cached_year || month != self.cached_month { self.rebuild_aggregates(day, year, month)?; } Ok(()) } /// Add a new record. fn add_record(&mut self, record: CostRecord) -> Result<()> { let mut file = OpenOptions::new() .create(true) .append(true) .open(&self.path) .with_context(|| format!("Failed to open cost storage at {}", self.path.display()))?; writeln!(file, "{}", serde_json::to_string(&record)?) .with_context(|| format!("Failed to write cost record to {}", self.path.display()))?; file.sync_all() .with_context(|| format!("Failed to sync cost storage at {}", self.path.display()))?; self.ensure_period_cache_current()?; let timestamp = record.usage.timestamp.naive_utc(); if timestamp.date() == self.cached_day { self.daily_cost_usd += record.usage.cost_usd; } if timestamp.year() == self.cached_year && timestamp.month() == self.cached_month { self.monthly_cost_usd += record.usage.cost_usd; } Ok(()) } /// Get aggregated costs for current day and month. fn get_aggregated_costs(&mut self) -> Result<(f64, f64)> { self.ensure_period_cache_current()?; Ok((self.daily_cost_usd, self.monthly_cost_usd)) } /// Get cost for a specific date. fn get_cost_for_date(&self, date: NaiveDate) -> Result { let mut cost = 0.0; self.for_each_record(|record| { if record.usage.timestamp.naive_utc().date() == date { cost += record.usage.cost_usd; } })?; Ok(cost) } /// Get cost for a specific month. fn get_cost_for_month(&self, year: i32, month: u32) -> Result { let mut cost = 0.0; self.for_each_record(|record| { let timestamp = record.usage.timestamp.naive_utc(); if timestamp.year() == year && timestamp.month() == month { cost += record.usage.cost_usd; } })?; Ok(cost) } } #[cfg(test)] mod tests { use super::*; use tempfile::TempDir; fn enabled_config() -> CostConfig { CostConfig { enabled: true, ..Default::default() } } #[test] fn cost_tracker_initialization() { let tmp = TempDir::new().unwrap(); let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); assert!(!tracker.session_id().is_empty()); } #[test] fn budget_check_when_disabled() { let tmp = TempDir::new().unwrap(); let config = CostConfig { enabled: false, ..Default::default() }; let tracker = CostTracker::new(config, tmp.path()).unwrap(); let check = tracker.check_budget(1000.0).unwrap(); assert!(matches!(check, BudgetCheck::Allowed)); } #[test] fn record_usage_and_get_summary() { let tmp = TempDir::new().unwrap(); let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); let usage = TokenUsage::new("test/model", 1000, 500, 1.0, 2.0); tracker.record_usage(usage).unwrap(); let summary = tracker.get_summary().unwrap(); assert_eq!(summary.request_count, 1); assert!(summary.session_cost_usd > 0.0); assert_eq!(summary.by_model.len(), 1); } #[test] fn budget_exceeded_daily_limit() { let tmp = TempDir::new().unwrap(); let config = CostConfig { enabled: true, daily_limit_usd: 0.01, // Very low limit ..Default::default() }; let tracker = CostTracker::new(config, tmp.path()).unwrap(); // Record a usage that exceeds the limit let usage = TokenUsage::new("test/model", 10000, 5000, 1.0, 2.0); // ~0.02 USD tracker.record_usage(usage).unwrap(); let check = tracker.check_budget(0.01).unwrap(); assert!(matches!(check, BudgetCheck::Exceeded { .. })); } #[test] fn summary_by_model_is_session_scoped() { let tmp = TempDir::new().unwrap(); let storage_path = resolve_storage_path(tmp.path()).unwrap(); if let Some(parent) = storage_path.parent() { fs::create_dir_all(parent).unwrap(); } let old_record = CostRecord::new( "old-session", TokenUsage::new("legacy/model", 500, 500, 1.0, 1.0), ); let mut file = OpenOptions::new() .create(true) .append(true) .open(storage_path) .unwrap(); writeln!(file, "{}", serde_json::to_string(&old_record).unwrap()).unwrap(); file.sync_all().unwrap(); let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); tracker .record_usage(TokenUsage::new("session/model", 1000, 1000, 1.0, 1.0)) .unwrap(); let summary = tracker.get_summary().unwrap(); assert_eq!(summary.by_model.len(), 1); assert!(summary.by_model.contains_key("session/model")); assert!(!summary.by_model.contains_key("legacy/model")); } #[test] fn malformed_lines_are_ignored_while_loading() { let tmp = TempDir::new().unwrap(); let storage_path = resolve_storage_path(tmp.path()).unwrap(); if let Some(parent) = storage_path.parent() { fs::create_dir_all(parent).unwrap(); } let valid_usage = TokenUsage::new("test/model", 1000, 0, 1.0, 1.0); let valid_record = CostRecord::new("session-a", valid_usage.clone()); let mut file = OpenOptions::new() .create(true) .append(true) .open(storage_path) .unwrap(); writeln!(file, "{}", serde_json::to_string(&valid_record).unwrap()).unwrap(); writeln!(file, "not-a-json-line").unwrap(); writeln!(file).unwrap(); file.sync_all().unwrap(); let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); let today_cost = tracker.get_daily_cost(Utc::now().date_naive()).unwrap(); assert!((today_cost - valid_usage.cost_usd).abs() < f64::EPSILON); } #[test] fn invalid_budget_estimate_is_rejected() { let tmp = TempDir::new().unwrap(); let tracker = CostTracker::new(enabled_config(), tmp.path()).unwrap(); let err = tracker.check_budget(f64::NAN).unwrap_err(); assert!(err .to_string() .contains("Estimated cost must be a finite, non-negative value")); } }