feat(provider): add native tool calling API (supersedes #450)
Co-authored-by: YubinghanBai <baiyubinghan@gmail.com>
This commit is contained in:
parent
767c66f3c8
commit
b2690f6809
2 changed files with 439 additions and 22 deletions
|
|
@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
|
|||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use parking_lot::Mutex;
|
||||
use std::collections::HashSet;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::process::Command;
|
||||
use tokio::time::timeout;
|
||||
|
|
@ -116,7 +116,9 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
fn in_failure_cooldown(&self) -> bool {
|
||||
let guard = self.last_failure_at.lock();
|
||||
let Ok(guard) = self.last_failure_at.lock() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
guard
|
||||
.as_ref()
|
||||
|
|
@ -124,11 +126,15 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
fn mark_failure_now(&self) {
|
||||
*self.last_failure_at.lock() = Some(Instant::now());
|
||||
if let Ok(mut guard) = self.last_failure_at.lock() {
|
||||
*guard = Some(Instant::now());
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_failure(&self) {
|
||||
*self.last_failure_at.lock() = None;
|
||||
if let Ok(mut guard) = self.last_failure_at.lock() {
|
||||
*guard = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ use crate::tools::ToolSpec;
|
|||
use async_trait::async_trait;
|
||||
use futures_util::{stream, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
/// A single message in a conversation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -76,13 +77,6 @@ pub struct ChatRequest<'a> {
|
|||
pub tools: Option<&'a [ToolSpec]>,
|
||||
}
|
||||
|
||||
/// Declares optional provider features.
|
||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
||||
pub struct ProviderCapabilities {
|
||||
/// Provider can perform native tool calling without prompt-level emulation.
|
||||
pub native_tool_calling: bool,
|
||||
}
|
||||
|
||||
/// A tool result to feed back to the LLM.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResultMessage {
|
||||
|
|
@ -198,6 +192,40 @@ pub enum StreamError {
|
|||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// Provider capabilities declaration.
|
||||
///
|
||||
/// Describes what features a provider supports, enabling intelligent
|
||||
/// adaptation of tool calling modes and request formatting.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProviderCapabilities {
|
||||
/// Whether the provider supports native tool calling via API primitives.
|
||||
///
|
||||
/// When `true`, the provider can convert tool definitions to API-native
|
||||
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
|
||||
///
|
||||
/// When `false`, tools must be injected via system prompt as text.
|
||||
pub native_tool_calling: bool,
|
||||
}
|
||||
|
||||
/// Provider-specific tool payload formats.
|
||||
///
|
||||
/// Different LLM providers require different formats for tool definitions.
|
||||
/// This enum encapsulates those variations, enabling providers to convert
|
||||
/// from the unified `ToolSpec` format to their native API requirements.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolsPayload {
|
||||
/// Gemini API format (functionDeclarations).
|
||||
Gemini {
|
||||
function_declarations: Vec<serde_json::Value>,
|
||||
},
|
||||
/// Anthropic Messages API format (tools with input_schema).
|
||||
Anthropic { tools: Vec<serde_json::Value> },
|
||||
/// OpenAI Chat Completions API format (tools with function).
|
||||
OpenAI { tools: Vec<serde_json::Value> },
|
||||
/// Prompt-guided fallback (tools injected as text in system prompt).
|
||||
PromptGuided { instructions: String },
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Query provider capabilities.
|
||||
|
|
@ -207,6 +235,19 @@ pub trait Provider: Send + Sync {
|
|||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities::default()
|
||||
}
|
||||
|
||||
/// Convert tool specifications to provider-native format.
|
||||
///
|
||||
/// Default implementation returns `PromptGuided` payload, which injects
|
||||
/// tool documentation into the system prompt as text. Providers with
|
||||
/// native tool calling support should override this to return their
|
||||
/// specific format (Gemini, Anthropic, OpenAI).
|
||||
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::PromptGuided {
|
||||
instructions: build_tool_instructions_text(tools),
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||
///
|
||||
/// This is the preferred API for non-agentic direct interactions.
|
||||
|
|
@ -259,6 +300,43 @@ pub trait Provider: Send + Sync {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
// If tools are provided but provider doesn't support native tools,
|
||||
// inject tool instructions into system prompt as fallback.
|
||||
if let Some(tools) = request.tools {
|
||||
if !tools.is_empty() && !self.supports_native_tools() {
|
||||
let tool_instructions = match self.convert_tools(tools) {
|
||||
ToolsPayload::PromptGuided { instructions } => instructions,
|
||||
payload => {
|
||||
anyhow::bail!(
|
||||
"Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
|
||||
)
|
||||
}
|
||||
};
|
||||
let mut modified_messages = request.messages.to_vec();
|
||||
|
||||
// Inject tool instructions into an existing system message.
|
||||
// If none exists, prepend one to the conversation.
|
||||
if let Some(system_message) =
|
||||
modified_messages.iter_mut().find(|m| m.role == "system")
|
||||
{
|
||||
if !system_message.content.is_empty() {
|
||||
system_message.content.push_str("\n\n");
|
||||
}
|
||||
system_message.content.push_str(&tool_instructions);
|
||||
} else {
|
||||
modified_messages.insert(0, ChatMessage::system(tool_instructions));
|
||||
}
|
||||
|
||||
let text = self
|
||||
.chat_with_history(&modified_messages, model, temperature)
|
||||
.await?;
|
||||
return Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
|
@ -321,21 +399,11 @@ pub trait Provider: Send + Sync {
|
|||
/// Default implementation falls back to stream_chat_with_system with last user message.
|
||||
fn stream_chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
_options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
let _system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.clone());
|
||||
let _last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.clone())
|
||||
.unwrap_or_default();
|
||||
|
||||
// For default implementation, we need to convert to owned strings
|
||||
// This is a limitation of the default implementation
|
||||
let provider_name = "unknown".to_string();
|
||||
|
|
@ -346,6 +414,39 @@ pub trait Provider: Send + Sync {
|
|||
}
|
||||
}
|
||||
|
||||
/// Build tool instructions text for prompt-guided tool calling.
|
||||
///
|
||||
/// Generates a formatted text block describing available tools and how to
|
||||
/// invoke them using XML-style tags. This is used as a fallback when the
|
||||
/// provider doesn't support native tool calling.
|
||||
pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
|
||||
let mut instructions = String::new();
|
||||
|
||||
instructions.push_str("## Tool Use Protocol\n\n");
|
||||
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||
instructions.push_str("<tool_call>\n");
|
||||
instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
|
||||
instructions.push_str("\n</tool_call>\n\n");
|
||||
instructions.push_str("You may use multiple tool calls in a single response. ");
|
||||
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
|
||||
instructions
|
||||
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||
instructions.push_str("### Available Tools\n\n");
|
||||
|
||||
for tool in tools {
|
||||
writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
|
||||
.expect("writing to String cannot fail");
|
||||
|
||||
let parameters =
|
||||
serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
|
||||
writeln!(&mut instructions, "Parameters: `{parameters}`")
|
||||
.expect("writing to String cannot fail");
|
||||
instructions.push('\n');
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -461,4 +562,314 @@ mod tests {
|
|||
let provider = CapabilityMockProvider;
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tools_payload_variants() {
|
||||
// Test Gemini variant
|
||||
let gemini = ToolsPayload::Gemini {
|
||||
function_declarations: vec![serde_json::json!({"name": "test"})],
|
||||
};
|
||||
assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
|
||||
|
||||
// Test Anthropic variant
|
||||
let anthropic = ToolsPayload::Anthropic {
|
||||
tools: vec![serde_json::json!({"name": "test"})],
|
||||
};
|
||||
assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
|
||||
|
||||
// Test OpenAI variant
|
||||
let openai = ToolsPayload::OpenAI {
|
||||
tools: vec![serde_json::json!({"type": "function"})],
|
||||
};
|
||||
assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
|
||||
|
||||
// Test PromptGuided variant
|
||||
let prompt_guided = ToolsPayload::PromptGuided {
|
||||
instructions: "Use tools...".to_string(),
|
||||
};
|
||||
assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_tool_instructions_text_format() {
|
||||
let tools = vec![
|
||||
ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Execute commands".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: "file_read".to_string(),
|
||||
description: "Read files".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
let instructions = build_tool_instructions_text(&tools);
|
||||
|
||||
// Check for protocol description
|
||||
assert!(instructions.contains("Tool Use Protocol"));
|
||||
assert!(instructions.contains("<tool_call>"));
|
||||
assert!(instructions.contains("</tool_call>"));
|
||||
|
||||
// Check for tool listings
|
||||
assert!(instructions.contains("**shell**"));
|
||||
assert!(instructions.contains("Execute commands"));
|
||||
assert!(instructions.contains("**file_read**"));
|
||||
assert!(instructions.contains("Read files"));
|
||||
|
||||
// Check for parameters
|
||||
assert!(instructions.contains("Parameters:"));
|
||||
assert!(instructions.contains(r#""type":"object""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_tool_instructions_text_empty() {
|
||||
let instructions = build_tool_instructions_text(&[]);
|
||||
|
||||
// Should still have protocol description
|
||||
assert!(instructions.contains("Tool Use Protocol"));
|
||||
|
||||
// Should have empty tools section
|
||||
assert!(instructions.contains("Available Tools"));
|
||||
}
|
||||
|
||||
// Mock provider for testing.
|
||||
struct MockProvider {
|
||||
supports_native: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
self.supports_native
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("response".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_convert_tools_default() {
|
||||
let provider = MockProvider {
|
||||
supports_native: false,
|
||||
};
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "test_tool".to_string(),
|
||||
description: "A test tool".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let payload = provider.convert_tools(&tools);
|
||||
|
||||
// Default implementation should return PromptGuided.
|
||||
assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
|
||||
|
||||
if let ToolsPayload::PromptGuided { instructions } = payload {
|
||||
assert!(instructions.contains("test_tool"));
|
||||
assert!(instructions.contains("A test tool"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_fallback() {
|
||||
let provider = MockProvider {
|
||||
supports_native: false,
|
||||
};
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::user("Hello")],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
|
||||
// Should return a response (default impl calls chat_with_history).
|
||||
assert!(response.text.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_without_tools() {
|
||||
let provider = MockProvider {
|
||||
supports_native: true,
|
||||
};
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::user("Hello")],
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
|
||||
// Should work normally without tools.
|
||||
assert!(response.text.is_some());
|
||||
}
|
||||
|
||||
// Provider that echoes the system prompt for assertions.
|
||||
struct EchoSystemProvider {
|
||||
supports_native: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for EchoSystemProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
self.supports_native
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(system.unwrap_or_default().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Provider with custom prompt-guided conversion.
|
||||
struct CustomConvertProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CustomConvertProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::PromptGuided {
|
||||
instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(system.unwrap_or_default().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Provider returning an invalid payload for non-native mode.
|
||||
struct InvalidConvertProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for InvalidConvertProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::OpenAI {
|
||||
tools: vec![serde_json::json!({"type": "function"})],
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("should_not_reach".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
|
||||
let provider = EchoSystemProvider {
|
||||
supports_native: false,
|
||||
};
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::system("BASE_SYSTEM_PROMPT"),
|
||||
],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
let text = response.text.unwrap_or_default();
|
||||
|
||||
assert!(text.contains("BASE_SYSTEM_PROMPT"));
|
||||
assert!(text.contains("Tool Use Protocol"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_uses_convert_tools_override() {
|
||||
let provider = CustomConvertProvider;
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
let text = response.text.unwrap_or_default();
|
||||
|
||||
assert!(text.contains("BASE"));
|
||||
assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
|
||||
let provider = InvalidConvertProvider;
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::user("Hello")],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let err = provider.chat(request, "model", 0.7).await.unwrap_err();
|
||||
let message = err.to_string();
|
||||
|
||||
assert!(message.contains("non-prompt-guided"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue