diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 55bf8e0..1acc502 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -20,7 +20,7 @@ pub use telegram::TelegramChannel; pub use traits::Channel; pub use whatsapp::WhatsAppChannel; -use crate::agent::loop_::{agent_turn, build_tool_instructions}; +use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop}; use crate::config::Config; use crate::identity; use crate::memory::{self, Memory}; @@ -181,7 +181,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let llm_result = tokio::time::timeout( Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), - agent_turn( + run_tool_call_loop( ctx.provider.as_ref(), &mut history, ctx.tools_registry.as_ref(), diff --git a/src/main.rs b/src/main.rs index 6c59090..426fdfd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -178,6 +178,12 @@ enum Commands { cron_command: CronCommands, }, + /// Manage provider model catalogs + Models { + #[command(subcommand)] + model_command: ModelCommands, + }, + /// Manage channels (telegram, discord, slack) Channel { #[command(subcommand)] @@ -235,6 +241,20 @@ enum CronCommands { }, } +#[derive(Subcommand, Debug)] +enum ModelCommands { + /// Refresh and cache provider models + Refresh { + /// Provider name (defaults to configured default provider) + #[arg(long)] + provider: Option, + + /// Force live refresh and ignore fresh cache + #[arg(long)] + force: bool, + }, +} + #[derive(Subcommand, Debug)] enum ChannelCommands { /// List configured channels @@ -435,6 +455,12 @@ async fn main() -> Result<()> { Commands::Cron { cron_command } => cron::handle_command(cron_command, &config), + Commands::Models { model_command } => match model_command { + ModelCommands::Refresh { provider, force } => { + onboard::run_models_refresh(&config, provider.as_deref(), force) + } + }, + Commands::Service { service_command } => service::handle_command(&service_command, &config), Commands::Doctor => doctor::run(&config), diff --git a/src/onboard/mod.rs b/src/onboard/mod.rs index c3658bd..5117897 100644 --- a/src/onboard/mod.rs +++ b/src/onboard/mod.rs @@ -1,6 +1,6 @@ pub mod wizard; -pub use wizard::{run_channels_repair_wizard, run_quick_setup, run_wizard}; +pub use wizard::{run_channels_repair_wizard, run_models_refresh, run_quick_setup, run_wizard}; #[cfg(test)] mod tests { @@ -13,5 +13,6 @@ mod tests { assert_reexport_exists(run_wizard); assert_reexport_exists(run_channels_repair_wizard); assert_reexport_exists(run_quick_setup); + assert_reexport_exists(run_models_refresh); } } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index c749d07..0447d23 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -8,8 +8,12 @@ use crate::hardware::{self, HardwareConfig}; use anyhow::{Context, Result}; use console::style; use dialoguer::{Confirm, Input, Select}; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::BTreeSet; use std::fs; use std::path::{Path, PathBuf}; +use std::time::Duration; // ── Project context collected during wizard ────────────────────── @@ -39,6 +43,12 @@ const BANNER: &str = r" ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ "; +const LIVE_MODEL_MAX_OPTIONS: usize = 120; +const MODEL_PREVIEW_LIMIT: usize = 20; +const MODEL_CACHE_FILE: &str = "models_cache.json"; +const MODEL_CACHE_TTL_SECS: u64 = 12 * 60 * 60; +const CUSTOM_MODEL_SENTINEL: &str = "__custom_model__"; + // ── Main wizard entry point ────────────────────────────────────── pub fn run_wizard() -> Result { @@ -60,7 +70,7 @@ pub fn run_wizard() -> Result { let (workspace_dir, config_path) = setup_workspace()?; print_step(2, 9, "AI Provider & API Key"); - let (provider, api_key, model) = setup_provider()?; + let (provider, api_key, model) = setup_provider(&workspace_dir)?; print_step(3, 9, "Channels (How You Talk to ZeroClaw)"); let channels_config = setup_channels()?; @@ -406,17 +416,766 @@ pub fn run_quick_setup( Ok(config) } +fn canonical_provider_name(provider_name: &str) -> &str { + match provider_name { + "grok" => "xai", + "together" => "together-ai", + "google" | "google-gemini" => "gemini", + _ => provider_name, + } +} + /// Pick a sensible default model for the given provider. fn default_model_for_provider(provider: &str) -> String { - match provider { + match canonical_provider_name(provider) { "anthropic" => "claude-sonnet-4-20250514".into(), - "openai" => "gpt-4o".into(), + "openai" => "gpt-5.2".into(), "glm" | "zhipu" | "zai" | "z.ai" => "glm-5".into(), "ollama" => "llama3.2".into(), "groq" => "llama-3.3-70b-versatile".into(), "deepseek" => "deepseek-chat".into(), - "gemini" | "google" | "google-gemini" => "gemini-2.0-flash".into(), - _ => "anthropic/claude-sonnet-4".into(), + "gemini" => "gemini-2.5-pro".into(), + _ => "anthropic/claude-sonnet-4.5".into(), + } +} + +fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> { + match canonical_provider_name(provider_name) { + "openrouter" => vec![ + ( + "anthropic/claude-sonnet-4.5".to_string(), + "Claude Sonnet 4.5 (balanced, recommended)".to_string(), + ), + ( + "openai/gpt-5.2".to_string(), + "GPT-5.2 (latest flagship)".to_string(), + ), + ( + "openai/gpt-5-mini".to_string(), + "GPT-5 mini (fast, cost-efficient)".to_string(), + ), + ( + "google/gemini-3-pro-preview".to_string(), + "Gemini 3 Pro Preview (frontier reasoning)".to_string(), + ), + ( + "x-ai/grok-4.1-fast".to_string(), + "Grok 4.1 Fast (reasoning + speed)".to_string(), + ), + ( + "deepseek/deepseek-v3.2".to_string(), + "DeepSeek V3.2 (agentic + affordable)".to_string(), + ), + ( + "meta-llama/llama-4-maverick".to_string(), + "Llama 4 Maverick (open model)".to_string(), + ), + ], + "anthropic" => vec![ + ( + "claude-sonnet-4-20250514".to_string(), + "Claude Sonnet 4 (balanced, recommended)".to_string(), + ), + ( + "claude-opus-4-1-20250805".to_string(), + "Claude Opus 4.1 (best quality)".to_string(), + ), + ( + "claude-3-5-haiku-20241022".to_string(), + "Claude 3.5 Haiku (fastest, cheapest)".to_string(), + ), + ], + "openai" => vec![ + ( + "gpt-5.2".to_string(), + "GPT-5.2 (latest coding/agentic flagship)".to_string(), + ), + ( + "gpt-5-mini".to_string(), + "GPT-5 mini (faster, cheaper)".to_string(), + ), + ( + "gpt-5-nano".to_string(), + "GPT-5 nano (lowest latency/cost)".to_string(), + ), + ( + "gpt-5.2-codex".to_string(), + "GPT-5.2 Codex (agentic coding)".to_string(), + ), + ], + "venice" => vec![ + ( + "llama-3.3-70b".to_string(), + "Llama 3.3 70B (default, fast)".to_string(), + ), + ( + "claude-opus-45".to_string(), + "Claude Opus 4.5 via Venice (strongest)".to_string(), + ), + ( + "llama-3.1-405b".to_string(), + "Llama 3.1 405B (largest open source)".to_string(), + ), + ], + "groq" => vec![ + ( + "llama-3.3-70b-versatile".to_string(), + "Llama 3.3 70B (fast, recommended)".to_string(), + ), + ( + "openai/gpt-oss-120b".to_string(), + "GPT-OSS 120B (strong open-weight)".to_string(), + ), + ( + "openai/gpt-oss-20b".to_string(), + "GPT-OSS 20B (cost-efficient open-weight)".to_string(), + ), + ], + "mistral" => vec![ + ( + "mistral-large-latest".to_string(), + "Mistral Large (latest flagship)".to_string(), + ), + ( + "mistral-medium-latest".to_string(), + "Mistral Medium (balanced)".to_string(), + ), + ( + "codestral-latest".to_string(), + "Codestral (code-focused)".to_string(), + ), + ( + "devstral-latest".to_string(), + "Devstral (software engineering specialist)".to_string(), + ), + ], + "deepseek" => vec![ + ( + "deepseek-chat".to_string(), + "DeepSeek Chat (mapped to V3.2 non-thinking)".to_string(), + ), + ( + "deepseek-reasoner".to_string(), + "DeepSeek Reasoner (mapped to V3.2 thinking)".to_string(), + ), + ], + "xai" => vec![ + ( + "grok-4-1-fast-reasoning".to_string(), + "Grok 4.1 Fast Reasoning (recommended)".to_string(), + ), + ( + "grok-4-1-fast-non-reasoning".to_string(), + "Grok 4.1 Fast Non-Reasoning (low latency)".to_string(), + ), + ( + "grok-code-fast-1".to_string(), + "Grok Code Fast 1 (coding specialist)".to_string(), + ), + ("grok-4".to_string(), "Grok 4 (max quality)".to_string()), + ], + "perplexity" => vec![ + ( + "sonar-pro".to_string(), + "Sonar Pro (flagship web-grounded model)".to_string(), + ), + ( + "sonar-reasoning-pro".to_string(), + "Sonar Reasoning Pro (complex multi-step reasoning)".to_string(), + ), + ( + "sonar-deep-research".to_string(), + "Sonar Deep Research (long-form research)".to_string(), + ), + ("sonar".to_string(), "Sonar (search, fast)".to_string()), + ], + "fireworks" => vec![ + ( + "accounts/fireworks/models/llama-v3p3-70b-instruct".to_string(), + "Llama 3.3 70B".to_string(), + ), + ( + "accounts/fireworks/models/mixtral-8x22b-instruct".to_string(), + "Mixtral 8x22B".to_string(), + ), + ], + "together-ai" => vec![ + ( + "meta-llama/Llama-3.3-70B-Instruct-Turbo".to_string(), + "Llama 3.3 70B Instruct Turbo (recommended)".to_string(), + ), + ( + "moonshotai/Kimi-K2.5".to_string(), + "Kimi K2.5 (reasoning + coding)".to_string(), + ), + ( + "deepseek-ai/DeepSeek-V3.1".to_string(), + "DeepSeek V3.1 (strong value)".to_string(), + ), + ], + "cohere" => vec![ + ( + "command-a-03-2025".to_string(), + "Command A (flagship enterprise model)".to_string(), + ), + ( + "command-a-reasoning-08-2025".to_string(), + "Command A Reasoning (agentic reasoning)".to_string(), + ), + ( + "command-r-08-2024".to_string(), + "Command R (stable fast baseline)".to_string(), + ), + ], + "moonshot" => vec![ + ( + "kimi-latest".to_string(), + "Kimi Latest (rolling latest assistant model)".to_string(), + ), + ( + "kimi-k2-0905-preview".to_string(), + "Kimi K2 0905 Preview (strong coding)".to_string(), + ), + ( + "kimi-thinking-preview".to_string(), + "Kimi Thinking Preview (deep reasoning)".to_string(), + ), + ], + "glm" | "zhipu" | "zai" | "z.ai" => vec![ + ( + "glm-4.7".to_string(), + "GLM-4.7 (latest flagship)".to_string(), + ), + ("glm-5".to_string(), "GLM-5 (high reasoning)".to_string()), + ( + "glm-4-plus".to_string(), + "GLM-4 Plus (stable baseline)".to_string(), + ), + ], + "minimax" => vec![ + ( + "MiniMax-M2.5".to_string(), + "MiniMax M2.5 (latest flagship)".to_string(), + ), + ( + "MiniMax-M2.1".to_string(), + "MiniMax M2.1 (strong coding/reasoning)".to_string(), + ), + ( + "MiniMax-M2.1-lightning".to_string(), + "MiniMax M2.1 Lightning (fast)".to_string(), + ), + ], + "ollama" => vec![ + ( + "llama3.2".to_string(), + "Llama 3.2 (recommended local)".to_string(), + ), + ("mistral".to_string(), "Mistral 7B".to_string()), + ("codellama".to_string(), "Code Llama".to_string()), + ("phi3".to_string(), "Phi-3 (small, fast)".to_string()), + ], + "gemini" => vec![ + ( + "gemini-3-pro-preview".to_string(), + "Gemini 3 Pro Preview (latest frontier reasoning)".to_string(), + ), + ( + "gemini-2.5-pro".to_string(), + "Gemini 2.5 Pro (stable reasoning)".to_string(), + ), + ( + "gemini-2.5-flash".to_string(), + "Gemini 2.5 Flash (best price/performance)".to_string(), + ), + ( + "gemini-2.5-flash-lite".to_string(), + "Gemini 2.5 Flash-Lite (lowest cost)".to_string(), + ), + ], + _ => vec![("default".to_string(), "Default model".to_string())], + } +} + +fn supports_live_model_fetch(provider_name: &str) -> bool { + matches!( + canonical_provider_name(provider_name), + "openrouter" + | "openai" + | "anthropic" + | "groq" + | "mistral" + | "deepseek" + | "xai" + | "together-ai" + | "gemini" + | "ollama" + ) +} + +fn build_model_fetch_client() -> Result { + reqwest::blocking::Client::builder() + .timeout(Duration::from_secs(8)) + .connect_timeout(Duration::from_secs(4)) + .build() + .context("failed to build model-fetch HTTP client") +} + +fn normalize_model_ids(ids: Vec) -> Vec { + let mut unique = BTreeSet::new(); + for id in ids { + let trimmed = id.trim(); + if !trimmed.is_empty() { + unique.insert(trimmed.to_string()); + } + } + unique.into_iter().collect() +} + +fn parse_openai_compatible_model_ids(payload: &Value) -> Vec { + let mut models = Vec::new(); + + if let Some(data) = payload.get("data").and_then(Value::as_array) { + for model in data { + if let Some(id) = model.get("id").and_then(Value::as_str) { + models.push(id.to_string()); + } + } + } else if let Some(data) = payload.as_array() { + for model in data { + if let Some(id) = model.get("id").and_then(Value::as_str) { + models.push(id.to_string()); + } + } + } + + normalize_model_ids(models) +} + +fn parse_gemini_model_ids(payload: &Value) -> Vec { + let Some(models) = payload.get("models").and_then(Value::as_array) else { + return Vec::new(); + }; + + let mut ids = Vec::new(); + for model in models { + let supports_generate_content = model + .get("supportedGenerationMethods") + .and_then(Value::as_array) + .is_none_or(|methods| { + methods + .iter() + .any(|method| method.as_str() == Some("generateContent")) + }); + + if !supports_generate_content { + continue; + } + + if let Some(name) = model.get("name").and_then(Value::as_str) { + ids.push(name.trim_start_matches("models/").to_string()); + } + } + + normalize_model_ids(ids) +} + +fn parse_ollama_model_ids(payload: &Value) -> Vec { + let Some(models) = payload.get("models").and_then(Value::as_array) else { + return Vec::new(); + }; + + let mut ids = Vec::new(); + for model in models { + if let Some(name) = model.get("name").and_then(Value::as_str) { + ids.push(name.to_string()); + } + } + + normalize_model_ids(ids) +} + +fn fetch_openai_compatible_models(endpoint: &str, api_key: Option<&str>) -> Result> { + let Some(api_key) = api_key else { + return Ok(Vec::new()); + }; + + let client = build_model_fetch_client()?; + let payload: Value = client + .get(endpoint) + .bearer_auth(api_key) + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .with_context(|| format!("model fetch failed: GET {endpoint}"))? + .json() + .context("failed to parse model list response")?; + + Ok(parse_openai_compatible_model_ids(&payload)) +} + +fn fetch_openrouter_models(api_key: Option<&str>) -> Result> { + let client = build_model_fetch_client()?; + let mut request = client.get("https://openrouter.ai/api/v1/models"); + if let Some(api_key) = api_key { + request = request.bearer_auth(api_key); + } + + let payload: Value = request + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET https://openrouter.ai/api/v1/models")? + .json() + .context("failed to parse OpenRouter model list response")?; + + Ok(parse_openai_compatible_model_ids(&payload)) +} + +fn fetch_anthropic_models(api_key: Option<&str>) -> Result> { + let Some(api_key) = api_key else { + return Ok(Vec::new()); + }; + + let client = build_model_fetch_client()?; + let payload: Value = client + .get("https://api.anthropic.com/v1/models") + .header("x-api-key", api_key) + .header("anthropic-version", "2023-06-01") + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET https://api.anthropic.com/v1/models")? + .json() + .context("failed to parse Anthropic model list response")?; + + Ok(parse_openai_compatible_model_ids(&payload)) +} + +fn fetch_gemini_models(api_key: Option<&str>) -> Result> { + let Some(api_key) = api_key else { + return Ok(Vec::new()); + }; + + let client = build_model_fetch_client()?; + let payload: Value = client + .get("https://generativelanguage.googleapis.com/v1beta/models") + .query(&[("key", api_key), ("pageSize", "200")]) + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET Gemini models")? + .json() + .context("failed to parse Gemini model list response")?; + + Ok(parse_gemini_model_ids(&payload)) +} + +fn fetch_ollama_models() -> Result> { + let client = build_model_fetch_client()?; + let payload: Value = client + .get("http://localhost:11434/api/tags") + .send() + .and_then(reqwest::blocking::Response::error_for_status) + .context("model fetch failed: GET http://localhost:11434/api/tags")? + .json() + .context("failed to parse Ollama model list response")?; + + Ok(parse_ollama_model_ids(&payload)) +} + +fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result> { + let provider_name = canonical_provider_name(provider_name); + let api_key = if api_key.trim().is_empty() { + std::env::var(provider_env_var(provider_name)) + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()) + } else { + Some(api_key.trim().to_string()) + }; + + let models = match provider_name { + "openrouter" => fetch_openrouter_models(api_key.as_deref())?, + "openai" => { + fetch_openai_compatible_models("https://api.openai.com/v1/models", api_key.as_deref())? + } + "groq" => fetch_openai_compatible_models( + "https://api.groq.com/openai/v1/models", + api_key.as_deref(), + )?, + "mistral" => { + fetch_openai_compatible_models("https://api.mistral.ai/v1/models", api_key.as_deref())? + } + "deepseek" => fetch_openai_compatible_models( + "https://api.deepseek.com/v1/models", + api_key.as_deref(), + )?, + "xai" => fetch_openai_compatible_models("https://api.x.ai/v1/models", api_key.as_deref())?, + "together-ai" => fetch_openai_compatible_models( + "https://api.together.xyz/v1/models", + api_key.as_deref(), + )?, + "anthropic" => fetch_anthropic_models(api_key.as_deref())?, + "gemini" => fetch_gemini_models(api_key.as_deref())?, + "ollama" => fetch_ollama_models()?, + _ => Vec::new(), + }; + + Ok(models) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct ModelCacheEntry { + provider: String, + fetched_at_unix: u64, + models: Vec, +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct ModelCacheState { + entries: Vec, +} + +#[derive(Debug, Clone)] +struct CachedModels { + models: Vec, + age_secs: u64, +} + +fn model_cache_path(workspace_dir: &Path) -> PathBuf { + workspace_dir.join("state").join(MODEL_CACHE_FILE) +} + +fn now_unix_secs() -> u64 { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs()) +} + +fn load_model_cache_state(workspace_dir: &Path) -> Result { + let path = model_cache_path(workspace_dir); + if !path.exists() { + return Ok(ModelCacheState::default()); + } + + let raw = fs::read_to_string(&path) + .with_context(|| format!("failed to read model cache at {}", path.display()))?; + + match serde_json::from_str::(&raw) { + Ok(state) => Ok(state), + Err(_) => Ok(ModelCacheState::default()), + } +} + +fn save_model_cache_state(workspace_dir: &Path, state: &ModelCacheState) -> Result<()> { + let path = model_cache_path(workspace_dir); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent).with_context(|| { + format!( + "failed to create model cache directory {}", + parent.display() + ) + })?; + } + + let json = serde_json::to_vec_pretty(state).context("failed to serialize model cache")?; + fs::write(&path, json) + .with_context(|| format!("failed to write model cache at {}", path.display()))?; + + Ok(()) +} + +fn cache_live_models_for_provider( + workspace_dir: &Path, + provider_name: &str, + models: &[String], +) -> Result<()> { + let normalized_models = normalize_model_ids(models.to_vec()); + if normalized_models.is_empty() { + return Ok(()); + } + + let mut state = load_model_cache_state(workspace_dir)?; + let now = now_unix_secs(); + + if let Some(entry) = state + .entries + .iter_mut() + .find(|entry| entry.provider == provider_name) + { + entry.fetched_at_unix = now; + entry.models = normalized_models; + } else { + state.entries.push(ModelCacheEntry { + provider: provider_name.to_string(), + fetched_at_unix: now, + models: normalized_models, + }); + } + + save_model_cache_state(workspace_dir, &state) +} + +fn load_cached_models_for_provider_internal( + workspace_dir: &Path, + provider_name: &str, + ttl_secs: Option, +) -> Result> { + let state = load_model_cache_state(workspace_dir)?; + let now = now_unix_secs(); + + let Some(entry) = state + .entries + .into_iter() + .find(|entry| entry.provider == provider_name) + else { + return Ok(None); + }; + + if entry.models.is_empty() { + return Ok(None); + } + + let age_secs = now.saturating_sub(entry.fetched_at_unix); + if ttl_secs.is_some_and(|ttl| age_secs > ttl) { + return Ok(None); + } + + Ok(Some(CachedModels { + models: entry.models, + age_secs, + })) +} + +fn load_cached_models_for_provider( + workspace_dir: &Path, + provider_name: &str, + ttl_secs: u64, +) -> Result> { + load_cached_models_for_provider_internal(workspace_dir, provider_name, Some(ttl_secs)) +} + +fn load_any_cached_models_for_provider( + workspace_dir: &Path, + provider_name: &str, +) -> Result> { + load_cached_models_for_provider_internal(workspace_dir, provider_name, None) +} + +fn humanize_age(age_secs: u64) -> String { + if age_secs < 60 { + format!("{age_secs}s") + } else if age_secs < 60 * 60 { + format!("{}m", age_secs / 60) + } else { + format!("{}h", age_secs / (60 * 60)) + } +} + +fn build_model_options(model_ids: Vec, source: &str) -> Vec<(String, String)> { + model_ids + .into_iter() + .map(|model_id| { + let label = format!("{model_id} ({source})"); + (model_id, label) + }) + .collect() +} + +fn print_model_preview(models: &[String]) { + for model in models.iter().take(MODEL_PREVIEW_LIMIT) { + println!(" {} {model}", style("-")); + } + + if models.len() > MODEL_PREVIEW_LIMIT { + println!( + " {} ... and {} more", + style("-"), + models.len() - MODEL_PREVIEW_LIMIT + ); + } +} + +pub fn run_models_refresh( + config: &Config, + provider_override: Option<&str>, + force: bool, +) -> Result<()> { + let provider_name = provider_override + .or(config.default_provider.as_deref()) + .unwrap_or("openrouter") + .trim() + .to_string(); + + if provider_name.is_empty() { + anyhow::bail!("Provider name cannot be empty"); + } + + if !supports_live_model_fetch(&provider_name) { + anyhow::bail!("Provider '{provider_name}' does not support live model discovery yet"); + } + + if !force { + if let Some(cached) = load_cached_models_for_provider( + &config.workspace_dir, + &provider_name, + MODEL_CACHE_TTL_SECS, + )? { + println!( + "Using cached model list for '{}' (updated {} ago):", + provider_name, + humanize_age(cached.age_secs) + ); + print_model_preview(&cached.models); + println!(); + println!( + "Tip: run `zeroclaw models refresh --force --provider {}` to fetch latest now.", + provider_name + ); + return Ok(()); + } + } + + let api_key = config.api_key.clone().unwrap_or_default(); + + match fetch_live_models_for_provider(&provider_name, &api_key) { + Ok(models) if !models.is_empty() => { + cache_live_models_for_provider(&config.workspace_dir, &provider_name, &models)?; + println!( + "Refreshed '{}' model cache with {} models.", + provider_name, + models.len() + ); + print_model_preview(&models); + Ok(()) + } + Ok(_) => { + if let Some(stale_cache) = + load_any_cached_models_for_provider(&config.workspace_dir, &provider_name)? + { + println!( + "Provider returned no models; using stale cache (updated {} ago):", + humanize_age(stale_cache.age_secs) + ); + print_model_preview(&stale_cache.models); + return Ok(()); + } + + anyhow::bail!("Provider '{}' returned an empty model list", provider_name) + } + Err(error) => { + if let Some(stale_cache) = + load_any_cached_models_for_provider(&config.workspace_dir, &provider_name)? + { + println!( + "Live refresh failed ({}). Falling back to stale cache (updated {} ago):", + error, + humanize_age(stale_cache.age_secs) + ); + print_model_preview(&stale_cache.models); + return Ok(()); + } + + Err(error) + .with_context(|| format!("failed to refresh models for provider '{provider_name}'")) + } } } @@ -481,7 +1240,7 @@ fn setup_workspace() -> Result<(PathBuf, PathBuf)> { // ── Step 2: Provider & API Key ─────────────────────────────────── #[allow(clippy::too_many_lines)] -fn setup_provider() -> Result<(String, String, String)> { +fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String)> { // ── Tier selection ── let tiers = vec![ "⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI, Gemini)", @@ -519,7 +1278,7 @@ fn setup_provider() -> Result<(String, String, String)> { 1 => vec![ ("groq", "Groq — ultra-fast LPU inference"), ("fireworks", "Fireworks AI — fast open-source inference"), - ("together", "Together AI — open-source model hosting"), + ("together-ai", "Together AI — open-source model hosting"), ], 2 => vec![ ("vercel", "Vercel AI Gateway"), @@ -597,10 +1356,7 @@ fn setup_provider() -> Result<(String, String, String)> { let api_key = if provider_name == "ollama" { print_bullet("Ollama runs locally — no API key needed!"); String::new() - } else if provider_name == "gemini" - || provider_name == "google" - || provider_name == "google-gemini" - { + } else if canonical_provider_name(provider_name) == "gemini" { // Special handling for Gemini: check for CLI auth first if crate::providers::gemini::GeminiProvider::has_cli_credentials() { print_bullet(&format!( @@ -653,7 +1409,7 @@ fn setup_provider() -> Result<(String, String, String)> { "groq" => "https://console.groq.com/keys", "mistral" => "https://console.mistral.ai/api-keys", "deepseek" => "https://platform.deepseek.com/api_keys", - "together" => "https://api.together.xyz/settings/api-keys", + "together-ai" => "https://api.together.xyz/settings/api-keys", "fireworks" => "https://fireworks.ai/account/api-keys", "perplexity" => "https://www.perplexity.ai/settings/api", "xai" => "https://console.x.ai", @@ -665,7 +1421,7 @@ fn setup_provider() -> Result<(String, String, String)> { "vercel" => "https://vercel.com/account/tokens", "cloudflare" => "https://dash.cloudflare.com/profile/api-tokens", "bedrock" => "https://console.aws.amazon.com/iam", - "gemini" | "google" | "google-gemini" => "https://aistudio.google.com/app/apikey", + "gemini" => "https://aistudio.google.com/app/apikey", _ => "", }; @@ -696,132 +1452,141 @@ fn setup_provider() -> Result<(String, String, String)> { }; // ── Model selection ── - let models: Vec<(&str, &str)> = match provider_name { - "openrouter" => vec![ - ( - "anthropic/claude-sonnet-4", - "Claude Sonnet 4 (balanced, recommended)", - ), - ( - "anthropic/claude-3.5-sonnet", - "Claude 3.5 Sonnet (fast, affordable)", - ), - ("openai/gpt-4o", "GPT-4o (OpenAI flagship)"), - ("openai/gpt-4o-mini", "GPT-4o Mini (fast, cheap)"), - ( - "google/gemini-2.0-flash-001", - "Gemini 2.0 Flash (Google, fast)", - ), - ( - "meta-llama/llama-3.3-70b-instruct", - "Llama 3.3 70B (open source)", - ), - ("deepseek/deepseek-chat", "DeepSeek Chat (affordable)"), - ], - "anthropic" => vec![ - ( - "claude-sonnet-4-20250514", - "Claude Sonnet 4 (balanced, recommended)", - ), - ("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet (fast)"), - ( - "claude-3-5-haiku-20241022", - "Claude 3.5 Haiku (fastest, cheapest)", - ), - ], - "openai" => vec![ - ("gpt-4o", "GPT-4o (flagship)"), - ("gpt-4o-mini", "GPT-4o Mini (fast, cheap)"), - ("o1-mini", "o1-mini (reasoning)"), - ], - "venice" => vec![ - ("llama-3.3-70b", "Llama 3.3 70B (default, fast)"), - ("claude-opus-45", "Claude Opus 4.5 via Venice (strongest)"), - ("llama-3.1-405b", "Llama 3.1 405B (largest open source)"), - ], - "groq" => vec![ - ( - "llama-3.3-70b-versatile", - "Llama 3.3 70B (fast, recommended)", - ), - ("llama-3.1-8b-instant", "Llama 3.1 8B (instant)"), - ("mixtral-8x7b-32768", "Mixtral 8x7B (32K context)"), - ], - "mistral" => vec![ - ("mistral-large-latest", "Mistral Large (flagship)"), - ("codestral-latest", "Codestral (code-focused)"), - ("mistral-small-latest", "Mistral Small (fast, cheap)"), - ], - "deepseek" => vec![ - ("deepseek-chat", "DeepSeek Chat (V3, recommended)"), - ("deepseek-reasoner", "DeepSeek Reasoner (R1)"), - ], - "xai" => vec![ - ("grok-3", "Grok 3 (flagship)"), - ("grok-3-mini", "Grok 3 Mini (fast)"), - ], - "perplexity" => vec![ - ("sonar-pro", "Sonar Pro (search + reasoning)"), - ("sonar", "Sonar (search, fast)"), - ], - "fireworks" => vec![ - ( - "accounts/fireworks/models/llama-v3p3-70b-instruct", - "Llama 3.3 70B", - ), - ( - "accounts/fireworks/models/mixtral-8x22b-instruct", - "Mixtral 8x22B", - ), - ], - "together" => vec![ - ( - "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo", - "Llama 3.1 70B Turbo", - ), - ( - "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", - "Llama 3.1 8B Turbo", - ), - ("mistralai/Mixtral-8x22B-Instruct-v0.1", "Mixtral 8x22B"), - ], - "cohere" => vec![ - ("command-r-plus", "Command R+ (flagship)"), - ("command-r", "Command R (fast)"), - ], - "moonshot" => vec![ - ("moonshot-v1-128k", "Moonshot V1 128K"), - ("moonshot-v1-32k", "Moonshot V1 32K"), - ], - "glm" | "zhipu" | "zai" | "z.ai" => vec![ - ("glm-5", "GLM-5 (latest)"), - ("glm-4-plus", "GLM-4 Plus (flagship)"), - ("glm-4-flash", "GLM-4 Flash (fast)"), - ], - "minimax" => vec![ - ("MiniMax-M2.5", "MiniMax M2.5 (latest flagship)"), - ("MiniMax-M2.5-highspeed", "MiniMax M2.5 Highspeed (faster)"), - ("MiniMax-M2.1", "MiniMax M2.1 (previous gen)"), - ], - "ollama" => vec![ - ("llama3.2", "Llama 3.2 (recommended local)"), - ("mistral", "Mistral 7B"), - ("codellama", "Code Llama"), - ("phi3", "Phi-3 (small, fast)"), - ], - "gemini" | "google" | "google-gemini" => vec![ - ("gemini-2.0-flash", "Gemini 2.0 Flash (fast, recommended)"), - ( - "gemini-2.0-flash-lite", - "Gemini 2.0 Flash Lite (fastest, cheapest)", - ), - ("gemini-1.5-pro", "Gemini 1.5 Pro (best quality)"), - ("gemini-1.5-flash", "Gemini 1.5 Flash (balanced)"), - ], - _ => vec![("default", "Default model")], - }; + let mut model_options = curated_models_for_provider(provider_name); + let mut live_options: Option> = None; - let model_labels: Vec<&str> = models.iter().map(|(_, label)| *label).collect(); + if supports_live_model_fetch(provider_name) { + let can_fetch_without_key = matches!(provider_name, "openrouter" | "ollama"); + let has_api_key = !api_key.trim().is_empty() + || std::env::var(provider_env_var(provider_name)) + .ok() + .is_some_and(|value| !value.trim().is_empty()); + + if can_fetch_without_key || has_api_key { + if let Some(cached) = + load_cached_models_for_provider(workspace_dir, provider_name, MODEL_CACHE_TTL_SECS)? + { + let shown_count = cached.models.len().min(LIVE_MODEL_MAX_OPTIONS); + print_bullet(&format!( + "Found cached models ({shown_count}) updated {} ago.", + humanize_age(cached.age_secs) + )); + + live_options = Some(build_model_options( + cached + .models + .into_iter() + .take(LIVE_MODEL_MAX_OPTIONS) + .collect(), + "cached", + )); + } + + let should_fetch_now = Confirm::new() + .with_prompt(if live_options.is_some() { + " Refresh models from provider now?" + } else { + " Fetch latest models from provider now?" + }) + .default(live_options.is_none()) + .interact()?; + + if should_fetch_now { + match fetch_live_models_for_provider(provider_name, &api_key) { + Ok(live_model_ids) if !live_model_ids.is_empty() => { + cache_live_models_for_provider( + workspace_dir, + provider_name, + &live_model_ids, + )?; + + let fetched_count = live_model_ids.len(); + let shown_count = fetched_count.min(LIVE_MODEL_MAX_OPTIONS); + let shown_models: Vec = live_model_ids + .into_iter() + .take(LIVE_MODEL_MAX_OPTIONS) + .collect(); + + if shown_count < fetched_count { + print_bullet(&format!( + "Fetched {fetched_count} models. Showing first {shown_count}." + )); + } else { + print_bullet(&format!("Fetched {shown_count} live models.")); + } + + live_options = Some(build_model_options(shown_models, "live")); + } + Ok(_) => { + print_bullet("Provider returned no models; using curated list."); + } + Err(error) => { + print_bullet(&format!( + "Live fetch failed ({}); using cached/curated list.", + style(error.to_string()).yellow() + )); + + if live_options.is_none() { + if let Some(stale) = + load_any_cached_models_for_provider(workspace_dir, provider_name)? + { + print_bullet(&format!( + "Loaded stale cache from {} ago.", + humanize_age(stale.age_secs) + )); + + live_options = Some(build_model_options( + stale + .models + .into_iter() + .take(LIVE_MODEL_MAX_OPTIONS) + .collect(), + "stale-cache", + )); + } + } + } + } + } + } else { + print_bullet("No API key detected, so using curated model list."); + print_bullet("Tip: add an API key and rerun onboarding to fetch live models."); + } + } + + if let Some(live_model_options) = live_options { + let source_options = vec![ + format!("Provider model list ({})", live_model_options.len()), + format!("Curated starter list ({})", model_options.len()), + ]; + + let source_idx = Select::new() + .with_prompt(" Model source") + .items(&source_options) + .default(0) + .interact()?; + + if source_idx == 0 { + model_options = live_model_options; + } + } + + if model_options.is_empty() { + model_options.push(( + default_model_for_provider(provider_name), + "Provider default model".to_string(), + )); + } + + model_options.push(( + CUSTOM_MODEL_SENTINEL.to_string(), + "Custom model ID (type manually)".to_string(), + )); + + let model_labels: Vec = model_options + .iter() + .map(|(model_id, label)| format!("{label} — {}", style(model_id).dim())) + .collect(); let model_idx = Select::new() .with_prompt(" Select your default model") @@ -829,7 +1594,15 @@ fn setup_provider() -> Result<(String, String, String)> { .default(0) .interact()?; - let model = models[model_idx].0.to_string(); + let selected_model = model_options[model_idx].0.clone(); + let model = if selected_model == CUSTOM_MODEL_SENTINEL { + Input::new() + .with_prompt(" Enter custom model ID") + .default(default_model_for_provider(provider_name)) + .interact_text()? + } else { + selected_model + }; println!( " {} Provider: {} | Model: {}", @@ -843,7 +1616,7 @@ fn setup_provider() -> Result<(String, String, String)> { /// Map provider name to its conventional env var fn provider_env_var(name: &str) -> &'static str { - match name { + match canonical_provider_name(name) { "openrouter" => "OPENROUTER_API_KEY", "anthropic" => "ANTHROPIC_API_KEY", "openai" => "OPENAI_API_KEY", @@ -851,8 +1624,8 @@ fn provider_env_var(name: &str) -> &'static str { "groq" => "GROQ_API_KEY", "mistral" => "MISTRAL_API_KEY", "deepseek" => "DEEPSEEK_API_KEY", - "xai" | "grok" => "XAI_API_KEY", - "together" | "together-ai" => "TOGETHER_API_KEY", + "xai" => "XAI_API_KEY", + "together-ai" => "TOGETHER_API_KEY", "fireworks" | "fireworks-ai" => "FIREWORKS_API_KEY", "perplexity" => "PERPLEXITY_API_KEY", "cohere" => "COHERE_API_KEY", @@ -866,7 +1639,7 @@ fn provider_env_var(name: &str) -> &'static str { "vercel" | "vercel-ai" => "VERCEL_API_KEY", "cloudflare" | "cloudflare-ai" => "CLOUDFLARE_API_KEY", "bedrock" | "aws-bedrock" => "AWS_ACCESS_KEY_ID", - "gemini" | "google" | "google-gemini" => "GEMINI_API_KEY", + "gemini" => "GEMINI_API_KEY", _ => "API_KEY", } } @@ -2796,6 +3569,7 @@ fn print_summary(config: &Config) { #[cfg(test)] mod tests { use super::*; + use serde_json::json; use tempfile::TempDir; // ── ProjectContext defaults ────────────────────────────────── @@ -3211,6 +3985,204 @@ mod tests { assert!(heartbeat.contains("Claw")); } + // ── model helper coverage ─────────────────────────────────── + + #[test] + fn default_model_for_provider_uses_latest_defaults() { + assert_eq!(default_model_for_provider("openai"), "gpt-5.2"); + assert_eq!( + default_model_for_provider("anthropic"), + "claude-sonnet-4-20250514" + ); + assert_eq!(default_model_for_provider("gemini"), "gemini-2.5-pro"); + assert_eq!(default_model_for_provider("google"), "gemini-2.5-pro"); + assert_eq!( + default_model_for_provider("google-gemini"), + "gemini-2.5-pro" + ); + } + + #[test] + fn curated_models_for_openai_include_latest_choices() { + let ids: Vec = curated_models_for_provider("openai") + .into_iter() + .map(|(id, _)| id) + .collect(); + + assert!(ids.contains(&"gpt-5.2".to_string())); + assert!(ids.contains(&"gpt-5-mini".to_string())); + } + + #[test] + fn supports_live_model_fetch_for_supported_and_unsupported_providers() { + assert!(supports_live_model_fetch("openai")); + assert!(supports_live_model_fetch("anthropic")); + assert!(supports_live_model_fetch("gemini")); + assert!(supports_live_model_fetch("google")); + assert!(supports_live_model_fetch("grok")); + assert!(supports_live_model_fetch("together")); + assert!(supports_live_model_fetch("ollama")); + assert!(!supports_live_model_fetch("venice")); + } + + #[test] + fn curated_models_provider_aliases_share_same_catalog() { + assert_eq!( + curated_models_for_provider("xai"), + curated_models_for_provider("grok") + ); + assert_eq!( + curated_models_for_provider("together-ai"), + curated_models_for_provider("together") + ); + assert_eq!( + curated_models_for_provider("gemini"), + curated_models_for_provider("google") + ); + assert_eq!( + curated_models_for_provider("gemini"), + curated_models_for_provider("google-gemini") + ); + } + + #[test] + fn parse_openai_model_ids_supports_data_array_payload() { + let payload = json!({ + "data": [ + {"id": " gpt-5.1 "}, + {"id": "gpt-5-mini"}, + {"id": "gpt-5.1"}, + {"id": ""} + ] + }); + + let ids = parse_openai_compatible_model_ids(&payload); + assert_eq!(ids, vec!["gpt-5-mini".to_string(), "gpt-5.1".to_string()]); + } + + #[test] + fn parse_openai_model_ids_supports_root_array_payload() { + let payload = json!([ + {"id": "alpha"}, + {"id": "beta"}, + {"id": "alpha"} + ]); + + let ids = parse_openai_compatible_model_ids(&payload); + assert_eq!(ids, vec!["alpha".to_string(), "beta".to_string()]); + } + + #[test] + fn parse_gemini_model_ids_filters_for_generate_content() { + let payload = json!({ + "models": [ + { + "name": "models/gemini-2.5-pro", + "supportedGenerationMethods": ["generateContent", "countTokens"] + }, + { + "name": "models/text-embedding-004", + "supportedGenerationMethods": ["embedContent"] + }, + { + "name": "models/gemini-2.5-flash", + "supportedGenerationMethods": ["generateContent"] + } + ] + }); + + let ids = parse_gemini_model_ids(&payload); + assert_eq!( + ids, + vec!["gemini-2.5-flash".to_string(), "gemini-2.5-pro".to_string()] + ); + } + + #[test] + fn parse_ollama_model_ids_extracts_and_deduplicates_names() { + let payload = json!({ + "models": [ + {"name": "llama3.2:latest"}, + {"name": "mistral:latest"}, + {"name": "llama3.2:latest"} + ] + }); + + let ids = parse_ollama_model_ids(&payload); + assert_eq!( + ids, + vec!["llama3.2:latest".to_string(), "mistral:latest".to_string()] + ); + } + + #[test] + fn model_cache_round_trip_returns_fresh_entry() { + let tmp = TempDir::new().unwrap(); + let models = vec!["gpt-5.1".to_string(), "gpt-5-mini".to_string()]; + + cache_live_models_for_provider(tmp.path(), "openai", &models).unwrap(); + + let cached = + load_cached_models_for_provider(tmp.path(), "openai", MODEL_CACHE_TTL_SECS).unwrap(); + let cached = cached.expect("expected fresh cached models"); + + assert_eq!(cached.models.len(), 2); + assert!(cached.models.contains(&"gpt-5.1".to_string())); + assert!(cached.models.contains(&"gpt-5-mini".to_string())); + } + + #[test] + fn model_cache_ttl_filters_stale_entries() { + let tmp = TempDir::new().unwrap(); + let stale = ModelCacheState { + entries: vec![ModelCacheEntry { + provider: "openai".to_string(), + fetched_at_unix: now_unix_secs().saturating_sub(MODEL_CACHE_TTL_SECS + 120), + models: vec!["gpt-5.1".to_string()], + }], + }; + + save_model_cache_state(tmp.path(), &stale).unwrap(); + + let fresh = + load_cached_models_for_provider(tmp.path(), "openai", MODEL_CACHE_TTL_SECS).unwrap(); + assert!(fresh.is_none()); + + let stale_any = load_any_cached_models_for_provider(tmp.path(), "openai").unwrap(); + assert!(stale_any.is_some()); + } + + #[test] + fn run_models_refresh_uses_fresh_cache_without_network() { + let tmp = TempDir::new().unwrap(); + + cache_live_models_for_provider(tmp.path(), "openai", &["gpt-5.1".to_string()]).unwrap(); + + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + default_provider: Some("openai".to_string()), + ..Config::default() + }; + + run_models_refresh(&config, None, false).unwrap(); + } + + #[test] + fn run_models_refresh_rejects_unsupported_provider() { + let tmp = TempDir::new().unwrap(); + + let config = Config { + workspace_dir: tmp.path().to_path_buf(), + default_provider: Some("venice".to_string()), + ..Config::default() + }; + + let err = run_models_refresh(&config, None, true).unwrap_err(); + assert!(err + .to_string() + .contains("does not support live model discovery")); + } + // ── provider_env_var ──────────────────────────────────────── #[test] @@ -3221,8 +4193,11 @@ mod tests { assert_eq!(provider_env_var("ollama"), "API_KEY"); // fallback assert_eq!(provider_env_var("xai"), "XAI_API_KEY"); assert_eq!(provider_env_var("grok"), "XAI_API_KEY"); // alias - assert_eq!(provider_env_var("together"), "TOGETHER_API_KEY"); - assert_eq!(provider_env_var("together-ai"), "TOGETHER_API_KEY"); // alias + assert_eq!(provider_env_var("together"), "TOGETHER_API_KEY"); // alias + assert_eq!(provider_env_var("together-ai"), "TOGETHER_API_KEY"); + assert_eq!(provider_env_var("google"), "GEMINI_API_KEY"); // alias + assert_eq!(provider_env_var("google-gemini"), "GEMINI_API_KEY"); // alias + assert_eq!(provider_env_var("gemini"), "GEMINI_API_KEY"); } #[test]