feat(provider): add native tool calling API (supersedes #450)

Co-authored-by: YubinghanBai <baiyubinghan@gmail.com>
This commit is contained in:
Chummy 2026-02-17 22:46:31 +08:00
parent 767c66f3c8
commit b2690f6809
2 changed files with 439 additions and 22 deletions

View file

@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use async_trait::async_trait;
use chrono::Local;
use parking_lot::Mutex;
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::process::Command;
use tokio::time::timeout;
@ -116,7 +116,9 @@ impl LucidMemory {
}
fn in_failure_cooldown(&self) -> bool {
let guard = self.last_failure_at.lock();
let Ok(guard) = self.last_failure_at.lock() else {
return false;
};
guard
.as_ref()
@ -124,11 +126,15 @@ impl LucidMemory {
}
fn mark_failure_now(&self) {
*self.last_failure_at.lock() = Some(Instant::now());
if let Ok(mut guard) = self.last_failure_at.lock() {
*guard = Some(Instant::now());
}
}
fn clear_failure(&self) {
*self.last_failure_at.lock() = None;
if let Ok(mut guard) = self.last_failure_at.lock() {
*guard = None;
}
}
fn to_lucid_type(category: &MemoryCategory) -> &'static str {

View file

@ -2,6 +2,7 @@ use crate::tools::ToolSpec;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
/// A single message in a conversation.
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -76,13 +77,6 @@ pub struct ChatRequest<'a> {
pub tools: Option<&'a [ToolSpec]>,
}
/// Declares optional provider features.
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct ProviderCapabilities {
/// Provider can perform native tool calling without prompt-level emulation.
pub native_tool_calling: bool,
}
/// A tool result to feed back to the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMessage {
@ -198,6 +192,40 @@ pub enum StreamError {
Io(#[from] std::io::Error),
}
/// Provider capabilities declaration.
///
/// Describes what features a provider supports, enabling intelligent
/// adaptation of tool calling modes and request formatting.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProviderCapabilities {
/// Whether the provider supports native tool calling via API primitives.
///
/// When `true`, the provider can convert tool definitions to API-native
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
///
/// When `false`, tools must be injected via system prompt as text.
pub native_tool_calling: bool,
}
/// Provider-specific tool payload formats.
///
/// Different LLM providers require different formats for tool definitions.
/// This enum encapsulates those variations, enabling providers to convert
/// from the unified `ToolSpec` format to their native API requirements.
#[derive(Debug, Clone)]
pub enum ToolsPayload {
/// Gemini API format (functionDeclarations).
Gemini {
function_declarations: Vec<serde_json::Value>,
},
/// Anthropic Messages API format (tools with input_schema).
Anthropic { tools: Vec<serde_json::Value> },
/// OpenAI Chat Completions API format (tools with function).
OpenAI { tools: Vec<serde_json::Value> },
/// Prompt-guided fallback (tools injected as text in system prompt).
PromptGuided { instructions: String },
}
#[async_trait]
pub trait Provider: Send + Sync {
/// Query provider capabilities.
@ -207,6 +235,19 @@ pub trait Provider: Send + Sync {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
/// Convert tool specifications to provider-native format.
///
/// Default implementation returns `PromptGuided` payload, which injects
/// tool documentation into the system prompt as text. Providers with
/// native tool calling support should override this to return their
/// specific format (Gemini, Anthropic, OpenAI).
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::PromptGuided {
instructions: build_tool_instructions_text(tools),
}
}
/// Simple one-shot chat (single user message, no explicit system prompt).
///
/// This is the preferred API for non-agentic direct interactions.
@ -259,6 +300,43 @@ pub trait Provider: Send + Sync {
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
// If tools are provided but provider doesn't support native tools,
// inject tool instructions into system prompt as fallback.
if let Some(tools) = request.tools {
if !tools.is_empty() && !self.supports_native_tools() {
let tool_instructions = match self.convert_tools(tools) {
ToolsPayload::PromptGuided { instructions } => instructions,
payload => {
anyhow::bail!(
"Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
)
}
};
let mut modified_messages = request.messages.to_vec();
// Inject tool instructions into an existing system message.
// If none exists, prepend one to the conversation.
if let Some(system_message) =
modified_messages.iter_mut().find(|m| m.role == "system")
{
if !system_message.content.is_empty() {
system_message.content.push_str("\n\n");
}
system_message.content.push_str(&tool_instructions);
} else {
modified_messages.insert(0, ChatMessage::system(tool_instructions));
}
let text = self
.chat_with_history(&modified_messages, model, temperature)
.await?;
return Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
});
}
}
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
@ -321,21 +399,11 @@ pub trait Provider: Send + Sync {
/// Default implementation falls back to stream_chat_with_system with last user message.
fn stream_chat_with_history(
&self,
messages: &[ChatMessage],
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
let _system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.clone());
let _last_user = messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.clone())
.unwrap_or_default();
// For default implementation, we need to convert to owned strings
// This is a limitation of the default implementation
let provider_name = "unknown".to_string();
@ -346,6 +414,39 @@ pub trait Provider: Send + Sync {
}
}
/// Build tool instructions text for prompt-guided tool calling.
///
/// Generates a formatted text block describing available tools and how to
/// invoke them using XML-style tags. This is used as a fallback when the
/// provider doesn't support native tool calling.
pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
let mut instructions = String::new();
instructions.push_str("## Tool Use Protocol\n\n");
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
instructions.push_str("<tool_call>\n");
instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
instructions.push_str("\n</tool_call>\n\n");
instructions.push_str("You may use multiple tool calls in a single response. ");
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
instructions
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
instructions.push_str("### Available Tools\n\n");
for tool in tools {
writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
.expect("writing to String cannot fail");
let parameters =
serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
writeln!(&mut instructions, "Parameters: `{parameters}`")
.expect("writing to String cannot fail");
instructions.push('\n');
}
instructions
}
#[cfg(test)]
mod tests {
use super::*;
@ -461,4 +562,314 @@ mod tests {
let provider = CapabilityMockProvider;
assert!(provider.supports_native_tools());
}
#[test]
fn tools_payload_variants() {
// Test Gemini variant
let gemini = ToolsPayload::Gemini {
function_declarations: vec![serde_json::json!({"name": "test"})],
};
assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
// Test Anthropic variant
let anthropic = ToolsPayload::Anthropic {
tools: vec![serde_json::json!({"name": "test"})],
};
assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
// Test OpenAI variant
let openai = ToolsPayload::OpenAI {
tools: vec![serde_json::json!({"type": "function"})],
};
assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
// Test PromptGuided variant
let prompt_guided = ToolsPayload::PromptGuided {
instructions: "Use tools...".to_string(),
};
assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
}
#[test]
fn build_tool_instructions_text_format() {
let tools = vec![
ToolSpec {
name: "shell".to_string(),
description: "Execute commands".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"command": {"type": "string"}
}
}),
},
ToolSpec {
name: "file_read".to_string(),
description: "Read files".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"path": {"type": "string"}
}
}),
},
];
let instructions = build_tool_instructions_text(&tools);
// Check for protocol description
assert!(instructions.contains("Tool Use Protocol"));
assert!(instructions.contains("<tool_call>"));
assert!(instructions.contains("</tool_call>"));
// Check for tool listings
assert!(instructions.contains("**shell**"));
assert!(instructions.contains("Execute commands"));
assert!(instructions.contains("**file_read**"));
assert!(instructions.contains("Read files"));
// Check for parameters
assert!(instructions.contains("Parameters:"));
assert!(instructions.contains(r#""type":"object""#));
}
#[test]
fn build_tool_instructions_text_empty() {
let instructions = build_tool_instructions_text(&[]);
// Should still have protocol description
assert!(instructions.contains("Tool Use Protocol"));
// Should have empty tools section
assert!(instructions.contains("Available Tools"));
}
// Mock provider for testing.
struct MockProvider {
supports_native: bool,
}
#[async_trait]
impl Provider for MockProvider {
fn supports_native_tools(&self) -> bool {
self.supports_native
}
async fn chat_with_system(
&self,
_system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("response".to_string())
}
}
#[test]
fn provider_convert_tools_default() {
let provider = MockProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let payload = provider.convert_tools(&tools);
// Default implementation should return PromptGuided.
assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
if let ToolsPayload::PromptGuided { instructions } = payload {
assert!(instructions.contains("test_tool"));
assert!(instructions.contains("A test tool"));
}
}
#[tokio::test]
async fn provider_chat_prompt_guided_fallback() {
let provider = MockProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
// Should return a response (default impl calls chat_with_history).
assert!(response.text.is_some());
}
#[tokio::test]
async fn provider_chat_without_tools() {
let provider = MockProvider {
supports_native: true,
};
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: None,
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
// Should work normally without tools.
assert!(response.text.is_some());
}
// Provider that echoes the system prompt for assertions.
struct EchoSystemProvider {
supports_native: bool,
}
#[async_trait]
impl Provider for EchoSystemProvider {
fn supports_native_tools(&self) -> bool {
self.supports_native
}
async fn chat_with_system(
&self,
system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(system.unwrap_or_default().to_string())
}
}
// Provider with custom prompt-guided conversion.
struct CustomConvertProvider;
#[async_trait]
impl Provider for CustomConvertProvider {
fn supports_native_tools(&self) -> bool {
false
}
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::PromptGuided {
instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
}
}
async fn chat_with_system(
&self,
system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(system.unwrap_or_default().to_string())
}
}
// Provider returning an invalid payload for non-native mode.
struct InvalidConvertProvider;
#[async_trait]
impl Provider for InvalidConvertProvider {
fn supports_native_tools(&self) -> bool {
false
}
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::OpenAI {
tools: vec![serde_json::json!({"type": "function"})],
}
}
async fn chat_with_system(
&self,
_system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("should_not_reach".to_string())
}
}
#[tokio::test]
async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
let provider = EchoSystemProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[
ChatMessage::user("Hello"),
ChatMessage::system("BASE_SYSTEM_PROMPT"),
],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
let text = response.text.unwrap_or_default();
assert!(text.contains("BASE_SYSTEM_PROMPT"));
assert!(text.contains("Tool Use Protocol"));
}
#[tokio::test]
async fn provider_chat_prompt_guided_uses_convert_tools_override() {
let provider = CustomConvertProvider;
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
let text = response.text.unwrap_or_default();
assert!(text.contains("BASE"));
assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
}
#[tokio::test]
async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
let provider = InvalidConvertProvider;
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: Some(&tools),
};
let err = provider.chat(request, "model", 0.7).await.unwrap_err();
let message = err.to_string();
assert!(message.contains("non-prompt-guided"));
}
}