fix(providers): harden tool fallback and refresh model catalogs
This commit is contained in:
parent
43494f8331
commit
b4b379e3e7
9 changed files with 1111 additions and 367 deletions
|
|
@ -2,7 +2,7 @@ use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
|
|||
use crate::config::Config;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::observability::{self, Observer, ObserverEvent};
|
||||
use crate::providers::{self, ChatMessage, Provider, ToolCall};
|
||||
use crate::providers::{self, ChatMessage, ChatRequest, Provider, ToolCall};
|
||||
use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool};
|
||||
|
|
@ -868,13 +868,9 @@ pub(crate) async fn run_tool_call_loop(
|
|||
max_tool_iterations
|
||||
};
|
||||
|
||||
// Build native tool definitions once if the provider supports them.
|
||||
let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty();
|
||||
let tool_definitions = if use_native_tools {
|
||||
tools_to_openai_format(tools_registry)
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
let tool_specs: Vec<crate::tools::ToolSpec> =
|
||||
tools_registry.iter().map(|tool| tool.spec()).collect();
|
||||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
for _iteration in 0..max_iterations {
|
||||
observer.record_event(&ObserverEvent::LlmRequest {
|
||||
|
|
@ -885,101 +881,73 @@ pub(crate) async fn run_tool_call_loop(
|
|||
|
||||
let llm_started_at = Instant::now();
|
||||
|
||||
// Choose between native tool-call API and prompt-based tool use.
|
||||
// `native_tool_calls` preserves the structured ToolCall vec (with IDs) so
|
||||
// that tool results can later be sent back as proper `role: tool` messages.
|
||||
// Unified path via Provider::chat so provider-specific native tool logic
|
||||
// (OpenAI/Anthropic/OpenRouter/compatible adapters) is honored.
|
||||
let request_tools = if use_native_tools {
|
||||
Some(tool_specs.as_slice())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) =
|
||||
if use_native_tools {
|
||||
match provider
|
||||
.chat_with_tools(history, &tool_definitions, model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
});
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
|
||||
let mut parsed_text = String::new();
|
||||
match provider
|
||||
.chat(
|
||||
ChatRequest {
|
||||
messages: history,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
});
|
||||
|
||||
if calls.is_empty() {
|
||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
||||
if !fallback_text.is_empty() {
|
||||
parsed_text = fallback_text;
|
||||
}
|
||||
calls = fallback_calls;
|
||||
let response_text = resp.text_or_empty().to_string();
|
||||
let mut calls = parse_structured_tool_calls(&resp.tool_calls);
|
||||
let mut parsed_text = String::new();
|
||||
|
||||
if calls.is_empty() {
|
||||
let (fallback_text, fallback_calls) = parse_tool_calls(&response_text);
|
||||
if !fallback_text.is_empty() {
|
||||
parsed_text = fallback_text;
|
||||
}
|
||||
|
||||
// Use JSON format for native tools so convert_messages()
|
||||
// can reconstruct proper NativeMessage with tool_calls.
|
||||
let assistant_history_content = if resp.tool_calls.is_empty() {
|
||||
response_text.clone()
|
||||
} else {
|
||||
build_native_assistant_history(&response_text, &resp.tool_calls)
|
||||
};
|
||||
|
||||
let native_calls = resp.tool_calls;
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
native_calls,
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(crate::providers::sanitize_api_error(
|
||||
&e.to_string(),
|
||||
)),
|
||||
});
|
||||
return Err(e);
|
||||
calls = fallback_calls;
|
||||
}
|
||||
|
||||
// Preserve native tool call IDs in assistant history so role=tool
|
||||
// follow-up messages can reference the exact call id.
|
||||
let assistant_history_content = if resp.tool_calls.is_empty() {
|
||||
response_text.clone()
|
||||
} else {
|
||||
build_native_assistant_history(&response_text, &resp.tool_calls)
|
||||
};
|
||||
|
||||
let native_calls = resp.tool_calls;
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
native_calls,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
match provider
|
||||
.chat_with_history(history, model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: true,
|
||||
error_message: None,
|
||||
});
|
||||
let response_text = resp;
|
||||
let assistant_history_content = response_text.clone();
|
||||
let (parsed_text, calls) = parse_tool_calls(&response_text);
|
||||
(
|
||||
response_text,
|
||||
parsed_text,
|
||||
calls,
|
||||
assistant_history_content,
|
||||
Vec::new(),
|
||||
)
|
||||
}
|
||||
Err(e) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(crate::providers::sanitize_api_error(
|
||||
&e.to_string(),
|
||||
)),
|
||||
});
|
||||
return Err(e);
|
||||
}
|
||||
Err(e) => {
|
||||
observer.record_event(&ObserverEvent::LlmResponse {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
duration: llm_started_at.elapsed(),
|
||||
success: false,
|
||||
error_message: Some(crate::providers::sanitize_api_error(&e.to_string())),
|
||||
});
|
||||
return Err(e);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -2402,7 +2402,7 @@ impl Default for Config {
|
|||
api_key: None,
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".to_string()),
|
||||
default_model: Some("anthropic/claude-sonnet-4".to_string()),
|
||||
default_model: Some("anthropic/claude-sonnet-4.6".to_string()),
|
||||
default_temperature: 0.7,
|
||||
observability: ObservabilityConfig::default(),
|
||||
autonomy: AutonomyConfig::default(),
|
||||
|
|
|
|||
|
|
@ -99,6 +99,132 @@ pub fn run(config: &Config) -> Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ModelProbeOutcome {
|
||||
Ok,
|
||||
Skipped,
|
||||
AuthOrAccess,
|
||||
Error,
|
||||
}
|
||||
|
||||
fn classify_model_probe_error(err_message: &str) -> ModelProbeOutcome {
|
||||
let lower = err_message.to_lowercase();
|
||||
|
||||
if lower.contains("does not support live model discovery") {
|
||||
return ModelProbeOutcome::Skipped;
|
||||
}
|
||||
|
||||
if [
|
||||
"401",
|
||||
"403",
|
||||
"429",
|
||||
"unauthorized",
|
||||
"forbidden",
|
||||
"api key",
|
||||
"token",
|
||||
"insufficient balance",
|
||||
"insufficient quota",
|
||||
"plan does not include",
|
||||
"rate limit",
|
||||
]
|
||||
.iter()
|
||||
.any(|hint| lower.contains(hint))
|
||||
{
|
||||
return ModelProbeOutcome::AuthOrAccess;
|
||||
}
|
||||
|
||||
ModelProbeOutcome::Error
|
||||
}
|
||||
|
||||
fn doctor_model_targets(provider_override: Option<&str>) -> Vec<String> {
|
||||
if let Some(provider) = provider_override.map(str::trim).filter(|p| !p.is_empty()) {
|
||||
return vec![provider.to_string()];
|
||||
}
|
||||
|
||||
crate::providers::list_providers()
|
||||
.into_iter()
|
||||
.map(|provider| provider.name.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn run_models(config: &Config, provider_override: Option<&str>, use_cache: bool) -> Result<()> {
|
||||
let targets = doctor_model_targets(provider_override);
|
||||
|
||||
if targets.is_empty() {
|
||||
anyhow::bail!("No providers available for model probing");
|
||||
}
|
||||
|
||||
println!("🩺 ZeroClaw Doctor — Model Catalog Probe");
|
||||
println!(" Providers to probe: {}", targets.len());
|
||||
println!(
|
||||
" Mode: {}",
|
||||
if use_cache {
|
||||
"cache-first"
|
||||
} else {
|
||||
"force live refresh"
|
||||
}
|
||||
);
|
||||
println!();
|
||||
|
||||
let mut ok_count = 0usize;
|
||||
let mut skipped_count = 0usize;
|
||||
let mut auth_count = 0usize;
|
||||
let mut error_count = 0usize;
|
||||
|
||||
for provider_name in &targets {
|
||||
println!(" [{}]", provider_name);
|
||||
|
||||
match crate::onboard::run_models_refresh(config, Some(provider_name), !use_cache) {
|
||||
Ok(()) => {
|
||||
ok_count += 1;
|
||||
println!(" ✅ model catalog check passed");
|
||||
}
|
||||
Err(error) => {
|
||||
let error_text = format_error_chain(&error);
|
||||
match classify_model_probe_error(&error_text) {
|
||||
ModelProbeOutcome::Skipped => {
|
||||
skipped_count += 1;
|
||||
println!(" ⚪ skipped: {}", truncate_for_display(&error_text, 160));
|
||||
}
|
||||
ModelProbeOutcome::AuthOrAccess => {
|
||||
auth_count += 1;
|
||||
println!(
|
||||
" ⚠️ auth/access: {}",
|
||||
truncate_for_display(&error_text, 160)
|
||||
);
|
||||
}
|
||||
ModelProbeOutcome::Error => {
|
||||
error_count += 1;
|
||||
println!(" ❌ error: {}", truncate_for_display(&error_text, 160));
|
||||
}
|
||||
ModelProbeOutcome::Ok => {
|
||||
ok_count += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
println!(
|
||||
" Summary: {} ok, {} skipped, {} auth/access, {} errors",
|
||||
ok_count, skipped_count, auth_count, error_count
|
||||
);
|
||||
|
||||
if auth_count > 0 {
|
||||
println!(
|
||||
" 💡 Some providers need valid API keys/plan access before `/models` can be fetched."
|
||||
);
|
||||
}
|
||||
|
||||
if provider_override.is_some() && ok_count == 0 {
|
||||
anyhow::bail!("Model probe failed for target provider")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ── Config semantic validation ───────────────────────────────────
|
||||
|
||||
fn check_config_semantics(config: &Config, items: &mut Vec<DiagItem>) {
|
||||
|
|
@ -572,6 +698,22 @@ fn check_command_available(cmd: &str, args: &[&str], cat: &'static str, items: &
|
|||
}
|
||||
}
|
||||
|
||||
fn format_error_chain(error: &anyhow::Error) -> String {
|
||||
let mut parts = Vec::new();
|
||||
for cause in error.chain() {
|
||||
let message = cause.to_string();
|
||||
if !message.is_empty() {
|
||||
parts.push(message);
|
||||
}
|
||||
}
|
||||
|
||||
if parts.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
parts.join(": ")
|
||||
}
|
||||
|
||||
fn truncate_for_display(input: &str, max_chars: usize) -> String {
|
||||
let mut chars = input.chars();
|
||||
let preview: String = chars.by_ref().take(max_chars).collect();
|
||||
|
|
@ -615,6 +757,25 @@ mod tests {
|
|||
assert_eq!(DiagItem::error("t", "m").icon(), "❌");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_model_probe_error_marks_unsupported_as_skipped() {
|
||||
let outcome = classify_model_probe_error(
|
||||
"Provider 'copilot' does not support live model discovery yet",
|
||||
);
|
||||
assert_eq!(outcome, ModelProbeOutcome::Skipped);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn classify_model_probe_error_marks_auth_and_plan_issues() {
|
||||
let auth_outcome = classify_model_probe_error("OpenAI API error (401): unauthorized");
|
||||
assert_eq!(auth_outcome, ModelProbeOutcome::AuthOrAccess);
|
||||
|
||||
let plan_outcome = classify_model_probe_error(
|
||||
"Z.AI API error (429): plan does not include requested model",
|
||||
);
|
||||
assert_eq!(plan_outcome, ModelProbeOutcome::AuthOrAccess);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_validation_catches_bad_temperature() {
|
||||
let mut config = Config::default();
|
||||
|
|
|
|||
41
src/main.rs
41
src/main.rs
|
|
@ -178,7 +178,10 @@ enum Commands {
|
|||
},
|
||||
|
||||
/// Run diagnostics for daemon/scheduler/channel freshness
|
||||
Doctor,
|
||||
Doctor {
|
||||
#[command(subcommand)]
|
||||
doctor_command: Option<DoctorCommands>,
|
||||
},
|
||||
|
||||
/// Show system status (full details)
|
||||
Status,
|
||||
|
|
@ -404,6 +407,20 @@ enum ModelCommands {
|
|||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum DoctorCommands {
|
||||
/// Probe model catalogs across providers and report availability
|
||||
Models {
|
||||
/// Probe a specific provider only (default: all known providers)
|
||||
#[arg(long)]
|
||||
provider: Option<String>,
|
||||
|
||||
/// Prefer cached catalogs when available (skip forced live refresh)
|
||||
#[arg(long)]
|
||||
use_cache: bool,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum ChannelCommands {
|
||||
/// List configured channels
|
||||
|
|
@ -646,7 +663,12 @@ async fn main() -> Result<()> {
|
|||
|
||||
Commands::Models { model_command } => match model_command {
|
||||
ModelCommands::Refresh { provider, force } => {
|
||||
onboard::run_models_refresh(&config, provider.as_deref(), force)
|
||||
let config_for_refresh = config.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
onboard::run_models_refresh(&config_for_refresh, provider.as_deref(), force)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("models refresh task failed: {e}"))?
|
||||
}
|
||||
},
|
||||
|
||||
|
|
@ -685,7 +707,20 @@ async fn main() -> Result<()> {
|
|||
|
||||
Commands::Service { service_command } => service::handle_command(&service_command, &config),
|
||||
|
||||
Commands::Doctor => doctor::run(&config),
|
||||
Commands::Doctor { doctor_command } => match doctor_command {
|
||||
Some(DoctorCommands::Models {
|
||||
provider,
|
||||
use_cache,
|
||||
}) => {
|
||||
let config_for_models = config.clone();
|
||||
tokio::task::spawn_blocking(move || {
|
||||
doctor::run_models(&config_for_models, provider.as_deref(), use_cache)
|
||||
})
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("doctor models task failed: {e}"))?
|
||||
}
|
||||
None => doctor::run(&config),
|
||||
},
|
||||
|
||||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => channels::start_channels(config).await,
|
||||
|
|
|
|||
|
|
@ -476,6 +476,19 @@ fn canonical_provider_name(provider_name: &str) -> &str {
|
|||
}
|
||||
}
|
||||
|
||||
fn allows_unauthenticated_model_fetch(provider_name: &str) -> bool {
|
||||
matches!(
|
||||
canonical_provider_name(provider_name),
|
||||
"openrouter"
|
||||
| "ollama"
|
||||
| "venice"
|
||||
| "astrai"
|
||||
| "nvidia"
|
||||
| "nvidia-nim"
|
||||
| "build.nvidia.com"
|
||||
)
|
||||
}
|
||||
|
||||
/// Pick a sensible default model for the given provider.
|
||||
const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [
|
||||
("MiniMax-M2.5", "MiniMax M2.5 (latest, recommended)"),
|
||||
|
|
@ -488,16 +501,28 @@ const MINIMAX_ONBOARD_MODELS: [(&str, &str); 5] = [
|
|||
fn default_model_for_provider(provider: &str) -> String {
|
||||
match canonical_provider_name(provider) {
|
||||
"anthropic" => "claude-sonnet-4-5-20250929".into(),
|
||||
"openrouter" => "anthropic/claude-sonnet-4.6".into(),
|
||||
"openai" => "gpt-5.2".into(),
|
||||
"openai-codex" => "gpt-5-codex".into(),
|
||||
"venice" => "zai-org-glm-5".into(),
|
||||
"groq" => "llama-3.3-70b-versatile".into(),
|
||||
"mistral" => "mistral-large-latest".into(),
|
||||
"deepseek" => "deepseek-chat".into(),
|
||||
"xai" => "grok-4-1-fast-reasoning".into(),
|
||||
"perplexity" => "sonar-pro".into(),
|
||||
"fireworks" => "accounts/fireworks/models/llama-v3p3-70b-instruct".into(),
|
||||
"together-ai" => "meta-llama/Llama-3.3-70B-Instruct-Turbo".into(),
|
||||
"cohere" => "command-a-03-2025".into(),
|
||||
"moonshot" => "kimi-k2.5".into(),
|
||||
"glm" | "zai" => "glm-5".into(),
|
||||
"minimax" => "MiniMax-M2.5".into(),
|
||||
"qwen" => "qwen-plus".into(),
|
||||
"ollama" => "llama3.2".into(),
|
||||
"groq" => "llama-3.3-70b-versatile".into(),
|
||||
"deepseek" => "deepseek-chat".into(),
|
||||
"gemini" => "gemini-2.5-pro".into(),
|
||||
"kimi-code" => "kimi-for-coding".into(),
|
||||
_ => "anthropic/claude-sonnet-4.5".into(),
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => "meta/llama-3.3-70b-instruct".into(),
|
||||
"astrai" => "anthropic/claude-sonnet-4.6".into(),
|
||||
_ => "anthropic/claude-sonnet-4.6".into(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -505,8 +530,8 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
|||
match canonical_provider_name(provider_name) {
|
||||
"openrouter" => vec![
|
||||
(
|
||||
"anthropic/claude-sonnet-4.5".to_string(),
|
||||
"Claude Sonnet 4.5 (balanced, recommended)".to_string(),
|
||||
"anthropic/claude-sonnet-4.6".to_string(),
|
||||
"Claude Sonnet 4.6 (balanced, recommended)".to_string(),
|
||||
),
|
||||
(
|
||||
"openai/gpt-5.2".to_string(),
|
||||
|
|
@ -565,18 +590,33 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
|||
"GPT-5.2 Codex (agentic coding)".to_string(),
|
||||
),
|
||||
],
|
||||
"openai-codex" => vec![
|
||||
(
|
||||
"gpt-5-codex".to_string(),
|
||||
"GPT-5 Codex (recommended)".to_string(),
|
||||
),
|
||||
(
|
||||
"gpt-5.2-codex".to_string(),
|
||||
"GPT-5.2 Codex (agentic coding)".to_string(),
|
||||
),
|
||||
("o4-mini".to_string(), "o4-mini (fallback)".to_string()),
|
||||
],
|
||||
"venice" => vec![
|
||||
(
|
||||
"llama-3.3-70b".to_string(),
|
||||
"Llama 3.3 70B (default, fast)".to_string(),
|
||||
"zai-org-glm-5".to_string(),
|
||||
"GLM-5 via Venice (agentic flagship)".to_string(),
|
||||
),
|
||||
(
|
||||
"claude-opus-45".to_string(),
|
||||
"Claude Opus 4.5 via Venice (strongest)".to_string(),
|
||||
"claude-sonnet-4-6".to_string(),
|
||||
"Claude Sonnet 4.6 via Venice (best quality)".to_string(),
|
||||
),
|
||||
(
|
||||
"llama-3.1-405b".to_string(),
|
||||
"Llama 3.1 405B (largest open source)".to_string(),
|
||||
"deepseek-v3.2".to_string(),
|
||||
"DeepSeek V3.2 via Venice (strong value)".to_string(),
|
||||
),
|
||||
(
|
||||
"grok-41-fast".to_string(),
|
||||
"Grok 4.1 Fast via Venice (low latency)".to_string(),
|
||||
),
|
||||
],
|
||||
"groq" => vec![
|
||||
|
|
@ -701,27 +741,27 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
|||
],
|
||||
"moonshot" => vec![
|
||||
(
|
||||
"kimi-latest".to_string(),
|
||||
"Kimi Latest (rolling latest assistant model)".to_string(),
|
||||
"kimi-k2.5".to_string(),
|
||||
"Kimi K2.5 (latest flagship, recommended)".to_string(),
|
||||
),
|
||||
(
|
||||
"kimi-k2-thinking".to_string(),
|
||||
"Kimi K2 Thinking (deep reasoning + tool use)".to_string(),
|
||||
),
|
||||
(
|
||||
"kimi-k2-0905-preview".to_string(),
|
||||
"Kimi K2 0905 Preview (strong coding)".to_string(),
|
||||
),
|
||||
(
|
||||
"kimi-thinking-preview".to_string(),
|
||||
"Kimi Thinking Preview (deep reasoning)".to_string(),
|
||||
),
|
||||
],
|
||||
"glm" | "zai" => vec![
|
||||
(
|
||||
"glm-4.7".to_string(),
|
||||
"GLM-4.7 (latest flagship)".to_string(),
|
||||
),
|
||||
("glm-5".to_string(), "GLM-5 (high reasoning)".to_string()),
|
||||
(
|
||||
"glm-4-plus".to_string(),
|
||||
"GLM-4 Plus (stable baseline)".to_string(),
|
||||
"glm-4.7".to_string(),
|
||||
"GLM-4.7 (strong general-purpose quality)".to_string(),
|
||||
),
|
||||
(
|
||||
"glm-4.5-air".to_string(),
|
||||
"GLM-4.5 Air (lower latency)".to_string(),
|
||||
),
|
||||
],
|
||||
"minimax" => vec![
|
||||
|
|
@ -730,12 +770,12 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
|||
"MiniMax M2.5 (latest flagship)".to_string(),
|
||||
),
|
||||
(
|
||||
"MiniMax-M2.1".to_string(),
|
||||
"MiniMax M2.1 (strong coding/reasoning)".to_string(),
|
||||
"MiniMax-M2.5-highspeed".to_string(),
|
||||
"MiniMax M2.5 High-Speed (fast)".to_string(),
|
||||
),
|
||||
(
|
||||
"MiniMax-M2.1-lightning".to_string(),
|
||||
"MiniMax M2.1 Lightning (fast)".to_string(),
|
||||
"MiniMax-M2.1".to_string(),
|
||||
"MiniMax M2.1 (strong coding/reasoning)".to_string(),
|
||||
),
|
||||
],
|
||||
"qwen" => vec![
|
||||
|
|
@ -752,6 +792,42 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> {
|
|||
"Qwen Turbo (fast and cost-efficient)".to_string(),
|
||||
),
|
||||
],
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec![
|
||||
(
|
||||
"meta/llama-3.3-70b-instruct".to_string(),
|
||||
"Llama 3.3 70B Instruct (balanced default)".to_string(),
|
||||
),
|
||||
(
|
||||
"deepseek-ai/deepseek-v3.2".to_string(),
|
||||
"DeepSeek V3.2 (reasoning + coding)".to_string(),
|
||||
),
|
||||
(
|
||||
"google/gemma-3-27b-it".to_string(),
|
||||
"Gemma 3 27B IT (cost-efficient)".to_string(),
|
||||
),
|
||||
(
|
||||
"meta/llama-3.1-405b-instruct".to_string(),
|
||||
"Llama 3.1 405B Instruct (max quality)".to_string(),
|
||||
),
|
||||
],
|
||||
"astrai" => vec![
|
||||
(
|
||||
"anthropic/claude-sonnet-4.6".to_string(),
|
||||
"Claude Sonnet 4.6 (balanced default)".to_string(),
|
||||
),
|
||||
(
|
||||
"openai/gpt-5.2".to_string(),
|
||||
"GPT-5.2 (latest flagship)".to_string(),
|
||||
),
|
||||
(
|
||||
"deepseek/deepseek-v3.2".to_string(),
|
||||
"DeepSeek V3.2 (agentic + affordable)".to_string(),
|
||||
),
|
||||
(
|
||||
"z-ai/glm-5".to_string(),
|
||||
"GLM-5 (high reasoning)".to_string(),
|
||||
),
|
||||
],
|
||||
"ollama" => vec![
|
||||
(
|
||||
"llama3.2".to_string(),
|
||||
|
|
@ -797,9 +873,49 @@ fn supports_live_model_fetch(provider_name: &str) -> bool {
|
|||
| "gemini"
|
||||
| "ollama"
|
||||
| "astrai"
|
||||
| "venice"
|
||||
| "fireworks"
|
||||
| "cohere"
|
||||
| "moonshot"
|
||||
| "glm"
|
||||
| "zai"
|
||||
| "qwen"
|
||||
| "nvidia"
|
||||
| "nvidia-nim"
|
||||
| "build.nvidia.com"
|
||||
)
|
||||
}
|
||||
|
||||
fn models_endpoint_for_provider(provider_name: &str) -> Option<&'static str> {
|
||||
match provider_name {
|
||||
"qwen-intl" => Some("https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models"),
|
||||
"dashscope-us" => Some("https://dashscope-us.aliyuncs.com/compatible-mode/v1/models"),
|
||||
"moonshot-cn" | "kimi-cn" => Some("https://api.moonshot.cn/v1/models"),
|
||||
"glm-cn" | "bigmodel" => Some("https://open.bigmodel.cn/api/paas/v4/models"),
|
||||
"zai-cn" | "z.ai-cn" => Some("https://open.bigmodel.cn/api/coding/paas/v4/models"),
|
||||
_ => match canonical_provider_name(provider_name) {
|
||||
"openai" => Some("https://api.openai.com/v1/models"),
|
||||
"venice" => Some("https://api.venice.ai/api/v1/models"),
|
||||
"groq" => Some("https://api.groq.com/openai/v1/models"),
|
||||
"mistral" => Some("https://api.mistral.ai/v1/models"),
|
||||
"deepseek" => Some("https://api.deepseek.com/v1/models"),
|
||||
"xai" => Some("https://api.x.ai/v1/models"),
|
||||
"together-ai" => Some("https://api.together.xyz/v1/models"),
|
||||
"fireworks" => Some("https://api.fireworks.ai/inference/v1/models"),
|
||||
"cohere" => Some("https://api.cohere.com/compatibility/v1/models"),
|
||||
"moonshot" => Some("https://api.moonshot.ai/v1/models"),
|
||||
"glm" => Some("https://api.z.ai/api/paas/v4/models"),
|
||||
"zai" => Some("https://api.z.ai/api/coding/paas/v4/models"),
|
||||
"qwen" => Some("https://dashscope.aliyuncs.com/compatible-mode/v1/models"),
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => {
|
||||
Some("https://integrate.api.nvidia.com/v1/models")
|
||||
}
|
||||
"astrai" => Some("https://as-trai.com/v1/models"),
|
||||
_ => None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn build_model_fetch_client() -> Result<reqwest::blocking::Client> {
|
||||
reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(8))
|
||||
|
|
@ -882,15 +998,21 @@ fn parse_ollama_model_ids(payload: &Value) -> Vec<String> {
|
|||
normalize_model_ids(ids)
|
||||
}
|
||||
|
||||
fn fetch_openai_compatible_models(endpoint: &str, api_key: Option<&str>) -> Result<Vec<String>> {
|
||||
let Some(api_key) = api_key else {
|
||||
return Ok(Vec::new());
|
||||
};
|
||||
|
||||
fn fetch_openai_compatible_models(
|
||||
endpoint: &str,
|
||||
api_key: Option<&str>,
|
||||
allow_unauthenticated: bool,
|
||||
) -> Result<Vec<String>> {
|
||||
let client = build_model_fetch_client()?;
|
||||
let payload: Value = client
|
||||
.get(endpoint)
|
||||
.bearer_auth(api_key)
|
||||
let mut request = client.get(endpoint);
|
||||
|
||||
if let Some(api_key) = api_key {
|
||||
request = request.bearer_auth(api_key);
|
||||
} else if !allow_unauthenticated {
|
||||
bail!("model fetch requires API key for endpoint {endpoint}");
|
||||
}
|
||||
|
||||
let payload: Value = request
|
||||
.send()
|
||||
.and_then(reqwest::blocking::Response::error_for_status)
|
||||
.with_context(|| format!("model fetch failed: GET {endpoint}"))?
|
||||
|
|
@ -919,7 +1041,7 @@ fn fetch_openrouter_models(api_key: Option<&str>) -> Result<Vec<String>> {
|
|||
|
||||
fn fetch_anthropic_models(api_key: Option<&str>) -> Result<Vec<String>> {
|
||||
let Some(api_key) = api_key else {
|
||||
return Ok(Vec::new());
|
||||
bail!("Anthropic model fetch requires API key or OAuth token");
|
||||
};
|
||||
|
||||
let client = build_model_fetch_client()?;
|
||||
|
|
@ -954,7 +1076,7 @@ fn fetch_anthropic_models(api_key: Option<&str>) -> Result<Vec<String>> {
|
|||
|
||||
fn fetch_gemini_models(api_key: Option<&str>) -> Result<Vec<String>> {
|
||||
let Some(api_key) = api_key else {
|
||||
return Ok(Vec::new());
|
||||
bail!("Gemini model fetch requires API key");
|
||||
};
|
||||
|
||||
let client = build_model_fetch_client()?;
|
||||
|
|
@ -984,6 +1106,7 @@ fn fetch_ollama_models() -> Result<Vec<String>> {
|
|||
}
|
||||
|
||||
fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result<Vec<String>> {
|
||||
let requested_provider_name = provider_name;
|
||||
let provider_name = canonical_provider_name(provider_name);
|
||||
let api_key = if api_key.trim().is_empty() {
|
||||
std::env::var(provider_env_var(provider_name))
|
||||
|
|
@ -1006,25 +1129,6 @@ fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result<
|
|||
|
||||
let models = match provider_name {
|
||||
"openrouter" => fetch_openrouter_models(api_key.as_deref())?,
|
||||
"openai" => {
|
||||
fetch_openai_compatible_models("https://api.openai.com/v1/models", api_key.as_deref())?
|
||||
}
|
||||
"groq" => fetch_openai_compatible_models(
|
||||
"https://api.groq.com/openai/v1/models",
|
||||
api_key.as_deref(),
|
||||
)?,
|
||||
"mistral" => {
|
||||
fetch_openai_compatible_models("https://api.mistral.ai/v1/models", api_key.as_deref())?
|
||||
}
|
||||
"deepseek" => fetch_openai_compatible_models(
|
||||
"https://api.deepseek.com/v1/models",
|
||||
api_key.as_deref(),
|
||||
)?,
|
||||
"xai" => fetch_openai_compatible_models("https://api.x.ai/v1/models", api_key.as_deref())?,
|
||||
"together-ai" => fetch_openai_compatible_models(
|
||||
"https://api.together.xyz/v1/models",
|
||||
api_key.as_deref(),
|
||||
)?,
|
||||
"anthropic" => fetch_anthropic_models(api_key.as_deref())?,
|
||||
"gemini" => fetch_gemini_models(api_key.as_deref())?,
|
||||
"ollama" => {
|
||||
|
|
@ -1046,10 +1150,15 @@ fn fetch_live_models_for_provider(provider_name: &str, api_key: &str) -> Result<
|
|||
]
|
||||
}
|
||||
}
|
||||
"astrai" => {
|
||||
fetch_openai_compatible_models("https://as-trai.com/v1/models", api_key.as_deref())?
|
||||
_ => {
|
||||
if let Some(endpoint) = models_endpoint_for_provider(requested_provider_name) {
|
||||
let allow_unauthenticated =
|
||||
allows_unauthenticated_model_fetch(requested_provider_name);
|
||||
fetch_openai_compatible_models(endpoint, api_key.as_deref(), allow_unauthenticated)?
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
_ => Vec::new(),
|
||||
};
|
||||
|
||||
Ok(models)
|
||||
|
|
@ -1719,167 +1828,12 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio
|
|||
|
||||
// ── Model selection ──
|
||||
let canonical_provider = canonical_provider_name(provider_name);
|
||||
let models: Vec<(&str, &str)> = match canonical_provider {
|
||||
"openrouter" => vec![
|
||||
(
|
||||
"anthropic/claude-sonnet-4",
|
||||
"Claude Sonnet 4 (balanced, recommended)",
|
||||
),
|
||||
(
|
||||
"anthropic/claude-3.5-sonnet",
|
||||
"Claude 3.5 Sonnet (fast, affordable)",
|
||||
),
|
||||
("openai/gpt-4o", "GPT-4o (OpenAI flagship)"),
|
||||
("openai/gpt-4o-mini", "GPT-4o Mini (fast, cheap)"),
|
||||
(
|
||||
"google/gemini-2.0-flash-001",
|
||||
"Gemini 2.0 Flash (Google, fast)",
|
||||
),
|
||||
(
|
||||
"meta-llama/llama-3.3-70b-instruct",
|
||||
"Llama 3.3 70B (open source)",
|
||||
),
|
||||
("deepseek/deepseek-chat", "DeepSeek Chat (affordable)"),
|
||||
],
|
||||
"anthropic" => vec![
|
||||
(
|
||||
"claude-sonnet-4-20250514",
|
||||
"Claude Sonnet 4 (balanced, recommended)",
|
||||
),
|
||||
("claude-3-5-sonnet-20241022", "Claude 3.5 Sonnet (fast)"),
|
||||
(
|
||||
"claude-3-5-haiku-20241022",
|
||||
"Claude 3.5 Haiku (fastest, cheapest)",
|
||||
),
|
||||
],
|
||||
"openai" => vec![
|
||||
("gpt-4o", "GPT-4o (flagship)"),
|
||||
("gpt-4o-mini", "GPT-4o Mini (fast, cheap)"),
|
||||
("o1-mini", "o1-mini (reasoning)"),
|
||||
],
|
||||
"openai-codex" => vec![
|
||||
("gpt-5-codex", "GPT-5 Codex (recommended)"),
|
||||
("o4-mini", "o4-mini (fallback)"),
|
||||
],
|
||||
"venice" => vec![
|
||||
("llama-3.3-70b", "Llama 3.3 70B (default, fast)"),
|
||||
("claude-opus-45", "Claude Opus 4.5 via Venice (strongest)"),
|
||||
("llama-3.1-405b", "Llama 3.1 405B (largest open source)"),
|
||||
],
|
||||
"groq" => vec![
|
||||
(
|
||||
"llama-3.3-70b-versatile",
|
||||
"Llama 3.3 70B (fast, recommended)",
|
||||
),
|
||||
("llama-3.1-8b-instant", "Llama 3.1 8B (instant)"),
|
||||
("mixtral-8x7b-32768", "Mixtral 8x7B (32K context)"),
|
||||
],
|
||||
"mistral" => vec![
|
||||
("mistral-large-latest", "Mistral Large (flagship)"),
|
||||
("codestral-latest", "Codestral (code-focused)"),
|
||||
("mistral-small-latest", "Mistral Small (fast, cheap)"),
|
||||
],
|
||||
"deepseek" => vec![
|
||||
("deepseek-chat", "DeepSeek Chat (V3, recommended)"),
|
||||
("deepseek-reasoner", "DeepSeek Reasoner (R1)"),
|
||||
],
|
||||
"xai" => vec![
|
||||
("grok-3", "Grok 3 (flagship)"),
|
||||
("grok-3-mini", "Grok 3 Mini (fast)"),
|
||||
],
|
||||
"perplexity" => vec![
|
||||
("sonar-pro", "Sonar Pro (search + reasoning)"),
|
||||
("sonar", "Sonar (search, fast)"),
|
||||
],
|
||||
"fireworks" => vec![
|
||||
(
|
||||
"accounts/fireworks/models/llama-v3p3-70b-instruct",
|
||||
"Llama 3.3 70B",
|
||||
),
|
||||
(
|
||||
"accounts/fireworks/models/mixtral-8x22b-instruct",
|
||||
"Mixtral 8x22B",
|
||||
),
|
||||
],
|
||||
"together-ai" => vec![
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
|
||||
"Llama 3.1 70B Turbo",
|
||||
),
|
||||
(
|
||||
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
||||
"Llama 3.1 8B Turbo",
|
||||
),
|
||||
("mistralai/Mixtral-8x22B-Instruct-v0.1", "Mixtral 8x22B"),
|
||||
],
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => vec![
|
||||
("deepseek-ai/DeepSeek-R1", "DeepSeek R1 (reasoning)"),
|
||||
("meta/llama-3.1-70b-instruct", "Llama 3.1 70B Instruct"),
|
||||
("mistralai/Mistral-7B-Instruct-v0.3", "Mistral 7B Instruct"),
|
||||
("meta/llama-3.1-405b-instruct", "Llama 3.1 405B Instruct"),
|
||||
],
|
||||
"cohere" => vec![
|
||||
("command-r-plus", "Command R+ (flagship)"),
|
||||
("command-r", "Command R (fast)"),
|
||||
],
|
||||
"kimi-code" => vec![
|
||||
(
|
||||
"kimi-for-coding",
|
||||
"Kimi for Coding (official coding-agent model)",
|
||||
),
|
||||
("kimi-k2.5", "Kimi K2.5 (general coding endpoint model)"),
|
||||
],
|
||||
"moonshot" => vec![
|
||||
("moonshot-v1-128k", "Moonshot V1 128K"),
|
||||
("moonshot-v1-32k", "Moonshot V1 32K"),
|
||||
],
|
||||
"glm" | "zai" => vec![
|
||||
("glm-5", "GLM-5 (latest)"),
|
||||
("glm-4-plus", "GLM-4 Plus (flagship)"),
|
||||
("glm-4-flash", "GLM-4 Flash (fast)"),
|
||||
],
|
||||
"minimax" => MINIMAX_ONBOARD_MODELS.to_vec(),
|
||||
"qwen" => vec![
|
||||
("qwen-plus", "Qwen Plus (balanced default)"),
|
||||
("qwen-max", "Qwen Max (highest quality)"),
|
||||
("qwen-turbo", "Qwen Turbo (fast and cost-efficient)"),
|
||||
],
|
||||
"ollama" => vec![
|
||||
("llama3.2", "Llama 3.2 (recommended local)"),
|
||||
("mistral", "Mistral 7B"),
|
||||
("codellama", "Code Llama"),
|
||||
("phi3", "Phi-3 (small, fast)"),
|
||||
],
|
||||
"gemini" | "google" | "google-gemini" => vec![
|
||||
("gemini-2.0-flash", "Gemini 2.0 Flash (fast, recommended)"),
|
||||
(
|
||||
"gemini-2.0-flash-lite",
|
||||
"Gemini 2.0 Flash Lite (fastest, cheapest)",
|
||||
),
|
||||
("gemini-1.5-pro", "Gemini 1.5 Pro (best quality)"),
|
||||
("gemini-1.5-flash", "Gemini 1.5 Flash (balanced)"),
|
||||
],
|
||||
"astrai" => vec![
|
||||
("auto", "Auto — Astrai best execution routing (recommended)"),
|
||||
("gpt-4o", "GPT-4o (OpenAI via Astrai)"),
|
||||
(
|
||||
"claude-sonnet-4.5",
|
||||
"Claude Sonnet 4.5 (Anthropic via Astrai)",
|
||||
),
|
||||
("deepseek-v3", "DeepSeek V3 (best value via Astrai)"),
|
||||
("llama-3.3-70b", "Llama 3.3 70B (open source via Astrai)"),
|
||||
],
|
||||
_ => vec![("default", "Default model")],
|
||||
};
|
||||
let mut model_options: Vec<(String, String)> = curated_models_for_provider(canonical_provider);
|
||||
|
||||
let mut model_options: Vec<(String, String)> = models
|
||||
.into_iter()
|
||||
.map(|(model_id, label)| (model_id.to_string(), label.to_string()))
|
||||
.collect();
|
||||
let mut live_options: Option<Vec<(String, String)>> = None;
|
||||
|
||||
if supports_live_model_fetch(provider_name) {
|
||||
let can_fetch_without_key = matches!(provider_name, "openrouter" | "ollama");
|
||||
let can_fetch_without_key = allows_unauthenticated_model_fetch(provider_name);
|
||||
let has_api_key = !api_key.trim().is_empty()
|
||||
|| std::env::var(provider_env_var(provider_name))
|
||||
.ok()
|
||||
|
|
@ -4663,7 +4617,12 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn default_model_for_provider_uses_latest_defaults() {
|
||||
assert_eq!(
|
||||
default_model_for_provider("openrouter"),
|
||||
"anthropic/claude-sonnet-4.6"
|
||||
);
|
||||
assert_eq!(default_model_for_provider("openai"), "gpt-5.2");
|
||||
assert_eq!(default_model_for_provider("openai-codex"), "gpt-5-codex");
|
||||
assert_eq!(
|
||||
default_model_for_provider("anthropic"),
|
||||
"claude-sonnet-4-5-20250929"
|
||||
|
|
@ -4680,6 +4639,16 @@ mod tests {
|
|||
default_model_for_provider("google-gemini"),
|
||||
"gemini-2.5-pro"
|
||||
);
|
||||
assert_eq!(default_model_for_provider("venice"), "zai-org-glm-5");
|
||||
assert_eq!(default_model_for_provider("moonshot"), "kimi-k2.5");
|
||||
assert_eq!(
|
||||
default_model_for_provider("nvidia-nim"),
|
||||
"meta/llama-3.3-70b-instruct"
|
||||
);
|
||||
assert_eq!(
|
||||
default_model_for_provider("astrai"),
|
||||
"anthropic/claude-sonnet-4.6"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -4708,6 +4677,31 @@ mod tests {
|
|||
assert!(ids.contains(&"gpt-5-mini".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_glm_removes_deprecated_flash_plus_aliases() {
|
||||
let ids: Vec<String> = curated_models_for_provider("glm")
|
||||
.into_iter()
|
||||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
|
||||
assert!(ids.contains(&"glm-5".to_string()));
|
||||
assert!(ids.contains(&"glm-4.7".to_string()));
|
||||
assert!(ids.contains(&"glm-4.5-air".to_string()));
|
||||
assert!(!ids.contains(&"glm-4-plus".to_string()));
|
||||
assert!(!ids.contains(&"glm-4-flash".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_openai_codex_include_codex_family() {
|
||||
let ids: Vec<String> = curated_models_for_provider("openai-codex")
|
||||
.into_iter()
|
||||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
|
||||
assert!(ids.contains(&"gpt-5-codex".to_string()));
|
||||
assert!(ids.contains(&"gpt-5.2-codex".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_openrouter_use_valid_anthropic_id() {
|
||||
let ids: Vec<String> = curated_models_for_provider("openrouter")
|
||||
|
|
@ -4715,7 +4709,33 @@ mod tests {
|
|||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
|
||||
assert!(ids.contains(&"anthropic/claude-sonnet-4.5".to_string()));
|
||||
assert!(ids.contains(&"anthropic/claude-sonnet-4.6".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn curated_models_for_moonshot_drop_deprecated_aliases() {
|
||||
let ids: Vec<String> = curated_models_for_provider("moonshot")
|
||||
.into_iter()
|
||||
.map(|(id, _)| id)
|
||||
.collect();
|
||||
|
||||
assert!(ids.contains(&"kimi-k2.5".to_string()));
|
||||
assert!(ids.contains(&"kimi-k2-thinking".to_string()));
|
||||
assert!(!ids.contains(&"kimi-latest".to_string()));
|
||||
assert!(!ids.contains(&"kimi-thinking-preview".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allows_unauthenticated_model_fetch_for_public_catalogs() {
|
||||
assert!(allows_unauthenticated_model_fetch("openrouter"));
|
||||
assert!(allows_unauthenticated_model_fetch("venice"));
|
||||
assert!(allows_unauthenticated_model_fetch("nvidia"));
|
||||
assert!(allows_unauthenticated_model_fetch("nvidia-nim"));
|
||||
assert!(allows_unauthenticated_model_fetch("build.nvidia.com"));
|
||||
assert!(allows_unauthenticated_model_fetch("astrai"));
|
||||
assert!(allows_unauthenticated_model_fetch("ollama"));
|
||||
assert!(!allows_unauthenticated_model_fetch("openai"));
|
||||
assert!(!allows_unauthenticated_model_fetch("deepseek"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -4739,7 +4759,11 @@ mod tests {
|
|||
assert!(supports_live_model_fetch("together"));
|
||||
assert!(supports_live_model_fetch("ollama"));
|
||||
assert!(supports_live_model_fetch("astrai"));
|
||||
assert!(!supports_live_model_fetch("venice"));
|
||||
assert!(supports_live_model_fetch("venice"));
|
||||
assert!(supports_live_model_fetch("glm-cn"));
|
||||
assert!(supports_live_model_fetch("qwen-intl"));
|
||||
assert!(!supports_live_model_fetch("minimax-cn"));
|
||||
assert!(!supports_live_model_fetch("unknown-provider"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -4778,6 +4802,40 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_endpoint_for_provider_handles_region_aliases() {
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("glm-cn"),
|
||||
Some("https://open.bigmodel.cn/api/paas/v4/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("zai-cn"),
|
||||
Some("https://open.bigmodel.cn/api/coding/paas/v4/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("qwen-intl"),
|
||||
Some("https://dashscope-intl.aliyuncs.com/compatible-mode/v1/models")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn models_endpoint_for_provider_supports_additional_openai_compatible_providers() {
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("venice"),
|
||||
Some("https://api.venice.ai/api/v1/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("cohere"),
|
||||
Some("https://api.cohere.com/compatibility/v1/models")
|
||||
);
|
||||
assert_eq!(
|
||||
models_endpoint_for_provider("moonshot"),
|
||||
Some("https://api.moonshot.ai/v1/models")
|
||||
);
|
||||
assert_eq!(models_endpoint_for_provider("perplexity"), None);
|
||||
assert_eq!(models_endpoint_for_provider("unknown-provider"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_openai_model_ids_supports_data_array_payload() {
|
||||
let payload = json!({
|
||||
|
|
|
|||
|
|
@ -263,6 +263,8 @@ impl ResponseMessage {
|
|||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type")]
|
||||
kind: Option<String>,
|
||||
function: Option<Function>,
|
||||
|
|
@ -274,6 +276,30 @@ struct Function {
|
|||
arguments: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeChatRequest {
|
||||
model: String,
|
||||
messages: Vec<NativeMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
stream: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<serde_json::Value>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResponsesRequest {
|
||||
model: String,
|
||||
|
|
@ -571,6 +597,169 @@ impl OpenAiCompatibleProvider {
|
|||
extract_responses_text(responses)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
|
||||
}
|
||||
|
||||
fn convert_tool_specs(
|
||||
tools: Option<&[crate::tools::ToolSpec]>,
|
||||
) -> Option<Vec<serde_json::Value>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| {
|
||||
serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.name,
|
||||
"description": tool.description,
|
||||
"parameters": tool.parameters,
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages_for_native(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
|
||||
{
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||
tool_calls_value.clone(),
|
||||
)
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| ToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some(tc.name),
|
||||
arguments: Some(tc.arguments),
|
||||
}),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string)
|
||||
.or_else(|| Some(message.content.clone()));
|
||||
|
||||
return NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
NativeMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(message.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn with_prompt_guided_tool_instructions(
|
||||
messages: &[ChatMessage],
|
||||
tools: Option<&[crate::tools::ToolSpec]>,
|
||||
) -> Vec<ChatMessage> {
|
||||
let Some(tools) = tools else {
|
||||
return messages.to_vec();
|
||||
};
|
||||
|
||||
if tools.is_empty() {
|
||||
return messages.to_vec();
|
||||
}
|
||||
|
||||
let instructions = crate::providers::traits::build_tool_instructions_text(tools);
|
||||
let mut modified_messages = messages.to_vec();
|
||||
|
||||
if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") {
|
||||
if !system_message.content.is_empty() {
|
||||
system_message.content.push_str("\n\n");
|
||||
}
|
||||
system_message.content.push_str(&instructions);
|
||||
} else {
|
||||
modified_messages.insert(0, ChatMessage::system(instructions));
|
||||
}
|
||||
|
||||
modified_messages
|
||||
}
|
||||
|
||||
fn parse_native_response(message: ResponseMessage) -> ProviderChatResponse {
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|tc| {
|
||||
let function = tc.function?;
|
||||
let name = function.name?;
|
||||
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
|
||||
Some(ProviderToolCall {
|
||||
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name,
|
||||
arguments,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ProviderChatResponse {
|
||||
text: message.content,
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_native_tool_schema_unsupported(status: reqwest::StatusCode, error: &str) -> bool {
|
||||
if !matches!(
|
||||
status,
|
||||
reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::UNPROCESSABLE_ENTITY
|
||||
) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lower = error.to_lowercase();
|
||||
[
|
||||
"unknown parameter: tools",
|
||||
"unsupported parameter: tools",
|
||||
"unrecognized field `tools`",
|
||||
"does not support tools",
|
||||
"function calling is not supported",
|
||||
"tool_choice",
|
||||
]
|
||||
.iter()
|
||||
.any(|hint| lower.contains(hint))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -846,49 +1035,83 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
// If native tools are requested, delegate to chat_with_tools.
|
||||
if let Some(tools) = request.tools {
|
||||
if !tools.is_empty() && self.supports_native_tools() {
|
||||
let native_tools = Self::tool_specs_to_openai_format(tools);
|
||||
return self
|
||||
.chat_with_tools(request.messages, &native_tools, model, temperature)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
)
|
||||
})?;
|
||||
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
let tools = Self::convert_tool_specs(request.tools);
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: Self::convert_messages_for_native(request.messages),
|
||||
temperature,
|
||||
stream: Some(false),
|
||||
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools,
|
||||
};
|
||||
|
||||
let url = self.chat_completions_url();
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&native_request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
|
||||
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
|
||||
let parsed_text = message.effective_content_optional();
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|tc| {
|
||||
let function = tc.function?;
|
||||
let name = function.name?;
|
||||
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
|
||||
Some(ProviderToolCall {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
name,
|
||||
arguments,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error = response.text().await?;
|
||||
let sanitized = super::sanitize_api_error(&error);
|
||||
|
||||
return Ok(ProviderChatResponse {
|
||||
text: parsed_text,
|
||||
tool_calls,
|
||||
});
|
||||
if Self::is_native_tool_schema_unsupported(status, &sanitized) {
|
||||
let fallback_messages =
|
||||
Self::with_prompt_guided_tool_instructions(request.messages, request.tools);
|
||||
let text = self
|
||||
.chat_with_history(&fallback_messages, model, temperature)
|
||||
.await?;
|
||||
return Ok(ProviderChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
let system = request.messages.iter().find(|m| m.role == "system");
|
||||
let last_user = request.messages.iter().rfind(|m| m.role == "user");
|
||||
if let Some(user_msg) = last_user {
|
||||
return self
|
||||
.chat_via_responses(
|
||||
credential,
|
||||
system.map(|m| m.content.as_str()),
|
||||
&user_msg.content,
|
||||
model,
|
||||
)
|
||||
.await
|
||||
.map(|text| ProviderChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: vec![],
|
||||
})
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
"{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})",
|
||||
self.name
|
||||
)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
|
||||
}
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: vec![],
|
||||
})
|
||||
let native_response: ApiChatResponse = response.json().await?;
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|choice| choice.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
|
||||
|
||||
Ok(Self::parse_native_response(message))
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
|
|
@ -1400,6 +1623,76 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_native_response_preserves_tool_call_id() {
|
||||
let message = ResponseMessage {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: Some("call_123".to_string()),
|
||||
kind: Some("function".to_string()),
|
||||
function: Some(Function {
|
||||
name: Some("shell".to_string()),
|
||||
arguments: Some(r#"{"command":"pwd"}"#.to_string()),
|
||||
}),
|
||||
}]),
|
||||
};
|
||||
|
||||
let parsed = OpenAiCompatibleProvider::parse_native_response(message);
|
||||
assert_eq!(parsed.tool_calls.len(), 1);
|
||||
assert_eq!(parsed.tool_calls[0].id, "call_123");
|
||||
assert_eq!(parsed.tool_calls[0].name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_messages_for_native_maps_tool_result_payload() {
|
||||
let input = vec![ChatMessage::tool(
|
||||
r#"{"tool_call_id":"call_abc","content":"done"}"#,
|
||||
)];
|
||||
|
||||
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input);
|
||||
assert_eq!(converted.len(), 1);
|
||||
assert_eq!(converted[0].role, "tool");
|
||||
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
|
||||
assert_eq!(converted[0].content.as_deref(), Some("done"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_tool_schema_unsupported_detection_is_precise() {
|
||||
assert!(OpenAiCompatibleProvider::is_native_tool_schema_unsupported(
|
||||
reqwest::StatusCode::BAD_REQUEST,
|
||||
"unknown parameter: tools"
|
||||
));
|
||||
assert!(
|
||||
!OpenAiCompatibleProvider::is_native_tool_schema_unsupported(
|
||||
reqwest::StatusCode::UNAUTHORIZED,
|
||||
"unknown parameter: tools"
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_guided_tool_fallback_injects_system_instruction() {
|
||||
let input = vec![ChatMessage::user("check status")];
|
||||
let tools = vec![crate::tools::ToolSpec {
|
||||
name: "shell_exec".to_string(),
|
||||
description: "Execute shell command".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": { "type": "string" }
|
||||
},
|
||||
"required": ["command"]
|
||||
}),
|
||||
}];
|
||||
|
||||
let output =
|
||||
OpenAiCompatibleProvider::with_prompt_guided_tool_instructions(&input, Some(&tools));
|
||||
assert!(!output.is_empty());
|
||||
assert_eq!(output[0].role, "system");
|
||||
assert!(output[0].content.contains("Available Tools"));
|
||||
assert!(output[0].content.contains("shell_exec"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_without_key_is_noop() {
|
||||
let provider = make_provider("test", "https://example.com", None);
|
||||
|
|
|
|||
|
|
@ -67,6 +67,52 @@ fn is_rate_limited(err: &anyhow::Error) -> bool {
|
|||
&& (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
|
||||
}
|
||||
|
||||
/// Check if a 429 is a business/quota-plan error that retries cannot fix.
|
||||
///
|
||||
/// Examples:
|
||||
/// - plan does not include requested model
|
||||
/// - insufficient balance / package not active
|
||||
/// - known provider business codes (e.g. Z.AI: 1311, 1113)
|
||||
fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
|
||||
if !is_rate_limited(err) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let msg = err.to_string();
|
||||
let lower = msg.to_lowercase();
|
||||
|
||||
let business_hints = [
|
||||
"plan does not include",
|
||||
"doesn't include",
|
||||
"not include",
|
||||
"insufficient balance",
|
||||
"insufficient_balance",
|
||||
"insufficient quota",
|
||||
"insufficient_quota",
|
||||
"quota exhausted",
|
||||
"out of credits",
|
||||
"no available package",
|
||||
"package not active",
|
||||
"purchase package",
|
||||
"model not available for your plan",
|
||||
];
|
||||
|
||||
if business_hints.iter().any(|hint| lower.contains(hint)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Known provider business codes observed for 429 where retry is futile.
|
||||
for token in lower.split(|c: char| !c.is_ascii_digit()) {
|
||||
if let Ok(code) = token.parse::<u16>() {
|
||||
if matches!(code, 1113 | 1311) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
/// Try to extract a Retry-After value (in milliseconds) from an error message.
|
||||
/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
|
||||
fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
|
||||
|
|
@ -101,7 +147,9 @@ fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
|
|||
}
|
||||
|
||||
fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
|
||||
if rate_limited {
|
||||
if rate_limited && non_retryable {
|
||||
"rate_limited_non_retryable"
|
||||
} else if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
|
|
@ -244,7 +292,8 @@ impl Provider for ReliableProvider {
|
|||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
let non_retryable = is_non_retryable(&e);
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
let failure_reason = failure_reason(rate_limited, non_retryable);
|
||||
let error_detail = compact_error_detail(&e);
|
||||
|
|
@ -260,7 +309,7 @@ impl Provider for ReliableProvider {
|
|||
);
|
||||
|
||||
// On rate-limit, try rotating API key
|
||||
if rate_limited {
|
||||
if rate_limited && !non_retryable_rate_limit {
|
||||
if let Some(new_key) = self.rotate_key() {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
|
|
@ -352,7 +401,8 @@ impl Provider for ReliableProvider {
|
|||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
let non_retryable = is_non_retryable(&e);
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
let failure_reason = failure_reason(rate_limited, non_retryable);
|
||||
let error_detail = compact_error_detail(&e);
|
||||
|
|
@ -367,7 +417,7 @@ impl Provider for ReliableProvider {
|
|||
&error_detail,
|
||||
);
|
||||
|
||||
if rate_limited {
|
||||
if rate_limited && !non_retryable_rate_limit {
|
||||
if let Some(new_key) = self.rotate_key() {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
|
|
@ -459,7 +509,8 @@ impl Provider for ReliableProvider {
|
|||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
let non_retryable = is_non_retryable(&e);
|
||||
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
|
||||
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
let failure_reason = failure_reason(rate_limited, non_retryable);
|
||||
let error_detail = compact_error_detail(&e);
|
||||
|
|
@ -474,7 +525,7 @@ impl Provider for ReliableProvider {
|
|||
&error_detail,
|
||||
);
|
||||
|
||||
if rate_limited {
|
||||
if rate_limited && !non_retryable_rate_limit {
|
||||
if let Some(new_key) = self.rotate_key() {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
|
|
@ -1106,6 +1157,39 @@ mod tests {
|
|||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_retryable_rate_limit_detects_plan_restricted_model() {
|
||||
let err = anyhow::anyhow!(
|
||||
"{}",
|
||||
"API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}"
|
||||
);
|
||||
assert!(
|
||||
is_non_retryable_rate_limit(&err),
|
||||
"plan-restricted 429 should skip retries"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_retryable_rate_limit_detects_insufficient_balance() {
|
||||
let err = anyhow::anyhow!(
|
||||
"{}",
|
||||
"API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}"
|
||||
);
|
||||
assert!(
|
||||
is_non_retryable_rate_limit(&err),
|
||||
"insufficient-balance 429 should skip retries"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_retryable_rate_limit_does_not_flag_generic_429() {
|
||||
let err = anyhow::anyhow!("429 Too Many Requests: rate limit exceeded");
|
||||
assert!(
|
||||
!is_non_retryable_rate_limit(&err),
|
||||
"generic rate-limit 429 should remain retryable"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_backoff_uses_retry_after() {
|
||||
let provider = ReliableProvider::new(vec![], 0, 500);
|
||||
|
|
@ -1261,6 +1345,35 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"primary".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: usize::MAX,
|
||||
response: "never",
|
||||
error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
|
||||
}),
|
||||
)],
|
||||
5,
|
||||
1,
|
||||
);
|
||||
|
||||
let result = provider.simple_chat("hello", "test", 0.0).await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"plan-restricted 429 should fail quickly without retrying"
|
||||
);
|
||||
assert_eq!(
|
||||
calls.load(Ordering::SeqCst),
|
||||
1,
|
||||
"must not retry non-retryable 429 business errors"
|
||||
);
|
||||
}
|
||||
|
||||
// ── Arc<ModelAwareMock> Provider impl for test ──
|
||||
|
||||
#[async_trait]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue