fix(providers): implement chat_with_tools for OpenAiProvider
The OpenAiProvider overrode chat() with native tool support but never overrode chat_with_tools(), which is the method called by run_tool_call_loop in channel mode (IRC/Discord/etc). The trait default for chat_with_tools() silently drops the tools parameter, sending plain ChatRequest with no tools — causing the model to never use native tool calls in channel mode. Add chat_with_tools() override that deserializes tool specs, uses convert_messages() for proper tool_call_id handling, and sends NativeChatRequest with tools and tool_choice. Also add Deserialize derive to NativeToolSpec and NativeToolFunctionSpec to support deserialization from OpenAI-format JSON.
This commit is contained in:
parent
d8409b0878
commit
f76c1226f1
1 changed files with 98 additions and 2 deletions
|
|
@ -75,14 +75,14 @@ struct NativeMessage {
|
||||||
tool_calls: Option<Vec<NativeToolCall>>,
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct NativeToolSpec {
|
struct NativeToolSpec {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
kind: String,
|
kind: String,
|
||||||
function: NativeToolFunctionSpec,
|
function: NativeToolFunctionSpec,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
struct NativeToolFunctionSpec {
|
struct NativeToolFunctionSpec {
|
||||||
name: String,
|
name: String,
|
||||||
description: String,
|
description: String,
|
||||||
|
|
@ -354,6 +354,58 @@ impl Provider for OpenAiProvider {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_tools(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[serde_json::Value],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.filter_map(|t| serde_json::from_value(t.clone()).ok())
|
||||||
|
.collect(),
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
let native_request = NativeChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: Self::convert_messages(messages),
|
||||||
|
temperature,
|
||||||
|
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||||
|
tools: native_tools,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/chat/completions", self.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
|
.json(&native_request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error("OpenAI", response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let native_response: NativeChatResponse = response.json().await?;
|
||||||
|
let message = native_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.message)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
|
||||||
|
Ok(Self::parse_native_response(message))
|
||||||
|
}
|
||||||
|
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
if let Some(credential) = self.credential.as_ref() {
|
if let Some(credential) = self.credential.as_ref() {
|
||||||
self.http_client()
|
self.http_client()
|
||||||
|
|
@ -537,4 +589,48 @@ mod tests {
|
||||||
let msg = &resp.choices[0].message;
|
let msg = &resp.choices[0].message;
|
||||||
assert_eq!(msg.effective_content(), Some("Real answer".to_string()));
|
assert_eq!(msg.effective_content(), Some("Real answer".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_tools_fails_without_key() {
|
||||||
|
let p = OpenAiProvider::new(None);
|
||||||
|
let messages = vec![ChatMessage::user("hello".to_string())];
|
||||||
|
let tools = vec![serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "shell",
|
||||||
|
"description": "Run a shell command",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})];
|
||||||
|
let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn native_tool_spec_deserializes_from_openai_format() {
|
||||||
|
let json = serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "shell",
|
||||||
|
"description": "Run a shell command",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": { "type": "string" }
|
||||||
|
},
|
||||||
|
"required": ["command"]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let spec: NativeToolSpec = serde_json::from_value(json).unwrap();
|
||||||
|
assert_eq!(spec.kind, "function");
|
||||||
|
assert_eq!(spec.function.name, "shell");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue