fix(provider): preserve full history in responses fallback
This commit is contained in:
parent
48b51e7152
commit
63aacb09ff
1 changed files with 140 additions and 88 deletions
|
|
@ -585,6 +585,43 @@ fn first_nonempty(text: Option<&str>) -> Option<String> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_responses_role(role: &str) -> &'static str {
|
||||||
|
match role {
|
||||||
|
"assistant" => "assistant",
|
||||||
|
"tool" => "assistant",
|
||||||
|
_ => "user",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_responses_prompt(messages: &[ChatMessage]) -> (Option<String>, Vec<ResponsesInput>) {
|
||||||
|
let mut instructions_parts = Vec::new();
|
||||||
|
let mut input = Vec::new();
|
||||||
|
|
||||||
|
for message in messages {
|
||||||
|
if message.content.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.role == "system" {
|
||||||
|
instructions_parts.push(message.content.clone());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
input.push(ResponsesInput {
|
||||||
|
role: normalize_responses_role(&message.role).to_string(),
|
||||||
|
content: message.content.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let instructions = if instructions_parts.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(instructions_parts.join("\n\n"))
|
||||||
|
};
|
||||||
|
|
||||||
|
(instructions, input)
|
||||||
|
}
|
||||||
|
|
||||||
fn extract_responses_text(response: ResponsesResponse) -> Option<String> {
|
fn extract_responses_text(response: ResponsesResponse) -> Option<String> {
|
||||||
if let Some(text) = first_nonempty(response.output_text.as_deref()) {
|
if let Some(text) = first_nonempty(response.output_text.as_deref()) {
|
||||||
return Some(text);
|
return Some(text);
|
||||||
|
|
@ -655,17 +692,21 @@ impl OpenAiCompatibleProvider {
|
||||||
async fn chat_via_responses(
|
async fn chat_via_responses(
|
||||||
&self,
|
&self,
|
||||||
credential: &str,
|
credential: &str,
|
||||||
system_prompt: Option<&str>,
|
messages: &[ChatMessage],
|
||||||
message: &str,
|
|
||||||
model: &str,
|
model: &str,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
|
let (instructions, input) = build_responses_prompt(messages);
|
||||||
|
if input.is_empty() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"{} Responses API fallback requires at least one non-system message",
|
||||||
|
self.name
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
let request = ResponsesRequest {
|
let request = ResponsesRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
input: vec![ResponsesInput {
|
input,
|
||||||
role: "user".to_string(),
|
instructions,
|
||||||
content: message.to_string(),
|
|
||||||
}],
|
|
||||||
instructions: system_prompt.map(str::to_string),
|
|
||||||
stream: Some(false),
|
stream: Some(false),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -909,6 +950,17 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
|
|
||||||
|
let mut fallback_messages = Vec::new();
|
||||||
|
if let Some(system_prompt) = system_prompt {
|
||||||
|
fallback_messages.push(ChatMessage::system(system_prompt));
|
||||||
|
}
|
||||||
|
fallback_messages.push(ChatMessage::user(message));
|
||||||
|
let fallback_messages = if self.merge_system_into_user {
|
||||||
|
Self::flatten_system_messages(&fallback_messages)
|
||||||
|
} else {
|
||||||
|
fallback_messages
|
||||||
|
};
|
||||||
|
|
||||||
let response = match self
|
let response = match self
|
||||||
.apply_auth_header(self.http_client().post(&url).json(&request), credential)
|
.apply_auth_header(self.http_client().post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
|
|
@ -919,7 +971,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
if self.supports_responses_fallback {
|
if self.supports_responses_fallback {
|
||||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(credential, system_prompt, message, model)
|
.chat_via_responses(credential, &fallback_messages, model)
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
|
|
@ -940,7 +992,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(credential, system_prompt, message, model)
|
.chat_via_responses(credential, &fallback_messages, model)
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
|
|
@ -1023,17 +1075,9 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(chat_error) => {
|
Err(chat_error) => {
|
||||||
if self.supports_responses_fallback {
|
if self.supports_responses_fallback {
|
||||||
let system = messages.iter().find(|m| m.role == "system");
|
|
||||||
let last_user = messages.iter().rfind(|m| m.role == "user");
|
|
||||||
if let Some(user_msg) = last_user {
|
|
||||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(
|
.chat_via_responses(credential, &effective_messages, model)
|
||||||
credential,
|
|
||||||
system.map(|m| m.content.as_str()),
|
|
||||||
&user_msg.content,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
|
|
@ -1042,7 +1086,6 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return Err(chat_error.into());
|
return Err(chat_error.into());
|
||||||
}
|
}
|
||||||
|
|
@ -1053,17 +1096,8 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
|
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
|
||||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||||
// Extract system prompt and last user message for responses fallback
|
|
||||||
let system = messages.iter().find(|m| m.role == "system");
|
|
||||||
let last_user = messages.iter().rfind(|m| m.role == "user");
|
|
||||||
if let Some(user_msg) = last_user {
|
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(
|
.chat_via_responses(credential, &effective_messages, model)
|
||||||
credential,
|
|
||||||
system.map(|m| m.content.as_str()),
|
|
||||||
&user_msg.content,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
|
|
@ -1072,7 +1106,6 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return Err(super::api_error(&self.name, response).await);
|
return Err(super::api_error(&self.name, response).await);
|
||||||
}
|
}
|
||||||
|
|
@ -1240,17 +1273,9 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
Ok(response) => response,
|
Ok(response) => response,
|
||||||
Err(chat_error) => {
|
Err(chat_error) => {
|
||||||
if self.supports_responses_fallback {
|
if self.supports_responses_fallback {
|
||||||
let system = request.messages.iter().find(|m| m.role == "system");
|
|
||||||
let last_user = request.messages.iter().rfind(|m| m.role == "user");
|
|
||||||
if let Some(user_msg) = last_user {
|
|
||||||
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
let sanitized = super::sanitize_api_error(&chat_error.to_string());
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(
|
.chat_via_responses(credential, &effective_messages, model)
|
||||||
credential,
|
|
||||||
system.map(|m| m.content.as_str()),
|
|
||||||
&user_msg.content,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map(|text| ProviderChatResponse {
|
.map(|text| ProviderChatResponse {
|
||||||
text: Some(text),
|
text: Some(text),
|
||||||
|
|
@ -1263,7 +1288,6 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return Err(chat_error.into());
|
return Err(chat_error.into());
|
||||||
}
|
}
|
||||||
|
|
@ -1287,16 +1311,8 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
}
|
}
|
||||||
|
|
||||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||||
let system = request.messages.iter().find(|m| m.role == "system");
|
|
||||||
let last_user = request.messages.iter().rfind(|m| m.role == "user");
|
|
||||||
if let Some(user_msg) = last_user {
|
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(
|
.chat_via_responses(credential, &effective_messages, model)
|
||||||
credential,
|
|
||||||
system.map(|m| m.content.as_str()),
|
|
||||||
&user_msg.content,
|
|
||||||
model,
|
|
||||||
)
|
|
||||||
.await
|
.await
|
||||||
.map(|text| ProviderChatResponse {
|
.map(|text| ProviderChatResponse {
|
||||||
text: Some(text),
|
text: Some(text),
|
||||||
|
|
@ -1309,7 +1325,6 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
|
anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
|
||||||
}
|
}
|
||||||
|
|
@ -1643,6 +1658,43 @@ mod tests {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_responses_prompt_preserves_multi_turn_history() {
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("policy"),
|
||||||
|
ChatMessage::user("step 1"),
|
||||||
|
ChatMessage::assistant("ack 1"),
|
||||||
|
ChatMessage::tool("{\"result\":\"ok\"}"),
|
||||||
|
ChatMessage::user("step 2"),
|
||||||
|
];
|
||||||
|
|
||||||
|
let (instructions, input) = build_responses_prompt(&messages);
|
||||||
|
|
||||||
|
assert_eq!(instructions.as_deref(), Some("policy"));
|
||||||
|
assert_eq!(input.len(), 4);
|
||||||
|
assert_eq!(input[0].role, "user");
|
||||||
|
assert_eq!(input[0].content, "step 1");
|
||||||
|
assert_eq!(input[1].role, "assistant");
|
||||||
|
assert_eq!(input[1].content, "ack 1");
|
||||||
|
assert_eq!(input[2].role, "assistant");
|
||||||
|
assert_eq!(input[2].content, "{\"result\":\"ok\"}");
|
||||||
|
assert_eq!(input[3].role, "user");
|
||||||
|
assert_eq!(input[3].content, "step 2");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_via_responses_requires_non_system_message() {
|
||||||
|
let provider = make_provider("custom", "https://api.example.com", Some("test-key"));
|
||||||
|
let err = provider
|
||||||
|
.chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test")
|
||||||
|
.await
|
||||||
|
.expect_err("system-only fallback payload should fail");
|
||||||
|
|
||||||
|
assert!(err
|
||||||
|
.to_string()
|
||||||
|
.contains("requires at least one non-system message"));
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------
|
// ----------------------------------------------------------
|
||||||
// Custom endpoint path tests (Issue #114)
|
// Custom endpoint path tests (Issue #114)
|
||||||
// ----------------------------------------------------------
|
// ----------------------------------------------------------
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue