fix(providers): harden tool fallback and refresh model catalogs

This commit is contained in:
Chummy 2026-02-18 22:36:39 +08:00
parent 43494f8331
commit b4b379e3e7
9 changed files with 1111 additions and 367 deletions

View file

@ -263,6 +263,8 @@ impl ResponseMessage {
#[derive(Debug, Deserialize, Serialize)]
struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type")]
kind: Option<String>,
function: Option<Function>,
@ -274,6 +276,30 @@ struct Function {
arguments: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeChatRequest {
model: String,
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Serialize)]
struct ResponsesRequest {
model: String,
@ -571,6 +597,169 @@ impl OpenAiCompatibleProvider {
extract_responses_text(responses)
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
}
fn convert_tool_specs(
tools: Option<&[crate::tools::ToolSpec]>,
) -> Option<Vec<serde_json::Value>> {
tools.map(|items| {
items
.iter()
.map(|tool| {
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters,
}
})
})
.collect()
})
}
fn convert_messages_for_native(messages: &[ChatMessage]) -> Vec<NativeMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content)
{
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(
tool_calls_value.clone(),
)
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| ToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: Some(Function {
name: Some(tc.name),
arguments: Some(tc.arguments),
}),
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string)
.or_else(|| Some(message.content.clone()));
return NativeMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
NativeMessage {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
fn with_prompt_guided_tool_instructions(
messages: &[ChatMessage],
tools: Option<&[crate::tools::ToolSpec]>,
) -> Vec<ChatMessage> {
let Some(tools) = tools else {
return messages.to_vec();
};
if tools.is_empty() {
return messages.to_vec();
}
let instructions = crate::providers::traits::build_tool_instructions_text(tools);
let mut modified_messages = messages.to_vec();
if let Some(system_message) = modified_messages.iter_mut().find(|m| m.role == "system") {
if !system_message.content.is_empty() {
system_message.content.push_str("\n\n");
}
system_message.content.push_str(&instructions);
} else {
modified_messages.insert(0, ChatMessage::system(instructions));
}
modified_messages
}
fn parse_native_response(message: ResponseMessage) -> ProviderChatResponse {
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.filter_map(|tc| {
let function = tc.function?;
let name = function.name?;
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
Some(ProviderToolCall {
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name,
arguments,
})
})
.collect::<Vec<_>>();
ProviderChatResponse {
text: message.content,
tool_calls,
}
}
fn is_native_tool_schema_unsupported(status: reqwest::StatusCode, error: &str) -> bool {
if !matches!(
status,
reqwest::StatusCode::BAD_REQUEST | reqwest::StatusCode::UNPROCESSABLE_ENTITY
) {
return false;
}
let lower = error.to_lowercase();
[
"unknown parameter: tools",
"unsupported parameter: tools",
"unrecognized field `tools`",
"does not support tools",
"function calling is not supported",
"tool_choice",
]
.iter()
.any(|hint| lower.contains(hint))
}
}
#[async_trait]
@ -846,49 +1035,83 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
// If native tools are requested, delegate to chat_with_tools.
if let Some(tools) = request.tools {
if !tools.is_empty() && self.supports_native_tools() {
let native_tools = Self::tool_specs_to_openai_format(tools);
return self
.chat_with_tools(request.messages, &native_tools, model, temperature)
.await;
}
}
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
)
})?;
let text = self
.chat_with_history(request.messages, model, temperature)
let tools = Self::convert_tool_specs(request.tools);
let native_request = NativeChatRequest {
model: model.to_string(),
messages: Self::convert_messages_for_native(request.messages),
temperature,
stream: Some(false),
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
tools,
};
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&native_request), credential)
.send()
.await?;
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
let parsed_text = message.effective_content_optional();
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.filter_map(|tc| {
let function = tc.function?;
let name = function.name?;
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
Some(ProviderToolCall {
id: uuid::Uuid::new_v4().to_string(),
name,
arguments,
})
})
.collect::<Vec<_>>();
if !response.status().is_success() {
let status = response.status();
let error = response.text().await?;
let sanitized = super::sanitize_api_error(&error);
return Ok(ProviderChatResponse {
text: parsed_text,
tool_calls,
});
if Self::is_native_tool_schema_unsupported(status, &sanitized) {
let fallback_messages =
Self::with_prompt_guided_tool_instructions(request.messages, request.tools);
let text = self
.chat_with_history(&fallback_messages, model, temperature)
.await?;
return Ok(ProviderChatResponse {
text: Some(text),
tool_calls: vec![],
});
}
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
let system = request.messages.iter().find(|m| m.role == "system");
let last_user = request.messages.iter().rfind(|m| m.role == "user");
if let Some(user_msg) = last_user {
return self
.chat_via_responses(
credential,
system.map(|m| m.content.as_str()),
&user_msg.content,
model,
)
.await
.map(|text| ProviderChatResponse {
text: Some(text),
tool_calls: vec![],
})
.map_err(|responses_err| {
anyhow::anyhow!(
"{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})",
self.name
)
});
}
}
anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
}
Ok(ProviderChatResponse {
text: Some(text),
tool_calls: vec![],
})
let native_response: ApiChatResponse = response.json().await?;
let message = native_response
.choices
.into_iter()
.next()
.map(|choice| choice.message)
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
Ok(Self::parse_native_response(message))
}
fn supports_native_tools(&self) -> bool {
@ -1400,6 +1623,76 @@ mod tests {
);
}
#[test]
fn parse_native_response_preserves_tool_call_id() {
let message = ResponseMessage {
content: None,
tool_calls: Some(vec![ToolCall {
id: Some("call_123".to_string()),
kind: Some("function".to_string()),
function: Some(Function {
name: Some("shell".to_string()),
arguments: Some(r#"{"command":"pwd"}"#.to_string()),
}),
}]),
};
let parsed = OpenAiCompatibleProvider::parse_native_response(message);
assert_eq!(parsed.tool_calls.len(), 1);
assert_eq!(parsed.tool_calls[0].id, "call_123");
assert_eq!(parsed.tool_calls[0].name, "shell");
}
#[test]
fn convert_messages_for_native_maps_tool_result_payload() {
let input = vec![ChatMessage::tool(
r#"{"tool_call_id":"call_abc","content":"done"}"#,
)];
let converted = OpenAiCompatibleProvider::convert_messages_for_native(&input);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_abc"));
assert_eq!(converted[0].content.as_deref(), Some("done"));
}
#[test]
fn native_tool_schema_unsupported_detection_is_precise() {
assert!(OpenAiCompatibleProvider::is_native_tool_schema_unsupported(
reqwest::StatusCode::BAD_REQUEST,
"unknown parameter: tools"
));
assert!(
!OpenAiCompatibleProvider::is_native_tool_schema_unsupported(
reqwest::StatusCode::UNAUTHORIZED,
"unknown parameter: tools"
)
);
}
#[test]
fn prompt_guided_tool_fallback_injects_system_instruction() {
let input = vec![ChatMessage::user("check status")];
let tools = vec![crate::tools::ToolSpec {
name: "shell_exec".to_string(),
description: "Execute shell command".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"command": { "type": "string" }
},
"required": ["command"]
}),
}];
let output =
OpenAiCompatibleProvider::with_prompt_guided_tool_instructions(&input, Some(&tools));
assert!(!output.is_empty());
assert_eq!(output[0].role, "system");
assert!(output[0].content.contains("Available Tools"));
assert!(output[0].content.contains("shell_exec"));
}
#[tokio::test]
async fn warmup_without_key_is_noop() {
let provider = make_provider("test", "https://example.com", None);

View file

@ -67,6 +67,52 @@ fn is_rate_limited(err: &anyhow::Error) -> bool {
&& (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
}
/// Check if a 429 is a business/quota-plan error that retries cannot fix.
///
/// Examples:
/// - plan does not include requested model
/// - insufficient balance / package not active
/// - known provider business codes (e.g. Z.AI: 1311, 1113)
fn is_non_retryable_rate_limit(err: &anyhow::Error) -> bool {
if !is_rate_limited(err) {
return false;
}
let msg = err.to_string();
let lower = msg.to_lowercase();
let business_hints = [
"plan does not include",
"doesn't include",
"not include",
"insufficient balance",
"insufficient_balance",
"insufficient quota",
"insufficient_quota",
"quota exhausted",
"out of credits",
"no available package",
"package not active",
"purchase package",
"model not available for your plan",
];
if business_hints.iter().any(|hint| lower.contains(hint)) {
return true;
}
// Known provider business codes observed for 429 where retry is futile.
for token in lower.split(|c: char| !c.is_ascii_digit()) {
if let Ok(code) = token.parse::<u16>() {
if matches!(code, 1113 | 1311) {
return true;
}
}
}
false
}
/// Try to extract a Retry-After value (in milliseconds) from an error message.
/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
@ -101,7 +147,9 @@ fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
}
fn failure_reason(rate_limited: bool, non_retryable: bool) -> &'static str {
if rate_limited {
if rate_limited && non_retryable {
"rate_limited_non_retryable"
} else if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
@ -244,7 +292,8 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
@ -260,7 +309,7 @@ impl Provider for ReliableProvider {
);
// On rate-limit, try rotating API key
if rate_limited {
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
@ -352,7 +401,8 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
@ -367,7 +417,7 @@ impl Provider for ReliableProvider {
&error_detail,
);
if rate_limited {
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
@ -459,7 +509,8 @@ impl Provider for ReliableProvider {
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let non_retryable_rate_limit = is_non_retryable_rate_limit(&e);
let non_retryable = is_non_retryable(&e) || non_retryable_rate_limit;
let rate_limited = is_rate_limited(&e);
let failure_reason = failure_reason(rate_limited, non_retryable);
let error_detail = compact_error_detail(&e);
@ -474,7 +525,7 @@ impl Provider for ReliableProvider {
&error_detail,
);
if rate_limited {
if rate_limited && !non_retryable_rate_limit {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
@ -1106,6 +1157,39 @@ mod tests {
)));
}
#[test]
fn non_retryable_rate_limit_detects_plan_restricted_model() {
let err = anyhow::anyhow!(
"{}",
"API error (429 Too Many Requests): {\"code\":1311,\"message\":\"the current account plan does not include glm-5\"}"
);
assert!(
is_non_retryable_rate_limit(&err),
"plan-restricted 429 should skip retries"
);
}
#[test]
fn non_retryable_rate_limit_detects_insufficient_balance() {
let err = anyhow::anyhow!(
"{}",
"API error (429 Too Many Requests): {\"code\":1113,\"message\":\"insufficient balance\"}"
);
assert!(
is_non_retryable_rate_limit(&err),
"insufficient-balance 429 should skip retries"
);
}
#[test]
fn non_retryable_rate_limit_does_not_flag_generic_429() {
let err = anyhow::anyhow!("429 Too Many Requests: rate limit exceeded");
assert!(
!is_non_retryable_rate_limit(&err),
"generic rate-limit 429 should remain retryable"
);
}
#[test]
fn compute_backoff_uses_retry_after() {
let provider = ReliableProvider::new(vec![], 0, 500);
@ -1261,6 +1345,35 @@ mod tests {
);
}
#[tokio::test]
async fn non_retryable_rate_limit_skips_retries_for_plan_errors() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "API error (429 Too Many Requests): {\"code\":1311,\"message\":\"plan does not include glm-5\"}",
}),
)],
5,
1,
);
let result = provider.simple_chat("hello", "test", 0.0).await;
assert!(
result.is_err(),
"plan-restricted 429 should fail quickly without retrying"
);
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"must not retry non-retryable 429 business errors"
);
}
// ── Arc<ModelAwareMock> Provider impl for test ──
#[async_trait]