diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 90ed340..35fef81 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -75,14 +75,14 @@ struct NativeMessage { tool_calls: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] struct NativeToolSpec { #[serde(rename = "type")] kind: String, function: NativeToolFunctionSpec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] struct NativeToolFunctionSpec { name: String, description: String, @@ -354,6 +354,58 @@ impl Provider for OpenAiProvider { true } + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") + })?; + + let native_tools: Option> = if tools.is_empty() { + None + } else { + Some( + tools + .iter() + .filter_map(|t| serde_json::from_value(t.clone()).ok()) + .collect(), + ) + }; + + let native_request = NativeChatRequest { + model: model.to_string(), + messages: Self::convert_messages(messages), + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let response = self + .client + .post(format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {credential}")) + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenAI", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + Ok(Self::parse_native_response(message)) + } + async fn warmup(&self) -> anyhow::Result<()> { if let Some(credential) = self.credential.as_ref() { self.http_client() @@ -537,4 +589,48 @@ mod tests { let msg = &resp.choices[0].message; assert_eq!(msg.effective_content(), Some("Real answer".to_string())); } + + #[tokio::test] + async fn chat_with_tools_fails_without_key() { + let p = OpenAiProvider::new(None); + let messages = vec![ChatMessage::user("hello".to_string())]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + })]; + let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[test] + fn native_tool_spec_deserializes_from_openai_format() { + let json = serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + }); + let spec: NativeToolSpec = serde_json::from_value(json).unwrap(); + assert_eq!(spec.kind, "function"); + assert_eq!(spec.function.name, "shell"); + } }