feat(provider): add native tool calling API (supersedes #450)
Co-authored-by: YubinghanBai <baiyubinghan@gmail.com>
This commit is contained in:
parent
767c66f3c8
commit
b2690f6809
2 changed files with 439 additions and 22 deletions
|
|
@ -2,9 +2,9 @@ use super::sqlite::SqliteMemory;
|
||||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use parking_lot::Mutex;
|
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Mutex;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
use tokio::time::timeout;
|
use tokio::time::timeout;
|
||||||
|
|
@ -116,7 +116,9 @@ impl LucidMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn in_failure_cooldown(&self) -> bool {
|
fn in_failure_cooldown(&self) -> bool {
|
||||||
let guard = self.last_failure_at.lock();
|
let Ok(guard) = self.last_failure_at.lock() else {
|
||||||
|
return false;
|
||||||
|
};
|
||||||
|
|
||||||
guard
|
guard
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -124,11 +126,15 @@ impl LucidMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mark_failure_now(&self) {
|
fn mark_failure_now(&self) {
|
||||||
*self.last_failure_at.lock() = Some(Instant::now());
|
if let Ok(mut guard) = self.last_failure_at.lock() {
|
||||||
|
*guard = Some(Instant::now());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clear_failure(&self) {
|
fn clear_failure(&self) {
|
||||||
*self.last_failure_at.lock() = None;
|
if let Ok(mut guard) = self.last_failure_at.lock() {
|
||||||
|
*guard = None;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
|
fn to_lucid_type(category: &MemoryCategory) -> &'static str {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ use crate::tools::ToolSpec;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use futures_util::{stream, StreamExt};
|
use futures_util::{stream, StreamExt};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
/// A single message in a conversation.
|
/// A single message in a conversation.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -76,13 +77,6 @@ pub struct ChatRequest<'a> {
|
||||||
pub tools: Option<&'a [ToolSpec]>,
|
pub tools: Option<&'a [ToolSpec]>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Declares optional provider features.
|
|
||||||
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
|
|
||||||
pub struct ProviderCapabilities {
|
|
||||||
/// Provider can perform native tool calling without prompt-level emulation.
|
|
||||||
pub native_tool_calling: bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
/// A tool result to feed back to the LLM.
|
/// A tool result to feed back to the LLM.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct ToolResultMessage {
|
pub struct ToolResultMessage {
|
||||||
|
|
@ -198,6 +192,40 @@ pub enum StreamError {
|
||||||
Io(#[from] std::io::Error),
|
Io(#[from] std::io::Error),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Provider capabilities declaration.
|
||||||
|
///
|
||||||
|
/// Describes what features a provider supports, enabling intelligent
|
||||||
|
/// adaptation of tool calling modes and request formatting.
|
||||||
|
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||||
|
pub struct ProviderCapabilities {
|
||||||
|
/// Whether the provider supports native tool calling via API primitives.
|
||||||
|
///
|
||||||
|
/// When `true`, the provider can convert tool definitions to API-native
|
||||||
|
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
|
||||||
|
///
|
||||||
|
/// When `false`, tools must be injected via system prompt as text.
|
||||||
|
pub native_tool_calling: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Provider-specific tool payload formats.
|
||||||
|
///
|
||||||
|
/// Different LLM providers require different formats for tool definitions.
|
||||||
|
/// This enum encapsulates those variations, enabling providers to convert
|
||||||
|
/// from the unified `ToolSpec` format to their native API requirements.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub enum ToolsPayload {
|
||||||
|
/// Gemini API format (functionDeclarations).
|
||||||
|
Gemini {
|
||||||
|
function_declarations: Vec<serde_json::Value>,
|
||||||
|
},
|
||||||
|
/// Anthropic Messages API format (tools with input_schema).
|
||||||
|
Anthropic { tools: Vec<serde_json::Value> },
|
||||||
|
/// OpenAI Chat Completions API format (tools with function).
|
||||||
|
OpenAI { tools: Vec<serde_json::Value> },
|
||||||
|
/// Prompt-guided fallback (tools injected as text in system prompt).
|
||||||
|
PromptGuided { instructions: String },
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
/// Query provider capabilities.
|
/// Query provider capabilities.
|
||||||
|
|
@ -207,6 +235,19 @@ pub trait Provider: Send + Sync {
|
||||||
fn capabilities(&self) -> ProviderCapabilities {
|
fn capabilities(&self) -> ProviderCapabilities {
|
||||||
ProviderCapabilities::default()
|
ProviderCapabilities::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Convert tool specifications to provider-native format.
|
||||||
|
///
|
||||||
|
/// Default implementation returns `PromptGuided` payload, which injects
|
||||||
|
/// tool documentation into the system prompt as text. Providers with
|
||||||
|
/// native tool calling support should override this to return their
|
||||||
|
/// specific format (Gemini, Anthropic, OpenAI).
|
||||||
|
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
|
||||||
|
ToolsPayload::PromptGuided {
|
||||||
|
instructions: build_tool_instructions_text(tools),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||||
///
|
///
|
||||||
/// This is the preferred API for non-agentic direct interactions.
|
/// This is the preferred API for non-agentic direct interactions.
|
||||||
|
|
@ -259,6 +300,43 @@ pub trait Provider: Send + Sync {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<ChatResponse> {
|
||||||
|
// If tools are provided but provider doesn't support native tools,
|
||||||
|
// inject tool instructions into system prompt as fallback.
|
||||||
|
if let Some(tools) = request.tools {
|
||||||
|
if !tools.is_empty() && !self.supports_native_tools() {
|
||||||
|
let tool_instructions = match self.convert_tools(tools) {
|
||||||
|
ToolsPayload::PromptGuided { instructions } => instructions,
|
||||||
|
payload => {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let mut modified_messages = request.messages.to_vec();
|
||||||
|
|
||||||
|
// Inject tool instructions into an existing system message.
|
||||||
|
// If none exists, prepend one to the conversation.
|
||||||
|
if let Some(system_message) =
|
||||||
|
modified_messages.iter_mut().find(|m| m.role == "system")
|
||||||
|
{
|
||||||
|
if !system_message.content.is_empty() {
|
||||||
|
system_message.content.push_str("\n\n");
|
||||||
|
}
|
||||||
|
system_message.content.push_str(&tool_instructions);
|
||||||
|
} else {
|
||||||
|
modified_messages.insert(0, ChatMessage::system(tool_instructions));
|
||||||
|
}
|
||||||
|
|
||||||
|
let text = self
|
||||||
|
.chat_with_history(&modified_messages, model, temperature)
|
||||||
|
.await?;
|
||||||
|
return Ok(ChatResponse {
|
||||||
|
text: Some(text),
|
||||||
|
tool_calls: Vec::new(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let text = self
|
let text = self
|
||||||
.chat_with_history(request.messages, model, temperature)
|
.chat_with_history(request.messages, model, temperature)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -321,21 +399,11 @@ pub trait Provider: Send + Sync {
|
||||||
/// Default implementation falls back to stream_chat_with_system with last user message.
|
/// Default implementation falls back to stream_chat_with_system with last user message.
|
||||||
fn stream_chat_with_history(
|
fn stream_chat_with_history(
|
||||||
&self,
|
&self,
|
||||||
messages: &[ChatMessage],
|
_messages: &[ChatMessage],
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
_options: StreamOptions,
|
_options: StreamOptions,
|
||||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||||
let _system = messages
|
|
||||||
.iter()
|
|
||||||
.find(|m| m.role == "system")
|
|
||||||
.map(|m| m.content.clone());
|
|
||||||
let _last_user = messages
|
|
||||||
.iter()
|
|
||||||
.rfind(|m| m.role == "user")
|
|
||||||
.map(|m| m.content.clone())
|
|
||||||
.unwrap_or_default();
|
|
||||||
|
|
||||||
// For default implementation, we need to convert to owned strings
|
// For default implementation, we need to convert to owned strings
|
||||||
// This is a limitation of the default implementation
|
// This is a limitation of the default implementation
|
||||||
let provider_name = "unknown".to_string();
|
let provider_name = "unknown".to_string();
|
||||||
|
|
@ -346,6 +414,39 @@ pub trait Provider: Send + Sync {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build tool instructions text for prompt-guided tool calling.
|
||||||
|
///
|
||||||
|
/// Generates a formatted text block describing available tools and how to
|
||||||
|
/// invoke them using XML-style tags. This is used as a fallback when the
|
||||||
|
/// provider doesn't support native tool calling.
|
||||||
|
pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
|
||||||
|
let mut instructions = String::new();
|
||||||
|
|
||||||
|
instructions.push_str("## Tool Use Protocol\n\n");
|
||||||
|
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||||
|
instructions.push_str("<tool_call>\n");
|
||||||
|
instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
|
||||||
|
instructions.push_str("\n</tool_call>\n\n");
|
||||||
|
instructions.push_str("You may use multiple tool calls in a single response. ");
|
||||||
|
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
|
||||||
|
instructions
|
||||||
|
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||||
|
instructions.push_str("### Available Tools\n\n");
|
||||||
|
|
||||||
|
for tool in tools {
|
||||||
|
writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
|
||||||
|
.expect("writing to String cannot fail");
|
||||||
|
|
||||||
|
let parameters =
|
||||||
|
serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
|
||||||
|
writeln!(&mut instructions, "Parameters: `{parameters}`")
|
||||||
|
.expect("writing to String cannot fail");
|
||||||
|
instructions.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -461,4 +562,314 @@ mod tests {
|
||||||
let provider = CapabilityMockProvider;
|
let provider = CapabilityMockProvider;
|
||||||
assert!(provider.supports_native_tools());
|
assert!(provider.supports_native_tools());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tools_payload_variants() {
|
||||||
|
// Test Gemini variant
|
||||||
|
let gemini = ToolsPayload::Gemini {
|
||||||
|
function_declarations: vec![serde_json::json!({"name": "test"})],
|
||||||
|
};
|
||||||
|
assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
|
||||||
|
|
||||||
|
// Test Anthropic variant
|
||||||
|
let anthropic = ToolsPayload::Anthropic {
|
||||||
|
tools: vec![serde_json::json!({"name": "test"})],
|
||||||
|
};
|
||||||
|
assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
|
||||||
|
|
||||||
|
// Test OpenAI variant
|
||||||
|
let openai = ToolsPayload::OpenAI {
|
||||||
|
tools: vec![serde_json::json!({"type": "function"})],
|
||||||
|
};
|
||||||
|
assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
|
||||||
|
|
||||||
|
// Test PromptGuided variant
|
||||||
|
let prompt_guided = ToolsPayload::PromptGuided {
|
||||||
|
instructions: "Use tools...".to_string(),
|
||||||
|
};
|
||||||
|
assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_tool_instructions_text_format() {
|
||||||
|
let tools = vec![
|
||||||
|
ToolSpec {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Execute commands".to_string(),
|
||||||
|
parameters: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"command": {"type": "string"}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
ToolSpec {
|
||||||
|
name: "file_read".to_string(),
|
||||||
|
description: "Read files".to_string(),
|
||||||
|
parameters: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"path": {"type": "string"}
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let instructions = build_tool_instructions_text(&tools);
|
||||||
|
|
||||||
|
// Check for protocol description
|
||||||
|
assert!(instructions.contains("Tool Use Protocol"));
|
||||||
|
assert!(instructions.contains("<tool_call>"));
|
||||||
|
assert!(instructions.contains("</tool_call>"));
|
||||||
|
|
||||||
|
// Check for tool listings
|
||||||
|
assert!(instructions.contains("**shell**"));
|
||||||
|
assert!(instructions.contains("Execute commands"));
|
||||||
|
assert!(instructions.contains("**file_read**"));
|
||||||
|
assert!(instructions.contains("Read files"));
|
||||||
|
|
||||||
|
// Check for parameters
|
||||||
|
assert!(instructions.contains("Parameters:"));
|
||||||
|
assert!(instructions.contains(r#""type":"object""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_tool_instructions_text_empty() {
|
||||||
|
let instructions = build_tool_instructions_text(&[]);
|
||||||
|
|
||||||
|
// Should still have protocol description
|
||||||
|
assert!(instructions.contains("Tool Use Protocol"));
|
||||||
|
|
||||||
|
// Should have empty tools section
|
||||||
|
assert!(instructions.contains("Available Tools"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock provider for testing.
|
||||||
|
struct MockProvider {
|
||||||
|
supports_native: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
self.supports_native
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok("response".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_convert_tools_default() {
|
||||||
|
let provider = MockProvider {
|
||||||
|
supports_native: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools = vec![ToolSpec {
|
||||||
|
name: "test_tool".to_string(),
|
||||||
|
description: "A test tool".to_string(),
|
||||||
|
parameters: serde_json::json!({"type": "object"}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let payload = provider.convert_tools(&tools);
|
||||||
|
|
||||||
|
// Default implementation should return PromptGuided.
|
||||||
|
assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
|
||||||
|
|
||||||
|
if let ToolsPayload::PromptGuided { instructions } = payload {
|
||||||
|
assert!(instructions.contains("test_tool"));
|
||||||
|
assert!(instructions.contains("A test tool"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_chat_prompt_guided_fallback() {
|
||||||
|
let provider = MockProvider {
|
||||||
|
supports_native: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools = vec![ToolSpec {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Run commands".to_string(),
|
||||||
|
parameters: serde_json::json!({"type": "object"}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
messages: &[ChatMessage::user("Hello")],
|
||||||
|
tools: Some(&tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||||
|
|
||||||
|
// Should return a response (default impl calls chat_with_history).
|
||||||
|
assert!(response.text.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_chat_without_tools() {
|
||||||
|
let provider = MockProvider {
|
||||||
|
supports_native: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
messages: &[ChatMessage::user("Hello")],
|
||||||
|
tools: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||||
|
|
||||||
|
// Should work normally without tools.
|
||||||
|
assert!(response.text.is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider that echoes the system prompt for assertions.
|
||||||
|
struct EchoSystemProvider {
|
||||||
|
supports_native: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for EchoSystemProvider {
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
self.supports_native
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
system: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok(system.unwrap_or_default().to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider with custom prompt-guided conversion.
|
||||||
|
struct CustomConvertProvider;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for CustomConvertProvider {
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
|
||||||
|
ToolsPayload::PromptGuided {
|
||||||
|
instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
system: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok(system.unwrap_or_default().to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Provider returning an invalid payload for non-native mode.
|
||||||
|
struct InvalidConvertProvider;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for InvalidConvertProvider {
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
|
||||||
|
ToolsPayload::OpenAI {
|
||||||
|
tools: vec![serde_json::json!({"type": "function"})],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok("should_not_reach".to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
|
||||||
|
let provider = EchoSystemProvider {
|
||||||
|
supports_native: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools = vec![ToolSpec {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Run commands".to_string(),
|
||||||
|
parameters: serde_json::json!({"type": "object"}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
messages: &[
|
||||||
|
ChatMessage::user("Hello"),
|
||||||
|
ChatMessage::system("BASE_SYSTEM_PROMPT"),
|
||||||
|
],
|
||||||
|
tools: Some(&tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||||
|
let text = response.text.unwrap_or_default();
|
||||||
|
|
||||||
|
assert!(text.contains("BASE_SYSTEM_PROMPT"));
|
||||||
|
assert!(text.contains("Tool Use Protocol"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_chat_prompt_guided_uses_convert_tools_override() {
|
||||||
|
let provider = CustomConvertProvider;
|
||||||
|
|
||||||
|
let tools = vec![ToolSpec {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Run commands".to_string(),
|
||||||
|
parameters: serde_json::json!({"type": "object"}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
|
||||||
|
tools: Some(&tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||||
|
let text = response.text.unwrap_or_default();
|
||||||
|
|
||||||
|
assert!(text.contains("BASE"));
|
||||||
|
assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
|
||||||
|
let provider = InvalidConvertProvider;
|
||||||
|
|
||||||
|
let tools = vec![ToolSpec {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Run commands".to_string(),
|
||||||
|
parameters: serde_json::json!({"type": "object"}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
messages: &[ChatMessage::user("Hello")],
|
||||||
|
tools: Some(&tools),
|
||||||
|
};
|
||||||
|
|
||||||
|
let err = provider.chat(request, "model", 0.7).await.unwrap_err();
|
||||||
|
let message = err.to_string();
|
||||||
|
|
||||||
|
assert!(message.contains("non-prompt-guided"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue