fix(providers): harden tool fallback and refresh model catalogs

This commit is contained in:
Chummy 2026-02-18 22:36:39 +08:00
parent 43494f8331
commit b4b379e3e7
9 changed files with 1111 additions and 367 deletions

View file

@ -2,7 +2,7 @@ use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::observability::{self, Observer, ObserverEvent};
use crate::providers::{self, ChatMessage, Provider, ToolCall};
use crate::providers::{self, ChatMessage, ChatRequest, Provider, ToolCall};
use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
@ -868,13 +868,9 @@ pub(crate) async fn run_tool_call_loop(
max_tool_iterations
};
// Build native tool definitions once if the provider supports them.
let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty();
let tool_definitions = if use_native_tools {
tools_to_openai_format(tools_registry)
} else {
Vec::new()
};
let tool_specs: Vec<crate::tools::ToolSpec> =
tools_registry.iter().map(|tool| tool.spec()).collect();
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
for _iteration in 0..max_iterations {
observer.record_event(&ObserverEvent::LlmRequest {
@ -885,101 +881,73 @@ pub(crate) async fn run_tool_call_loop(
let llm_started_at = Instant::now();
// Choose between native tool-call API and prompt-based tool use.
// `native_tool_calls` preserves the structured ToolCall vec (with IDs) so
// that tool results can later be sent back as proper `role: tool` messages.
// Unified path via Provider::chat so provider-specific native tool logic
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
let request_tools = if use_native_tools {
Some(tool_specs.as_slice())
} else {
None
};
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
if use_native_tools {
match provider
.chat_with_tools(history, &tool_definitions, model, temperature)
.await
{
Ok(resp) => {
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
duration: llm_started_at.elapsed(),
success: true,
error_message: None,
});
let response_text = resp.text_or_empty().to_string();
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
let mut parsed_text = String::new();
match provider
.chat(
ChatRequest {
messages: history,
tools: request_tools,
},
model,
temperature,
)
.await
{
Ok(resp) => {
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
duration: llm_started_at.elapsed(),
success: true,
error_message: None,
});
if calls.is_empty() {
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
if !fallback_text.is_empty() {
parsed_text = fallback_text;
}
calls = fallback_calls;
let response_text = resp.text_or_empty().to_string();
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
let mut parsed_text = String::new();
if calls.is_empty() {
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
if !fallback_text.is_empty() {
parsed_text = fallback_text;
}
// Use JSON format for native tools so convert_messages()
// can reconstruct proper NativeMessage with tool_calls.
let assistant_history_content = if resp.tool_calls.is_empty() {
response_text.clone()
} else {
build_native_assistant_history(&response_text, &resp.tool_calls)
};
let native_calls = resp.tool_calls;
(
response_text,
parsed_text,
calls,
assistant_history_content,
native_calls,
)
}
Err(e) => {
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
duration: llm_started_at.elapsed(),
success: false,
error_message: Some(crate::providers::sanitize_api_error(
&e.to_string(),
)),
});
return Err(e);
calls = fallback_calls;
}
// Preserve native tool call IDs in assistant history so role=tool
// follow-up messages can reference the exact call id.
let assistant_history_content = if resp.tool_calls.is_empty() {
response_text.clone()
} else {
build_native_assistant_history(&response_text, &resp.tool_calls)
};
let native_calls = resp.tool_calls;
(
response_text,
parsed_text,
calls,
assistant_history_content,
native_calls,
)
}
} else {
match provider
.chat_with_history(history, model, temperature)
.await
{
Ok(resp) => {
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
duration: llm_started_at.elapsed(),
success: true,
error_message: None,
});
let response_text = resp;
let assistant_history_content = response_text.clone();
let (parsed_text, calls) = parse_tool_calls(&response_text);
(
response_text,
parsed_text,
calls,
assistant_history_content,
Vec::new(),
)
}
Err(e) => {
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
duration: llm_started_at.elapsed(),
success: false,
error_message: Some(crate::providers::sanitize_api_error(
&e.to_string(),
)),
});
return Err(e);
}
Err(e) => {
observer.record_event(&ObserverEvent::LlmResponse {
provider: provider_name.to_string(),
model: model.to_string(),
duration: llm_started_at.elapsed(),
success: false,
error_message: Some(crate::providers::sanitize_api_error(&e.to_string())),
});
return Err(e);
}
};

View file

@ -2402,7 +2402,7 @@ impl Default for Config {
api_key: None,
api_url: None,
default_provider: Some("openrouter".to_string()),
default_model: Some("anthropic/claude-sonnet-4".to_string()),
default_model: Some("anthropic/claude-sonnet-4.6".to_string()),
default_temperature: 0.7,
observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(),

View file

@ -99,6 +99,132 @@ pub fn run(config: &Config) -> Result<()> {
Ok(())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum ModelProbeOutcome {
Ok,
Skipped,
AuthOrAccess,
Error,
}
fn classify_model_probe_error(err_message: &str) -> ModelProbeOutcome {
let lower = err_message.to_lowercase();
if lower.contains("does not support live model discovery") {
return ModelProbeOutcome::Skipped;
}
if [
"401",
"403",
"429",
"unauthorized",
"forbidden",
"api key",
"token",
"insufficient balance",
"insufficient quota",
"plan does not include",
"rate limit",
]
.iter()
.any(|hint| lower.contains(hint))
{
return ModelProbeOutcome::AuthOrAccess;
}
ModelProbeOutcome::Error
}
fn doctor_model_targets(provider_override: Option<&str>) -> Vec<String> {
if let Some(provider) = provider_override.map(str::trim).filter(|p| !p.is_empty()) {
return vec![provider.to_string()];
}
crate::providers::list_providers()
.into_iter()
.map(|provider| provider.name.to_string())
.collect()
}
pub fn run_models(config: &Config, provider_override: Option<&str>, use_cache: bool) -> Result<()> {
let targets = doctor_model_targets(provider_override);
if targets.is_empty() {
anyhow::bail!("No providers available for model probing");
}
println!("🩺 ZeroClaw Doctor — Model Catalog Probe");
println!(" Providers to probe: {}", targets.len());
println!(
" Mode: {}",
if use_cache {
"cache-first"
} else {
"force live refresh"
}
);
println!();
let mut ok_count = 0usize;
let mut skipped_count = 0usize;
let mut auth_count = 0usize;
let mut error_count = 0usize;
for provider_name in &targets {
println!(" [{}]", provider_name);
match crate::onboard::run_models_refresh(config, Some(provider_name), !use_cache) {
Ok(()) => {
ok_count += 1;
println!(" ✅ model catalog check passed");
}
Err(error) => {
let error_text = format_error_chain(&error);
match classify_model_probe_error(&error_text) {
ModelProbeOutcome::Skipped => {
skipped_count += 1;
println!(" ⚪ skipped: {}", truncate_for_display(&error_text, 160));
}
ModelProbeOutcome::AuthOrAccess => {
auth_count += 1;
println!(
" ⚠️ auth/access: {}",
truncate_for_display(&error_text, 160)
);
}
ModelProbeOutcome::Error => {
error_count += 1;
println!(" ❌ error: {}", truncate_for_display(&error_text, 160));
}
ModelProbeOutcome::Ok => {
ok_count += 1;
}
}
}
}
println!();
}
println!(
" Summary: {} ok, {} skipped, {} auth/access, {} errors",
ok_count, skipped_count, auth_count, error_count
);
if auth_count > 0 {
println!(
" 💡 Some providers need valid API keys/plan access before `/models` can be fetched."
);
}
if provider_override.is_some() && ok_count == 0 {
anyhow::bail!("Model probe failed for target provider")
}
Ok(())
}
// ── Config semantic validation ───────────────────────────────────
fn check_config_semantics(config: &Config, items: &mut Vec<DiagItem>) {
@ -572,6 +698,22 @@ fn check_command_available(cmd: &str, args: &[&str], cat: &'static str, items: &
}
}
fn format_error_chain(error: &anyhow::Error) -> String {
let mut parts = Vec::new();
for cause in error.chain() {
let message = cause.to_string();
if !message.is_empty() {
parts.push(message);
}
}
if parts.is_empty() {
return String::new();
}
parts.join(": ")
}
fn truncate_for_display(input: &str, max_chars: usize) -> String {
let mut chars = input.chars();
let preview: String = chars.by_ref().take(max_chars).collect();
@ -615,6 +757,25 @@ mod tests {
assert_eq!(DiagItem::error("t", "m").icon(), "");
}
#[test]
fn classify_model_probe_error_marks_unsupported_as_skipped() {
let outcome = classify_model_probe_error(
"Provider 'copilot' does not support live model discovery yet",
);
assert_eq!(outcome, ModelProbeOutcome::Skipped);
}
#[test]
fn classify_model_probe_error_marks_auth_and_plan_issues() {
let auth_outcome = classify_model_probe_error("OpenAI API error (401): unauthorized");
assert_eq!(auth_outcome, ModelProbeOutcome::AuthOrAccess);
let plan_outcome = classify_model_probe_error(
"Z.AI API error (429): plan does not include requested model",
);
assert_eq!(plan_outcome, ModelProbeOutcome::AuthOrAccess);
}
#[test]
fn config_validation_catches_bad_temperature() {
let mut config = Config::default();

View file

@ -178,7 +178,10 @@ enum Commands {
},
/// Run diagnostics for daemon/scheduler/channel freshness
Doctor,
Doctor {
#[command(subcommand)]
doctor_command: Option<DoctorCommands>,
},
/// Show system status (full details)
Status,
@ -404,6 +407,20 @@ enum ModelCommands {
},
}
#[derive(Subcommand, Debug)]
enum DoctorCommands {
/// Probe model catalogs across providers and report availability
Models {
/// Probe a specific provider only (default: all known providers)
#[arg(long)]
provider: Option<String>,
/// Prefer cached catalogs when available (skip forced live refresh)
#[arg(long)]
use_cache: bool,
},
}
#[derive(Subcommand, Debug)]
enum ChannelCommands {
/// List configured channels
@ -646,7 +663,12 @@ async fn main() -> Result<()> {
Commands::Models { model_command } => match model_command {
ModelCommands::Refresh { provider, force } => {
onboard::run_models_refresh(&config, provider.as_deref(), force)
let config_for_refresh = config.clone();
tokio::task::spawn_blocking(move || {
onboard::run_models_refresh(&config_for_refresh, provider.as_deref(), force)
})
.await
.map_err(|e| anyhow::anyhow!("models refresh task failed: {e}"))?
}
},
@ -685,7 +707,20 @@ async fn main() -> Result<()> {
Commands::Service { service_command } => service::handle_command(&service_command, &config),
Commands::Doctor => doctor::run(&config),
Commands::Doctor { doctor_command } => match doctor_command {
Some(DoctorCommands::Models {
provider,
use_cache,
}) => {
let config_for_models = config.clone();
tokio::task::spawn_blocking(move || {
doctor::run_models(&config_for_models, provider.as_deref(), use_cache)
})
.await
.map_err(|e| anyhow::anyhow!("doctor models task failed: {e}"))?
}
None => doctor::run(&config),
},
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await,

View file

@ -476,6 +476,19 @@ fn canonical_provider_name(provider_name: &str) -> &str {
}
}
fn allows_unauthenticated_model_fetch(provider_name: &str) -> bool {
matches!(
canonical_provider_name(provider_name),
"openrouter"
| "ollama"
| "venice"
| "astrai"
| "nvidia"
| "nvidia-nim"
| "build.nvidia.com"
)
}
/// Pick a sensible default model for the given provider.
const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [
("MiniMax-M2.5", "MiniMax M2.5 (latest, recommended)"),
@ -488,16 +501,28 @@ const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [
fn default_model_for_provider(provider: &str) -> String {
match canonical_provider_name(provider) {
"anthropic" => "claude-sonnet-4-5-20250929".into(),
"openrouter" => "anthropic/claude-sonnet-4.6".into(),
"openai" => "gpt-5.2".into(),
"openai-codex" => "gpt-5-codex".into(),
"venice" => "zai-org-glm-5".into(),
"groq" => "llama-3.3-70b-versatile".into(),
"mistral" => "mistral-large-latest".into(),
"deepseek" => "deepseek-chat".into(),
"xai" => "grok-4-1-fast-reasoning".into(),
"perplexity" => "sonar-pro".into(),
"fireworks" => "accounts/fireworks/models/llama-v3p3-70b-instruct".into(),
"together-ai" => "meta-llama/Llama-3.3-70B-Instruct-Turbo".into(),
"cohere" => "command-a-03-2025".into(),
"moonshot" => "kimi-k2.5".into(),
"glm" | "zai" => "glm-5".into(),
"minimax" => "MiniMax-M2.5".into(),
"qwen" => "qwen-plus".into(),
"ollama" => "llama3.2".into(),
"groq" => "llama-3.3-70b-versatile".into(),
"deepseek" => "deepseek-chat".into(),
"gemini" => "gemini-2.5-pro".into(),
"kimi-code" => "kimi-for-coding".into(),
_ => "anthropic/claude-sonnet-4.5".into(),
"nvidia" | "nvidia-nim" | "build.nvidia.com" => "meta/llama-3.3-70b-instruct".into(),
"astrai" => "anthropic/claude-sonnet-4.6".into(),
_ => "anthropic/claude-sonnet-4.6".into(),
}
}
@ -505,8 +530,8 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
match canonical_provider_name(provider_name) {
"openrouter" => vec![
(
"anthropic/claude-sonnet-4.5".to_string(),
"Claude Sonnet 4.5 (balanced, recommended)".to_string(),
"anthropic/claude-sonnet-4.6".to_string(),
"Claude Sonnet 4.6 (balanced, recommended)".to_string(),
),
(
"openai/gpt-5.2".to_string(),
@ -565,18 +590,33 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
"GPT-5.2 Codex (agentic coding)".to_string(),
),
],
"openai-codex" => vec![
(
"gpt-5-codex".to_string(),
"GPT-5 Codex (recommended)".to_string(),
),
(
"gpt-5.2-codex".to_string(),
"GPT-5.2 Codex (agentic coding)".to_string(),
),
("o4-mini".to_string(), "o4-mini (fallback)".to_string()),
],
"venice" => vec![
(
"llama-3.3-70b".to_string(),
"Llama 3.3 70B (default, fast)".to_string(),
"zai-org-glm-5".to_string(),
"GLM-5 via Venice (agentic flagship)".to_string(),
),
(
"claude-opus-45".to_string(),
"Claude Opus 4.5 via Venice (strongest)".to_string(),
"claude-sonnet-4-6".to_string(),
"Claude Sonnet 4.6 via Venice (best quality)".to_string(),
),
(
"llama-3.1-405b".to_string(),
"Llama 3.1 405B (largest open source)".to_string(),
"deepseek-v3.2".to_string(),
"DeepSeek V3.2 via Venice (strong value)".to_string(),
),
(
"grok-41-fast".to_string(),
"Grok 4.1 Fast via Venice (low latency)".to_string(),
),
],
"groq" => vec![
@ -701,27 +741,27 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
],
"moonshot" => vec![
(
"kimi-latest".to_string(),
"Kimi Latest (rolling latest assistant model)".to_string(),
"kimi-k2.5".to_string(),
"Kimi K2.5 (latest flagship, recommended)".to_string(),
),
(
"kimi-k2-thinking".to_string(),
"Kimi K2 Thinking (deep reasoning + tool use)".to_string(),
),
(
"kimi-k2-0905-preview".to_string(),
"Kimi K2 0905 Preview (strong coding)".to_string(),
),
(
"kimi-thinking-preview".to_string(),
"Kimi Thinking Preview (deep reasoning)".to_string(),
),
],
"glm" | "zai" => vec![
(
"glm-4.7".to_string(),
"GLM-4.7 (latest flagship)".to_string(),
),
("glm-5".to_string(), "GLM-5 (high reasoning)".to_string()),
(
"glm-4-plus".to_string(),
"GLM-4 Plus (stable baseline)".to_string(),
"glm-4.7".to_string(),
"GLM-4.7 (strong general-purpose quality)".to_string(),
),
(
"glm-4.5-air".to_string(),
"GLM-4.5 Air (lower latency)".to_string(),
),
],
"minimax" => vec![
@ -730,12 +770,12 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
"MiniMax M2.5 (latest flagship)".to_string(),
),
(
"MiniMax-M2.1".to_string(),
"MiniMax M2.1 (strong coding/reasoning)".to_string(),
"MiniMax-M2.5-highspeed".to_string(),
"MiniMax M2.5 High-Speed (fast)".to_string(),
),
(
"MiniMax-M2.1-lightning".to_string(),
"MiniMax M2.1 Lightning (fast)".to_string(),
"MiniMax-M2.1".to_string(),
"MiniMax M2.1 (strong coding/reasoning)".to_string(),
),
],
"qwen" => vec![
@ -752,6 +792,42 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
"Qwen Turbo (fast and cost-efficient)".to_string(),
),
],
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec![
(
"meta/llama-3.3-70b-instruct".to_string(),
"Llama 3.3 70B Instruct (balanced default)".to_string(),
),
(
"deepseek-ai/deepseek-v3.2".to_string(),
"DeepSeek V3.2 (reasoning + coding)".to_string(),
),
(
"google/gemma-3-27b-it".to_string(),
"Gemma 3 27B IT (cost-efficient)".to_string(),
),
(
"meta/llama-3.1-405b-instruct".to_string(),
"Llama 3.1 405B Instruct (max quality)".to_string(),
),
],
"astrai" => vec![
(
"anthropic/claude-sonnet-4.6".to_string(),
"Claude Sonnet 4.6 (balanced default)".to_string(),
),
(
"openai/gpt-5.2".to_string(),
"GPT-5.2 (latest flagship)".to_string(),
),
(
"deepseek/deepseek-v3.2".to_string(),
"DeepSeek V3.2 (agentic + affordable)".to_string(),
),
(
"z-ai/glm-5".to_string(),
"GLM-5 (high reasoning)".to_string(),
),
],
"ollama" => vec![
(
"llama3.2".to_string(),
@ -797,9 +873,49 @@ fn supports_live_model_fetch(provider_name: &str) -> bool {
| "gemini"
| "ollama"
| "astrai"
| "venice"
| "fireworks"
| "cohere"
| "moonshot"
| "glm"
| "zai"
| "qwen"
| "nvidia"
| "nvidia-nim"
| "build.nvidia.com"
)
}
fn models_endpoint_for_provider(provider_name: &str) -> Option<&'static str> {
match provider_name {
"qwen-intl" => Some("https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models"),
"dashscope-us" => Some("https://dashscope-us.aliyuncs.com/compatible-mode/v1/models"),
"moonshot-cn" | "kimi-cn" => Some("https://api.moonshot.cn/v1/models"),
"glm-cn" | "bigmodel" => Some("https://open.bigmodel.cn/api/paas/v4/models"),
"zai-cn" | "z.ai-cn" => Some("https://open.bigmodel.cn/api/coding/paas/v4/models"),
_ => match canonical_provider_name(provider_name) {
"openai" => Some("https://api.openai.com/v1/models"),
"venice" => Some("https://api.venice.ai/api/v1/models"),
"groq" => Some("https://api.groq.com/openai/v1/models"),
"mistral" => Some("https://api.mistral.ai/v1/models"),
"deepseek" => Some("https://api.deepseek.com/v1/models"),
"xai" => Some("https://api.x.ai/v1/models"),
"together-ai" => Some("https://api.together.xyz/v1/models"),
"fireworks" => Some("https://api.fireworks.ai/inference/v1/models"),
"cohere" => Some("https://api.cohere.com/compatibility/v1/models"),
"moonshot" => Some("https://api.moonshot.ai/v1/models"),
"glm" => Some("https://api.z.ai/api/paas/v4/models"),
"zai" => Some("https://api.z.ai/api/coding/paas/v4/models"),
"qwen" => Some("https://dashscope.aliyuncs.com/compatible-mode/v1/models"),
"nvidia" | "nvidia-nim" | "build.nvidia.com" => {
Some("https://integrate.api.nvidia.com/v1/models")
}
"astrai" => Some("https://as-trai.com/v1/models"),
_ => None,
},
}
}
fn build_model_fetch_client() -> Result<reqwest::blocking::Client> {
reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(8))
@ -882,15 +998,21 @@ fn parse_ollama_model_ids(payload: &Value) -> Vec<String> {
normalize_model_ids(ids)
}
fn fetch_openai_compatible_models(endpoint: &str, api_key: Option<&str>) -> Result<Vec<String>> {
let Some(api_key) = api_key else {
return Ok(Vec::new());
};
fn fetch_openai_compatible_models(
endpoint: &str,
api_key: Option<&str>,
allow_unauthenticated: bool,
) -> Result<Vec<String>> {
let client = build_model_fetch_client()?;
let payload: Value = client
.get(endpoint)
.bearer_auth(api_key)
let mut request = client.get(endpoint);
if let Some(api_key) = api_key {
request = request.bearer_auth(api_key);
} else if !allow_unauthenticated {
bail!("model fetch requires API key for endpoint {endpoint}");
}
let payload: Value = request
.send()
.and_then(reqwest::blocking::Response::error_for_status)
.with_context(|| format!("model fetch failed: GET {endpoint}"))?
@ -919,7 +1041,7 @@ fn fetch_openrouter_models(api_key: Option<&str>) -> Result<Vec<String>> {
fn fetch_anthropic_models(api_key: Option<&str>) -> Result<Vec<String>> {
let Some(api_key) = api_key else {
return Ok(Vec::new());
bail!("Anthropic model fetch requires API key or OAuth token");
};
let client = build_model_fetch_client()?;
@ -954,7 +1076,7 @@ fn fetch_anthropic_models(api_key: Option<&str>) -> Result<Vec<String>> {
fn fetch_gemini_models(api_key: Option<&str>) -> Result<Vec<String>> {
let Some(api_key) = api_key else {
return Ok(Vec::new());
bail!("Gemini model fetch requires API key");
};
let client = build_model_fetch_client()?;
@ -984,6 +1106,7 @@ fn fetch_ollama_models() -> Result<Vec<String>> {
}
fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result<Vec<String>> {
let requested_provider_name = provider_name;
let provider_name = canonical_provider_name(provider_name);
let api_key = if api_key.trim().is_empty() {
std::env::var(provider_env_var(provider_name))
@ -1006,25 +1129,6 @@ fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result<
let models = match provider_name {
"openrouter" => fetch_openrouter_models(api_key.as_deref())?,
"openai" => {
fetch_openai_compatible_models("https://api.openai.com/v1/models", api_key.as_deref())?
}
"groq" => fetch_openai_compatible_models(
"https://api.groq.com/openai/v1/models",
api_key.as_deref(),
)?,
"mistral" => {
fetch_openai_compatible_models("https://api.mistral.ai/v1/models", api_key.as_deref())?
}
"deepseek" => fetch_openai_compatible_models(
"https://api.deepseek.com/v1/models",
api_key.as_deref(),
)?,
"xai" => fetch_openai_compatible_models("https://api.x.ai/v1/models", api_key.as_deref())?,
"together-ai" => fetch_openai_compatible_models(
"https://api.together.xyz/v1/models",
api_key.as_deref(),
)?,
"anthropic" => fetch_anthropic_models(api_key.as_deref())?,
"gemini" => fetch_gemini_models(api_key.as_deref())?,
"ollama" => {
@ -1046,10 +1150,15 @@ fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result<
]
}
}
"astrai" => {
fetch_openai_compatible_models("https://as-trai.com/v1/models", api_key.as_deref())?
_ => {
if let Some(endpoint) = models_endpoint_for_provider(requested_provider_name) {
let allow_unauthenticated =
allows_unauthenticated_model_fetch(requested_provider_name);
fetch_openai_compatible_models(endpoint, api_key.as_deref(), allow_unauthenticated)?
} else {
Vec::new()
}
}
_ => Vec::new(),
};
Ok(models)
@ -1719,167 +1828,12 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio
// ── Model selection ──
let canonical_provider = canonical_provider_name(provider_name);
let models: Vec<(&str, &str)> = match canonical_provider {
"openrouter" => vec![
(
"anthropic/claude-sonnet-4",
"Claude Sonnet 4 (balanced, recommended)",
),
(
"anthropic/claude-3.5-sonnet",
"Claude 3.5 Sonnet (fast, affordable)",
),
("openai/gpt-4o", "GPT-4o (OpenAI flagship)"),
("openai/gpt-4o-mini", "GPT-4o Mini (fast, cheap)"),
(
"google/gemini-2.0-flash-001",
"Gemini 2.0 Flash (Google, fast)",
),
(
"meta-llama/llama-3.3-70b-instruct",
"Llama 3.3 70B (open source)",
),
("deepseek/deepseek-chat", "DeepSeek Chat (affordable)"),
],
"anthropic" => vec![
(
"claude-sonnet-4-20250514",
"Claude Sonnet 4 (balanced, recommended)",
),
("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet (fast)"),
(
"claude-3-5-haiku-20241022",
"Claude 3.5 Haiku (fastest, cheapest)",
),
],
"openai" => vec![
("gpt-4o", "GPT-4o (flagship)"),
("gpt-4o-mini", "GPT-4o Mini (fast, cheap)"),
("o1-mini", "o1-mini (reasoning)"),
],
"openai-codex" => vec![
("gpt-5-codex", "GPT-5 Codex (recommended)"),
("o4-mini", "o4-mini (fallback)"),
],
"venice" => vec![
("llama-3.3-70b", "Llama 3.3 70B (default, fast)"),
("claude-opus-45", "Claude Opus 4.5 via Venice (strongest)"),
("llama-3.1-405b", "Llama 3.1 405B (largest open source)"),
],
"groq" => vec![
(
"llama-3.3-70b-versatile",
"Llama 3.3 70B (fast, recommended)",
),
("llama-3.1-8b-instant", "Llama 3.1 8B (instant)"),
("mixtral-8x7b-32768", "Mixtral 8x7B (32K context)"),
],
"mistral" => vec![
("mistral-large-latest", "Mistral Large (flagship)"),
("codestral-latest", "Codestral (code-focused)"),
("mistral-small-latest", "Mistral Small (fast, cheap)"),
],
"deepseek" => vec![
("deepseek-chat", "DeepSeek Chat (V3, recommended)"),
("deepseek-reasoner", "DeepSeek Reasoner (R1)"),
],
"xai" => vec![
("grok-3", "Grok 3 (flagship)"),
("grok-3-mini", "Grok 3 Mini (fast)"),
],
"perplexity" => vec![
("sonar-pro", "Sonar Pro (search + reasoning)"),
("sonar", "Sonar (search, fast)"),
],
"fireworks" => vec![
(
"accounts/fireworks/models/llama-v3p3-70b-instruct",
"Llama 3.3 70B",
),
(
"accounts/fireworks/models/mixtral-8x22b-instruct",
"Mixtral 8x22B",
),
],
"together-ai" => vec![
(
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
"Llama 3.1 70B Turbo",
),
(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
"Llama 3.1 8B Turbo",
),
("mistralai/Mixtral-8x22B-Instruct-v0.1", "Mixtral 8x22B"),
],
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec![
("deepseek-ai/DeepSeek-R1", "DeepSeek R1 (reasoning)"),
("meta/llama-3.1-70b-instruct", "Llama 3.1 70B Instruct"),
("mistralai/Mistral-7B-Instruct-v0.3", "Mistral 7B Instruct"),
("meta/llama-3.1-405b-instruct", "Llama 3.1 405B Instruct"),
],
"cohere" => vec![
("command-r-plus", "Command R+ (flagship)"),
("command-r", "Command R (fast)"),
],
"kimi-code" => vec![
(
"kimi-for-coding",
"Kimi for Coding (official coding-agent model)",
),
("kimi-k2.5", "Kimi K2.5 (general coding endpoint model)"),
],
"moonshot" => vec![
("moonshot-v1-128k", "Moonshot V1 128K"),
("moonshot-v1-32k", "Moonshot V1 32K"),
],
"glm" | "zai" => vec![
("glm-5", "GLM-5 (latest)"),
("glm-4-plus", "GLM-4 Plus (flagship)"),
("glm-4-flash", "GLM-4 Flash (fast)"),
],
"minimax" => MINIMAX_ONBOARD_MODELS.to_vec(),
"qwen" => vec![
("qwen-plus", "Qwen Plus (balanced default)"),
("qwen-max", "Qwen Max (highest quality)"),
("qwen-turbo", "Qwen Turbo (fast and cost-efficient)"),
],
"ollama" => vec![
("llama3.2", "Llama 3.2 (recommended local)"),
("mistral", "Mistral 7B"),
("codellama", "Code Llama"),
("phi3", "Phi-3 (small, fast)"),
],
"gemini" | "google" | "google-gemini" => vec![
("gemini-2.0-flash", "Gemini 2.0 Flash (fast, recommended)"),
(
"gemini-2.0-flash-lite",
"Gemini 2.0 Flash Lite (fastest, cheapest)",
),
("gemini-1.5-pro", "Gemini 1.5 Pro (best quality)"),
("gemini-1.5-flash", "Gemini 1.5 Flash (balanced)"),
],
"astrai" => vec![
("auto", "Auto — Astrai best execution routing (recommended)"),
("gpt-4o", "GPT-4o (OpenAI via Astrai)"),
(
"claude-sonnet-4.5",
"Claude Sonnet 4.5 (Anthropic via Astrai)",
),
("deepseek-v3", "DeepSeek V3 (best value via Astrai)"),
("llama-3.3-70b", "Llama 3.3 70B (open source via Astrai)"),
],
_ => vec![("default", "Default model")],
};
let mut model_options: Vec<(String, String)> = curated_models_for_provider(canonical_provider);
let mut model_options: Vec<(String, String)> = models
.into_iter()
.map(|(model_id, label)| (model_id.to_string(), label.to_string()))
.collect();
let mut live_options: Option<Vec<(String, String)>> = None;
if supports_live_model_fetch(provider_name) {
let can_fetch_without_key = matches!(provider_name, "openrouter" | "ollama");
let can_fetch_without_key = allows_unauthenticated_model_fetch(provider_name);
let has_api_key = !api_key.trim().is_empty()
|| std::env::var(provider_env_var(provider_name))
.ok()
@ -4663,7 +4617,12 @@ mod tests {
#[test]
fn default_model_for_provider_uses_latest_defaults() {
assert_eq!(
default_model_for_provider("openrouter"),
"anthropic/claude-sonnet-4.6"
);
assert_eq!(default_model_for_provider("openai"), "gpt-5.2");
assert_eq!(default_model_for_provider("openai-codex"), "gpt-5-codex");
assert_eq!(
default_model_for_provider("anthropic"),
"claude-sonnet-4-5-20250929"
@ -4680,6 +4639,16 @@ mod tests {
default_model_for_provider("google-gemini"),
"gemini-2.5-pro"
);
assert_eq!(default_model_for_provider("venice"), "zai-org-glm-5");
assert_eq!(default_model_for_provider("moonshot"), "kimi-k2.5");
assert_eq!(
default_model_for_provider("nvidia-nim"),
"meta/llama-3.3-70b-instruct"
);
assert_eq!(
default_model_for_provider("astrai"),
"anthropic/claude-sonnet-4.6"
);
}
#[test]
@ -4708,6 +4677,31 @@ mod tests {
assert!(ids.contains(&"gpt-5-mini".to_string()));
}
#[test]
fn curated_models_for_glm_removes_deprecated_flash_plus_aliases() {
let ids: Vec<String> = curated_models_for_provider("glm")
.into_iter()
.map(|(id, _)| id)
.collect();
assert!(ids.contains(&"glm-5".to_string()));
assert!(ids.contains(&"glm-4.7".to_string()));
assert!(ids.contains(&"glm-4.5-air".to_string()));
assert!(!ids.contains(&"glm-4-plus".to_string()));
assert!(!ids.contains(&"glm-4-flash".to_string()));
}
#[test]
fn curated_models_for_openai_codex_include_codex_family() {
let ids: Vec<String> = curated_models_for_provider("openai-codex")
.into_iter()
.map(|(id, _)| id)
.collect();
assert!(ids.contains(&"gpt-5-codex".to_string()));
assert!(ids.contains(&"gpt-5.2-codex".to_string()));
}
#[test]
fn curated_models_for_openrouter_use_valid_anthropic_id() {
let ids: Vec<String> = curated_models_for_provider("openrouter")
@ -4715,7 +4709,33 @@ mod tests {
.map(|(id, _)| id)
.collect();
assert!(ids.contains(&"anthropic/claude-sonnet-4.5".to_string()));
assert!(ids.contains(&"anthropic/claude-sonnet-4.6".to_string()));
}
#[test]
fn curated_models_for_moonshot_drop_deprecated_aliases() {
let ids: Vec<String> = curated_models_for_provider("moonshot")
.into_iter()
.map(|(id, _)| id)
.collect();
assert!(ids.contains(&"kimi-k2.5".to_string()));
assert!(ids.contains(&"kimi-k2-thinking".to_string()));
assert!(!ids.contains(&"kimi-latest".to_string()));
assert!(!ids.contains(&"kimi-thinking-preview".to_string()));
}
#[test]
fn allows_unauthenticated_model_fetch_for_public_catalogs() {
assert!(allows_unauthenticated_model_fetch("openrouter"));
assert!(allows_unauthenticated_model_fetch("venice"));
assert!(allows_unauthenticated_model_fetch("nvidia"));
assert!(allows_unauthenticated_model_fetch("nvidia-nim"));
assert!(allows_unauthenticated_model_fetch("build.nvidia.com"));
assert!(allows_unauthenticated_model_fetch("astrai"));
assert!(allows_unauthenticated_model_fetch("ollama"));
assert!(!allows_unauthenticated_model_fetch("openai"));
assert!(!allows_unauthenticated_model_fetch("deepseek"));
}
#[test]
@ -4739,7 +4759,11 @@ mod tests {
assert!(supports_live_model_fetch("together"));
assert!(supports_live_model_fetch("ollama"));
assert!(supports_live_model_fetch("astrai"));
assert!(!supports_live_model_fetch("venice"));
assert!(supports_live_model_fetch("venice"));
assert!(supports_live_model_fetch("glm-cn"));
assert!(supports_live_model_fetch("qwen-intl"));
assert!(!supports_live_model_fetch("minimax-cn"));
assert!(!supports_live_model_fetch("unknown-provider"));
}
#[test]
@ -4778,6 +4802,40 @@ mod tests {
);
}
#[test]
fn models_endpoint_for_provider_handles_region_aliases() {
assert_eq!(
models_endpoint_for_provider("glm-cn"),
Some("https://open.bigmodel.cn/api/paas/v4/models")
);
assert_eq!(
models_endpoint_for_provider("zai-cn"),
Some("https://open.bigmodel.cn/api/coding/paas/v4/models")
);
assert_eq!(
models_endpoint_for_provider("qwen-intl"),
Some("https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models")
);
}
#[test]
fn models_endpoint_for_provider_supports_additional_openai_compatible_providers() {
assert_eq!(
models_endpoint_for_provider("venice"),
Some("https://api.venice.ai/api/v1/models")
);
assert_eq!(
models_endpoint_for_provider("cohere"),
Some("https://api.cohere.com/compatibility/v1/models")
);
assert_eq!(
models_endpoint_for_provider("moonshot"),
Some("https://api.moonshot.ai/v1/models")
);
assert_eq!(models_endpoint_for_provider("perplexity"), None);
assert_eq!(models_endpoint_for_provider("unknown-provider"), None);
}
#[test]
fn parse_openai_model_ids_supports_data_array_payload() {
let payload = json!({

View file

@ -263,6 +263,8 @@ impl ResponseMessage {
#[derive(Debug, Deserialize, Serialize)]
struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type")]
kind: Option<String>,
function: Option<Function>,
@ -274,6 +276,30 @@ struct Function {
arguments: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeChatRequest {
model: String,
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Serialize)]
struct ResponsesRequest {
model: String,
@ -571,6 +597,169 @@ impl OpenAiCompatibleProvider {
extract_responses_text(responses)
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
}
fn convert_tool_specs(
tools: Option<&[crate::tools::ToolSpec]>,
) -> Option<Vec<serde_json::Value>> {
tools.map(|items| {
items
.iter()
.map(|tool| {
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
})
})
.collect()
})
}
fn convert_messages_for_native(messages: &[ChatMessage]) -> Vec<NativeMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
{
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(
tool_calls_value.clone(),
)
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| ToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: Some(Function {
name: Some(tc.name),
arguments: Some(tc.arguments),
}),
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string)
.or_else(|| Some(message.content.clone()));
return NativeMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
NativeMessage {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
fn with_prompt_guided_tool_instructions(
messages: &[ChatMessage],
tools: Option<&[crate::tools::ToolSpec]>,
) -> Vec<ChatMessage> {
let Some(tools) = tools else {
return messages.to_vec();
};
if tools.is_empty() {
return messages.to_vec();
}
let instructions = crate::providers::traits::build_tool_instructions_text(tools);
let mut modified_messages = messages.to_vec();
if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") {
if !system_message.content.is_empty() {
system_message.content.push_str("\n\n");
}
system_message.content.push_str(&instructions);
} else {
modified_messages.insert(0, ChatMessage::system(instructions));
}
modified_messages
}
fn parse_native_response(message: ResponseMessage) -> ProviderChatResponse {
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.filter_map(|tc| {
let function = tc.function?;
let name = function.name?;
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
Some(ProviderToolCall {
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name,
arguments,
})
})
.collect::<Vec<_>>();
ProviderChatResponse {
text: message.content,
tool_calls,
}
}
fn is_native_tool_schema_unsupported(status: reqwest::StatusCode, error: &str) -> bool {
if !matches!(
status,
reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::UNPROCESSABLE_ENTITY
) {
return false;
}
let lower = error.to_lowercase();
[
"unknown parameter: tools",
"unsupported parameter: tools",
"unrecognized field `tools`",
"does not support tools",
"function calling is not supported",
"tool_choice",
]
.iter()
.any(|hint| lower.contains(hint))
}
}
#[async_trait]
@ -846,49 +1035,83 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
// If native tools are requested, delegate to chat_with_tools.
if let Some(tools) = request.tools {
if !tools.is_empty() && self.supports_native_tools() {
let native_tools = Self::tool_specs_to_openai_format(tools);
return self
.chat_with_tools(request.messages, &native_tools, model, temperature)
.await;
}
}
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
)
})?;
let text = self
.chat_with_history(request.messages, model, temperature)
let tools = Self::convert_tool_specs(request.tools);
let native_request = NativeChatRequest {
model: model.to_string(),
messages: Self::convert_messages_for_native(request.messages),
temperature,
stream: Some(false),
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
tools,
};
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&native_request), credential)
.send()
.await?;
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
let parsed_text = message.effective_content_optional();
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.filter_map(|tc| {
let function = tc.function?;
let name = function.name?;
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
Some(ProviderToolCall {
id: uuid::Uuid::new_v4().to_string(),
name,
arguments,
})
})
.collect::<Vec<_>>();
if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
let sanitized = super::sanitize_api_error(&error);
return Ok(ProviderChatResponse {
text: parsed_text,
tool_calls,
});
if Self::is_native_tool_schema_unsupported(status, &sanitized) {
let fallback_messages =
Self::with_prompt_guided_tool_instructions(request.messages, request.tools);
let text = self
.chat_with_history(&fallback_messages, model, temperature)
.await?;
return Ok(ProviderChatResponse {
text: Some(text),
tool_calls: vec![],
});
}
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
let system = request.messages.iter().find(|m| m.role == "system");
let last_user = request.messages.iter().rfind(|m| m.role == "user");
if let Some(user_msg) = last_user {
return self
.chat_via_responses(
credential,
system.map(|m| m.content.as_str()),
&user_msg.content,
model,
)
.await
.map(|text| ProviderChatResponse {
text: Some(text),
tool_calls: vec![],
})
.map_err(|responses_err| {
anyhow::anyhow!(
"{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})",
self.name
)
});
}
}
anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
}
Ok(ProviderChatResponse {
text: Some(text),
tool_calls: vec![],
})
let native_response: ApiChatResponse = response.json().await?;
let message = native_response
.choices
.into_iter()
.next()
.map(|choice| choice.message)
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
Ok(Self::parse_native_response(message))
}
fn supports_native_tools(&self) -> bool {
@ -1400,6 +1623,76 @@ mod tests {
);
}
#[test]
fn parse_native_response_preserves_tool_call_id() {
let message = ResponseMessage {
content: None,
tool_calls: Some(vec![ToolCall {
id: Some("call_123".to_string()),
kind: Some("function".to_string()),
function: Some(Function {
name: Some("shell".to_string()),
arguments: Some(r#"{"command":"pwd"}"#.to_string()),
}),
}]),
};
let parsed = OpenAiCompatibleProvider::parse_native_response(message);
assert_eq!(parsed.tool_calls.len(), 1);
assert_eq!(parsed.tool_calls[0].id, "call_123");
assert_eq!(parsed.tool_calls[0].name, "shell");
}
#[test]
fn convert_messages_for_native_maps_tool_result_payload() {
let input = vec![ChatMessage::tool(
r#"{"tool_call_id":"call_abc","content":"done"}"#,
)];
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
assert_eq!(converted[0].content.as_deref(), Some("done"));
}
#[test]
fn native_tool_schema_unsupported_detection_is_precise() {
assert!(OpenAiCompatibleProvider::is_native_tool_schema_unsupported(
reqwest::StatusCode::BAD_REQUEST,
"unknown parameter: tools"
));
assert!(
!OpenAiCompatibleProvider::is_native_tool_schema_unsupported(
reqwest::StatusCode::UNAUTHORIZED,
"unknown parameter: tools"
)
);
}
#[test]
fn prompt_guided_tool_fallback_injects_system_instruction() {
let input = vec![ChatMessage::user("check status")];
let tools = vec![crate::tools::ToolSpec {
name: "shell_exec".to_string(),
description: "Execute shell command".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"command": { "type": "string" }
},
"required": ["command"]
}),
}];
let output =
OpenAiCompatibleProvider::with_prompt_guided_tool_instructions(&input, Some(&tools));
assert!(!output.is_empty());
assert_eq!(output[0].role, "system");
assert!(output[0].content.contains("Available Tools"));
assert!(output[0].content.contains("shell_exec"));
}
#[tokio::test]
async fn warmup_without_key_is_noop() {
let provider = make_provider("test", "https://example.com", None);

View file

@ -67,6 +67,52 @@ fn is_rate_limited(err: &anyhow::Error) -> bool {
&& (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
}
/// Check if a 429 is a business/quota-plan error that retries cannot fix.
///
/// Examples:
/// - plan does not include requested model
/// - insufficient balance / package not active
/// - known provider business codes (e.g. Z.AI: 1311, 1113)
fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
if !is_rate_limited(err) {
return false;
}
let msg = err.to_string();
let lower = msg.to_lowercase();
let business_hints = [
"plan does not include",
"doesn't include",
"not include",
"insufficient balance",
"insufficient_balance",
"insufficient quota",
"insufficient_quota",
"quota exhausted",
"out of credits",
"no available package",
"package not active",
"purchase package",
"model not available for your plan",
];
if business_hints.iter().any(|hint| lower.contains(hint)) {
return true;
}
// Known provider business codes observed for 429 where retry is futile.
for token in lower.split(|c: char| !c.is_ascii_digit()) {
if let Ok(code) = token.parse::<u16>() {
if matches!(code, 1113 | 1311) {
return true;
}
}
}
false
}
/// Try to extract a Retry-After value (in milliseconds) from an error message.
/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
@ -101,7 +147,9 @@ fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
}
fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
if rate_limited {
if rate_limited && non_retryable {
"rate_limited_non_retryable"
} else if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
@ -244,7 +292,8 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
@ -260,7 +309,7 @@ impl Provider for ReliableProvider {
);
// On rate-limit, try rotating API key
if rate_limited {
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
@ -352,7 +401,8 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
@ -367,7 +417,7 @@ impl Provider for ReliableProvider {
&error_detail,
);
if rate_limited {
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
@ -459,7 +509,8 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
@ -474,7 +525,7 @@ impl Provider for ReliableProvider {
&error_detail,
);
if rate_limited {
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
@ -1106,6 +1157,39 @@ mod tests {
)));
}
#[test]
fn non_retryable_rate_limit_detects_plan_restricted_model() {
let err = anyhow::anyhow!(
"{}",
"API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}"
);
assert!(
is_non_retryable_rate_limit(&err),
"plan-restricted 429 should skip retries"
);
}
#[test]
fn non_retryable_rate_limit_detects_insufficient_balance() {
let err = anyhow::anyhow!(
"{}",
"API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}"
);
assert!(
is_non_retryable_rate_limit(&err),
"insufficient-balance 429 should skip retries"
);
}
#[test]
fn non_retryable_rate_limit_does_not_flag_generic_429() {
let err = anyhow::anyhow!("429 Too Many Requests: rate limit exceeded");
assert!(
!is_non_retryable_rate_limit(&err),
"generic rate-limit 429 should remain retryable"
);
}
#[test]
fn compute_backoff_uses_retry_after() {
let provider = ReliableProvider::new(vec![], 0, 500);
@ -1261,6 +1345,35 @@ mod tests {
);
}
#[tokio::test]
async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
}),
)],
5,
1,
);
let result = provider.simple_chat("hello", "test", 0.0).await;
assert!(
result.is_err(),
"plan-restricted 429 should fail quickly without retrying"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"must not retry non-retryable 429 business errors"
);
}
// ── Arc<ModelAwareMock> Provider impl for test ──
#[async_trait]