From b2690f680993fa277fc5449ac67bdbefc5590a7e Mon Sep 17 00:00:00 2001 From: Chummy Date: Tue, 17 Feb 2026 22:46:31 +0800 Subject: [PATCH] feat(provider): add native tool calling API (supersedes #450) Co-authored-by: YubinghanBai --- src/memory/lucid.rs | 14 +- src/providers/traits.rs | 447 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 439 insertions(+), 22 deletions(-) diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs index 7ea75a0..ab27840 100644 --- a/src/memory/lucid.rs +++ b/src/memory/lucid.rs @@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory; use super::traits::{Memory, MemoryCategory, MemoryEntry}; use async_trait::async_trait; use chrono::Local; -use parking_lot::Mutex; use std::collections::HashSet; use std::path::{Path, PathBuf}; +use std::sync::Mutex; use std::time::{Duration, Instant}; use tokio::process::Command; use tokio::time::timeout; @@ -116,7 +116,9 @@ impl LucidMemory { } fn in_failure_cooldown(&self) -> bool { - let guard = self.last_failure_at.lock(); + let Ok(guard) = self.last_failure_at.lock() else { + return false; + }; guard .as_ref() @@ -124,11 +126,15 @@ impl LucidMemory { } fn mark_failure_now(&self) { - *self.last_failure_at.lock() = Some(Instant::now()); + if let Ok(mut guard) = self.last_failure_at.lock() { + *guard = Some(Instant::now()); + } } fn clear_failure(&self) { - *self.last_failure_at.lock() = None; + if let Ok(mut guard) = self.last_failure_at.lock() { + *guard = None; + } } fn to_lucid_type(category: &MemoryCategory) -> &'static str { diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 1bb296b..1b7af06 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -2,6 +2,7 @@ use crate::tools::ToolSpec; use async_trait::async_trait; use futures_util::{stream, StreamExt}; use serde::{Deserialize, Serialize}; +use std::fmt::Write; /// A single message in a conversation. #[derive(Debug, Clone, Serialize, Deserialize)] @@ -76,13 +77,6 @@ pub struct ChatRequest<'a> { pub tools: Option<&'a [ToolSpec]>, } -/// Declares optional provider features. -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)] -pub struct ProviderCapabilities { - /// Provider can perform native tool calling without prompt-level emulation. - pub native_tool_calling: bool, -} - /// A tool result to feed back to the LLM. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolResultMessage { @@ -198,6 +192,40 @@ pub enum StreamError { Io(#[from] std::io::Error), } +/// Provider capabilities declaration. +/// +/// Describes what features a provider supports, enabling intelligent +/// adaptation of tool calling modes and request formatting. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ProviderCapabilities { + /// Whether the provider supports native tool calling via API primitives. + /// + /// When `true`, the provider can convert tool definitions to API-native + /// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema). + /// + /// When `false`, tools must be injected via system prompt as text. + pub native_tool_calling: bool, +} + +/// Provider-specific tool payload formats. +/// +/// Different LLM providers require different formats for tool definitions. +/// This enum encapsulates those variations, enabling providers to convert +/// from the unified `ToolSpec` format to their native API requirements. +#[derive(Debug, Clone)] +pub enum ToolsPayload { + /// Gemini API format (functionDeclarations). + Gemini { + function_declarations: Vec, + }, + /// Anthropic Messages API format (tools with input_schema). + Anthropic { tools: Vec }, + /// OpenAI Chat Completions API format (tools with function). + OpenAI { tools: Vec }, + /// Prompt-guided fallback (tools injected as text in system prompt). + PromptGuided { instructions: String }, +} + #[async_trait] pub trait Provider: Send + Sync { /// Query provider capabilities. @@ -207,6 +235,19 @@ pub trait Provider: Send + Sync { fn capabilities(&self) -> ProviderCapabilities { ProviderCapabilities::default() } + + /// Convert tool specifications to provider-native format. + /// + /// Default implementation returns `PromptGuided` payload, which injects + /// tool documentation into the system prompt as text. Providers with + /// native tool calling support should override this to return their + /// specific format (Gemini, Anthropic, OpenAI). + fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::PromptGuided { + instructions: build_tool_instructions_text(tools), + } + } + /// Simple one-shot chat (single user message, no explicit system prompt). /// /// This is the preferred API for non-agentic direct interactions. @@ -259,6 +300,43 @@ pub trait Provider: Send + Sync { model: &str, temperature: f64, ) -> anyhow::Result { + // If tools are provided but provider doesn't support native tools, + // inject tool instructions into system prompt as fallback. + if let Some(tools) = request.tools { + if !tools.is_empty() && !self.supports_native_tools() { + let tool_instructions = match self.convert_tools(tools) { + ToolsPayload::PromptGuided { instructions } => instructions, + payload => { + anyhow::bail!( + "Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false" + ) + } + }; + let mut modified_messages = request.messages.to_vec(); + + // Inject tool instructions into an existing system message. + // If none exists, prepend one to the conversation. + if let Some(system_message) = + modified_messages.iter_mut().find(|m| m.role == "system") + { + if !system_message.content.is_empty() { + system_message.content.push_str("\n\n"); + } + system_message.content.push_str(&tool_instructions); + } else { + modified_messages.insert(0, ChatMessage::system(tool_instructions)); + } + + let text = self + .chat_with_history(&modified_messages, model, temperature) + .await?; + return Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + }); + } + } + let text = self .chat_with_history(request.messages, model, temperature) .await?; @@ -321,21 +399,11 @@ pub trait Provider: Send + Sync { /// Default implementation falls back to stream_chat_with_system with last user message. fn stream_chat_with_history( &self, - messages: &[ChatMessage], + _messages: &[ChatMessage], _model: &str, _temperature: f64, _options: StreamOptions, ) -> stream::BoxStream<'static, StreamResult> { - let _system = messages - .iter() - .find(|m| m.role == "system") - .map(|m| m.content.clone()); - let _last_user = messages - .iter() - .rfind(|m| m.role == "user") - .map(|m| m.content.clone()) - .unwrap_or_default(); - // For default implementation, we need to convert to owned strings // This is a limitation of the default implementation let provider_name = "unknown".to_string(); @@ -346,6 +414,39 @@ pub trait Provider: Send + Sync { } } +/// Build tool instructions text for prompt-guided tool calling. +/// +/// Generates a formatted text block describing available tools and how to +/// invoke them using XML-style tags. This is used as a fallback when the +/// provider doesn't support native tool calling. +pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String { + let mut instructions = String::new(); + + instructions.push_str("## Tool Use Protocol\n\n"); + instructions.push_str("To use a tool, wrap a JSON object in tags:\n\n"); + instructions.push_str("\n"); + instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#); + instructions.push_str("\n\n\n"); + instructions.push_str("You may use multiple tool calls in a single response. "); + instructions.push_str("After tool execution, results appear in tags. "); + instructions + .push_str("Continue reasoning with the results until you can give a final answer.\n\n"); + instructions.push_str("### Available Tools\n\n"); + + for tool in tools { + writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description) + .expect("writing to String cannot fail"); + + let parameters = + serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string()); + writeln!(&mut instructions, "Parameters: `{parameters}`") + .expect("writing to String cannot fail"); + instructions.push('\n'); + } + + instructions +} + #[cfg(test)] mod tests { use super::*; @@ -461,4 +562,314 @@ mod tests { let provider = CapabilityMockProvider; assert!(provider.supports_native_tools()); } + + #[test] + fn tools_payload_variants() { + // Test Gemini variant + let gemini = ToolsPayload::Gemini { + function_declarations: vec![serde_json::json!({"name": "test"})], + }; + assert!(matches!(gemini, ToolsPayload::Gemini { .. })); + + // Test Anthropic variant + let anthropic = ToolsPayload::Anthropic { + tools: vec![serde_json::json!({"name": "test"})], + }; + assert!(matches!(anthropic, ToolsPayload::Anthropic { .. })); + + // Test OpenAI variant + let openai = ToolsPayload::OpenAI { + tools: vec![serde_json::json!({"type": "function"})], + }; + assert!(matches!(openai, ToolsPayload::OpenAI { .. })); + + // Test PromptGuided variant + let prompt_guided = ToolsPayload::PromptGuided { + instructions: "Use tools...".to_string(), + }; + assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. })); + } + + #[test] + fn build_tool_instructions_text_format() { + let tools = vec![ + ToolSpec { + name: "shell".to_string(), + description: "Execute commands".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "command": {"type": "string"} + } + }), + }, + ToolSpec { + name: "file_read".to_string(), + description: "Read files".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "path": {"type": "string"} + } + }), + }, + ]; + + let instructions = build_tool_instructions_text(&tools); + + // Check for protocol description + assert!(instructions.contains("Tool Use Protocol")); + assert!(instructions.contains("")); + assert!(instructions.contains("")); + + // Check for tool listings + assert!(instructions.contains("**shell**")); + assert!(instructions.contains("Execute commands")); + assert!(instructions.contains("**file_read**")); + assert!(instructions.contains("Read files")); + + // Check for parameters + assert!(instructions.contains("Parameters:")); + assert!(instructions.contains(r#""type":"object""#)); + } + + #[test] + fn build_tool_instructions_text_empty() { + let instructions = build_tool_instructions_text(&[]); + + // Should still have protocol description + assert!(instructions.contains("Tool Use Protocol")); + + // Should have empty tools section + assert!(instructions.contains("Available Tools")); + } + + // Mock provider for testing. + struct MockProvider { + supports_native: bool, + } + + #[async_trait] + impl Provider for MockProvider { + fn supports_native_tools(&self) -> bool { + self.supports_native + } + + async fn chat_with_system( + &self, + _system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("response".to_string()) + } + } + + #[test] + fn provider_convert_tools_default() { + let provider = MockProvider { + supports_native: false, + }; + + let tools = vec![ToolSpec { + name: "test_tool".to_string(), + description: "A test tool".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let payload = provider.convert_tools(&tools); + + // Default implementation should return PromptGuided. + assert!(matches!(payload, ToolsPayload::PromptGuided { .. })); + + if let ToolsPayload::PromptGuided { instructions } = payload { + assert!(instructions.contains("test_tool")); + assert!(instructions.contains("A test tool")); + } + } + + #[tokio::test] + async fn provider_chat_prompt_guided_fallback() { + let provider = MockProvider { + supports_native: false, + }; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ChatMessage::user("Hello")], + tools: Some(&tools), + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + + // Should return a response (default impl calls chat_with_history). + assert!(response.text.is_some()); + } + + #[tokio::test] + async fn provider_chat_without_tools() { + let provider = MockProvider { + supports_native: true, + }; + + let request = ChatRequest { + messages: &[ChatMessage::user("Hello")], + tools: None, + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + + // Should work normally without tools. + assert!(response.text.is_some()); + } + + // Provider that echoes the system prompt for assertions. + struct EchoSystemProvider { + supports_native: bool, + } + + #[async_trait] + impl Provider for EchoSystemProvider { + fn supports_native_tools(&self) -> bool { + self.supports_native + } + + async fn chat_with_system( + &self, + system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(system.unwrap_or_default().to_string()) + } + } + + // Provider with custom prompt-guided conversion. + struct CustomConvertProvider; + + #[async_trait] + impl Provider for CustomConvertProvider { + fn supports_native_tools(&self) -> bool { + false + } + + fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::PromptGuided { + instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(), + } + } + + async fn chat_with_system( + &self, + system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok(system.unwrap_or_default().to_string()) + } + } + + // Provider returning an invalid payload for non-native mode. + struct InvalidConvertProvider; + + #[async_trait] + impl Provider for InvalidConvertProvider { + fn supports_native_tools(&self) -> bool { + false + } + + fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload { + ToolsPayload::OpenAI { + tools: vec![serde_json::json!({"type": "function"})], + } + } + + async fn chat_with_system( + &self, + _system: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("should_not_reach".to_string()) + } + } + + #[tokio::test] + async fn provider_chat_prompt_guided_preserves_existing_system_not_first() { + let provider = EchoSystemProvider { + supports_native: false, + }; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ + ChatMessage::user("Hello"), + ChatMessage::system("BASE_SYSTEM_PROMPT"), + ], + tools: Some(&tools), + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + let text = response.text.unwrap_or_default(); + + assert!(text.contains("BASE_SYSTEM_PROMPT")); + assert!(text.contains("Tool Use Protocol")); + } + + #[tokio::test] + async fn provider_chat_prompt_guided_uses_convert_tools_override() { + let provider = CustomConvertProvider; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")], + tools: Some(&tools), + }; + + let response = provider.chat(request, "model", 0.7).await.unwrap(); + let text = response.text.unwrap_or_default(); + + assert!(text.contains("BASE")); + assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS")); + } + + #[tokio::test] + async fn provider_chat_prompt_guided_rejects_non_prompt_payload() { + let provider = InvalidConvertProvider; + + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object"}), + }]; + + let request = ChatRequest { + messages: &[ChatMessage::user("Hello")], + tools: Some(&tools), + }; + + let err = provider.chat(request, "model", 0.7).await.unwrap_err(); + let message = err.to_string(); + + assert!(message.contains("non-prompt-guided")); + } }