diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index f89270d..e55e1f0 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -73,6 +73,129 @@ struct ResponseMessage { content: String, } +#[derive(Debug, Serialize)] +struct ResponsesRequest { + model: String, + input: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + stream: Option, +} + +#[derive(Debug, Serialize)] +struct ResponsesInput { + role: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct ResponsesResponse { + #[serde(default)] + output: Vec, + #[serde(default)] + output_text: Option, +} + +#[derive(Debug, Deserialize)] +struct ResponsesOutput { + #[serde(default)] + content: Vec, +} + +#[derive(Debug, Deserialize)] +struct ResponsesContent { + #[serde(rename = "type")] + kind: Option, + text: Option, +} + +fn first_nonempty(text: Option<&str>) -> Option { + text.and_then(|value| { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + }) +} + +fn extract_responses_text(response: ResponsesResponse) -> Option { + if let Some(text) = first_nonempty(response.output_text.as_deref()) { + return Some(text); + } + + for item in &response.output { + for content in &item.content { + if content.kind.as_deref() == Some("output_text") { + if let Some(text) = first_nonempty(content.text.as_deref()) { + return Some(text); + } + } + } + } + + for item in &response.output { + for content in &item.content { + if let Some(text) = first_nonempty(content.text.as_deref()) { + return Some(text); + } + } + } + + None +} + +impl OpenAiCompatibleProvider { + fn apply_auth_header( + &self, + req: reqwest::RequestBuilder, + api_key: &str, + ) -> reqwest::RequestBuilder { + match &self.auth_header { + AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")), + AuthStyle::XApiKey => req.header("x-api-key", api_key), + AuthStyle::Custom(header) => req.header(header, api_key), + } + } + + async fn chat_via_responses( + &self, + api_key: &str, + system_prompt: Option<&str>, + message: &str, + model: &str, + ) -> anyhow::Result { + let request = ResponsesRequest { + model: model.to_string(), + input: vec![ResponsesInput { + role: "user".to_string(), + content: message.to_string(), + }], + instructions: system_prompt.map(str::to_string), + stream: Some(false), + }; + + let url = format!("{}/v1/responses", self.base_url); + + let response = self + .apply_auth_header(self.client.post(&url).json(&request), api_key) + .send() + .await?; + + if !response.status().is_success() { + let error = response.text().await?; + anyhow::bail!("{} Responses API error: {error}", self.name); + } + + let responses: ResponsesResponse = response.json().await?; + + extract_responses_text(responses) + .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) + } +} + #[async_trait] impl Provider for OpenAiCompatibleProvider { async fn chat_with_system( @@ -111,24 +234,28 @@ impl Provider for OpenAiCompatibleProvider { let url = format!("{}/v1/chat/completions", self.base_url); - let mut req = self.client.post(&url).json(&request); - - match &self.auth_header { - AuthStyle::Bearer => { - req = req.header("Authorization", format!("Bearer {api_key}")); - } - AuthStyle::XApiKey => { - req = req.header("x-api-key", api_key.as_str()); - } - AuthStyle::Custom(header) => { - req = req.header(header.as_str(), api_key.as_str()); - } - } - - let response = req.send().await?; + let response = self + .apply_auth_header(self.client.post(&url).json(&request), api_key) + .send() + .await?; if !response.status().is_success() { - return Err(super::api_error(&self.name, response).await); + let status = response.status(); + let error = response.text().await?; + + if status == reqwest::StatusCode::NOT_FOUND { + return self + .chat_via_responses(api_key, system_prompt, message, model) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} API error: {error} (chat completions unavailable; responses fallback failed: {responses_err})", + self.name + ) + }); + } + + anyhow::bail!("{} API error: {error}", self.name); } let chat_response: ChatResponse = response.json().await?; @@ -263,4 +390,35 @@ mod tests { ); } } + + #[test] + fn responses_extracts_top_level_output_text() { + let json = r#"{"output_text":"Hello from top-level","output":[]}"#; + let response: ResponsesResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + extract_responses_text(response).as_deref(), + Some("Hello from top-level") + ); + } + + #[test] + fn responses_extracts_nested_output_text() { + let json = + r#"{"output":[{"content":[{"type":"output_text","text":"Hello from nested"}]}]}"#; + let response: ResponsesResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + extract_responses_text(response).as_deref(), + Some("Hello from nested") + ); + } + + #[test] + fn responses_extracts_any_text_as_fallback() { + let json = r#"{"output":[{"content":[{"type":"message","text":"Fallback text"}]}]}"#; + let response: ResponsesResponse = serde_json::from_str(json).unwrap(); + assert_eq!( + extract_responses_text(response).as_deref(), + Some("Fallback text") + ); + } }