fix(provider): complete ChatResponse integration across runtime surfaces
This commit is contained in:
parent
3b4a4de457
commit
34306e32d8
3 changed files with 229 additions and 66 deletions
|
|
@ -10,8 +10,14 @@
|
||||||
use crate::channels::{Channel, WhatsAppChannel};
|
use crate::channels::{Channel, WhatsAppChannel};
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::memory::{self, Memory, MemoryCategory};
|
use crate::memory::{self, Memory, MemoryCategory};
|
||||||
use crate::providers::{self, ChatResponse, Provider};
|
use crate::observability::{self, Observer};
|
||||||
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
use crate::providers::{self, ChatMessage, Provider};
|
||||||
|
use crate::runtime;
|
||||||
|
use crate::security::{
|
||||||
|
pairing::{constant_time_eq, is_public_bind, PairingGuard},
|
||||||
|
SecurityPolicy,
|
||||||
|
};
|
||||||
|
use crate::tools::{self, Tool};
|
||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use axum::{
|
use axum::{
|
||||||
|
|
@ -45,29 +51,33 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
|
||||||
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gateway_reply_from_response(response: ChatResponse) -> String {
|
fn normalize_gateway_reply(reply: String) -> String {
|
||||||
let has_tool_calls = response.has_tool_calls();
|
|
||||||
let tool_call_count = response.tool_calls.len();
|
|
||||||
let mut reply = response.text.unwrap_or_default();
|
|
||||||
|
|
||||||
if has_tool_calls {
|
|
||||||
tracing::warn!(
|
|
||||||
tool_call_count,
|
|
||||||
"Provider requested tool calls in gateway mode; tool calls are not executed here"
|
|
||||||
);
|
|
||||||
if reply.trim().is_empty() {
|
|
||||||
reply = "I need to use tools to answer that, but tool execution is not enabled for gateway requests yet."
|
|
||||||
.to_string();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if reply.trim().is_empty() {
|
if reply.trim().is_empty() {
|
||||||
reply = "Model returned an empty response.".to_string();
|
return "Model returned an empty response.".to_string();
|
||||||
}
|
}
|
||||||
|
|
||||||
reply
|
reply
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn gateway_agent_reply(state: &AppState, message: &str) -> Result<String> {
|
||||||
|
let mut history = vec![
|
||||||
|
ChatMessage::system(state.system_prompt.as_str()),
|
||||||
|
ChatMessage::user(message),
|
||||||
|
];
|
||||||
|
|
||||||
|
let reply = crate::agent::loop_::run_tool_call_loop(
|
||||||
|
state.provider.as_ref(),
|
||||||
|
&mut history,
|
||||||
|
state.tools_registry.as_ref(),
|
||||||
|
state.observer.as_ref(),
|
||||||
|
&state.model,
|
||||||
|
state.temperature,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(normalize_gateway_reply(reply))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct SlidingWindowRateLimiter {
|
struct SlidingWindowRateLimiter {
|
||||||
limit_per_window: u32,
|
limit_per_window: u32,
|
||||||
|
|
@ -182,6 +192,9 @@ fn client_key_from_headers(headers: &HeaderMap) -> String {
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub provider: Arc<dyn Provider>,
|
pub provider: Arc<dyn Provider>,
|
||||||
|
pub observer: Arc<dyn Observer>,
|
||||||
|
pub tools_registry: Arc<Vec<Box<dyn Tool>>>,
|
||||||
|
pub system_prompt: Arc<String>,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub temperature: f64,
|
pub temperature: f64,
|
||||||
pub mem: Arc<dyn Memory>,
|
pub mem: Arc<dyn Memory>,
|
||||||
|
|
@ -228,6 +241,47 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
|
let observer: Arc<dyn Observer> =
|
||||||
|
Arc::from(observability::create_observer(&config.observability));
|
||||||
|
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||||
|
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||||||
|
let security = Arc::new(SecurityPolicy::from_config(
|
||||||
|
&config.autonomy,
|
||||||
|
&config.workspace_dir,
|
||||||
|
));
|
||||||
|
|
||||||
|
let composio_key = if config.composio.enabled {
|
||||||
|
config.composio.api_key.as_deref()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let tools_registry = Arc::new(tools::all_tools_with_runtime(
|
||||||
|
&security,
|
||||||
|
runtime,
|
||||||
|
Arc::clone(&mem),
|
||||||
|
composio_key,
|
||||||
|
&config.browser,
|
||||||
|
&config.agents,
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
));
|
||||||
|
let skills = crate::skills::load_skills(&config.workspace_dir);
|
||||||
|
let tool_descs: Vec<(&str, &str)> = tools_registry
|
||||||
|
.iter()
|
||||||
|
.map(|tool| (tool.name(), tool.description()))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let mut system_prompt = crate::channels::build_system_prompt(
|
||||||
|
&config.workspace_dir,
|
||||||
|
&model,
|
||||||
|
&tool_descs,
|
||||||
|
&skills,
|
||||||
|
Some(&config.identity),
|
||||||
|
);
|
||||||
|
system_prompt.push_str(&crate::agent::loop_::build_tool_instructions(
|
||||||
|
tools_registry.as_ref(),
|
||||||
|
));
|
||||||
|
let system_prompt = Arc::new(system_prompt);
|
||||||
|
|
||||||
// Extract webhook secret for authentication
|
// Extract webhook secret for authentication
|
||||||
let webhook_secret: Option<Arc<str>> = config
|
let webhook_secret: Option<Arc<str>> = config
|
||||||
|
|
@ -331,6 +385,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
// Build shared state
|
// Build shared state
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
provider,
|
provider,
|
||||||
|
observer,
|
||||||
|
tools_registry,
|
||||||
|
system_prompt,
|
||||||
model,
|
model,
|
||||||
temperature,
|
temperature,
|
||||||
mem,
|
mem,
|
||||||
|
|
@ -514,13 +571,8 @@ async fn handle_webhook(
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
match state
|
match gateway_agent_reply(&state, message).await {
|
||||||
.provider
|
Ok(reply) => {
|
||||||
.chat(message, &state.model, state.temperature)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => {
|
|
||||||
let reply = gateway_reply_from_response(response);
|
|
||||||
let body = serde_json::json!({"response": reply, "model": state.model});
|
let body = serde_json::json!({"response": reply, "model": state.model});
|
||||||
(StatusCode::OK, Json(body))
|
(StatusCode::OK, Json(body))
|
||||||
}
|
}
|
||||||
|
|
@ -669,13 +721,8 @@ async fn handle_whatsapp_message(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the LLM
|
// Call the LLM
|
||||||
match state
|
match gateway_agent_reply(&state, &msg.content).await {
|
||||||
.provider
|
Ok(reply) => {
|
||||||
.chat(&msg.content, &state.model, state.temperature)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(response) => {
|
|
||||||
let reply = gateway_reply_from_response(response);
|
|
||||||
// Send reply via WhatsApp
|
// Send reply via WhatsApp
|
||||||
if let Err(e) = wa.send(&reply, &msg.sender).await {
|
if let Err(e) = wa.send(&reply, &msg.sender).await {
|
||||||
tracing::error!("Failed to send WhatsApp reply: {e}");
|
tracing::error!("Failed to send WhatsApp reply: {e}");
|
||||||
|
|
@ -847,9 +894,9 @@ mod tests {
|
||||||
_message: &str,
|
_message: &str,
|
||||||
_model: &str,
|
_model: &str,
|
||||||
_temperature: f64,
|
_temperature: f64,
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<crate::providers::ChatResponse> {
|
||||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
Ok(ChatResponse::with_text("ok"))
|
Ok(crate::providers::ChatResponse::with_text("ok"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -910,25 +957,36 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
fn test_app_state(
|
||||||
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
provider: Arc<dyn Provider>,
|
||||||
let provider_impl = Arc::new(MockProvider::default());
|
memory: Arc<dyn Memory>,
|
||||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
auto_save: bool,
|
||||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
) -> AppState {
|
||||||
|
AppState {
|
||||||
let state = AppState {
|
|
||||||
provider,
|
provider,
|
||||||
|
observer: Arc::new(crate::observability::NoopObserver),
|
||||||
|
tools_registry: Arc::new(Vec::new()),
|
||||||
|
system_prompt: Arc::new("test-system-prompt".into()),
|
||||||
model: "test-model".into(),
|
model: "test-model".into(),
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
mem: memory,
|
mem: memory,
|
||||||
auto_save: false,
|
auto_save,
|
||||||
webhook_secret: None,
|
webhook_secret: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
whatsapp: None,
|
whatsapp: None,
|
||||||
whatsapp_app_secret: None,
|
whatsapp_app_secret: None,
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = test_app_state(provider, memory, false);
|
||||||
|
|
||||||
let mut headers = HeaderMap::new();
|
let mut headers = HeaderMap::new();
|
||||||
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
|
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
|
||||||
|
|
@ -964,19 +1022,7 @@ mod tests {
|
||||||
let tracking_impl = Arc::new(TrackingMemory::default());
|
let tracking_impl = Arc::new(TrackingMemory::default());
|
||||||
let memory: Arc<dyn Memory> = tracking_impl.clone();
|
let memory: Arc<dyn Memory> = tracking_impl.clone();
|
||||||
|
|
||||||
let state = AppState {
|
let state = test_app_state(provider, memory, true);
|
||||||
provider,
|
|
||||||
model: "test-model".into(),
|
|
||||||
temperature: 0.0,
|
|
||||||
mem: memory,
|
|
||||||
auto_save: true,
|
|
||||||
webhook_secret: None,
|
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
|
||||||
whatsapp: None,
|
|
||||||
whatsapp_app_secret: None,
|
|
||||||
};
|
|
||||||
|
|
||||||
let headers = HeaderMap::new();
|
let headers = HeaderMap::new();
|
||||||
|
|
||||||
|
|
@ -1008,6 +1054,110 @@ mod tests {
|
||||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct StructuredToolCallProvider {
|
||||||
|
calls: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for StructuredToolCallProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<crate::providers::ChatResponse> {
|
||||||
|
let turn = self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
|
||||||
|
if turn == 0 {
|
||||||
|
return Ok(crate::providers::ChatResponse {
|
||||||
|
text: Some("Running tool...".into()),
|
||||||
|
tool_calls: vec![crate::providers::ToolCall {
|
||||||
|
id: "call_1".into(),
|
||||||
|
name: "mock_tool".into(),
|
||||||
|
arguments: r#"{"query":"gateway"}"#.into(),
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(crate::providers::ChatResponse::with_text(
|
||||||
|
"Gateway tool result ready.",
|
||||||
|
))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct MockTool {
|
||||||
|
calls: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for MockTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock_tool"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Mock tool for gateway tests"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string"}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(
|
||||||
|
&self,
|
||||||
|
args: serde_json::Value,
|
||||||
|
) -> anyhow::Result<crate::tools::ToolResult> {
|
||||||
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
assert_eq!(args["query"], "gateway");
|
||||||
|
|
||||||
|
Ok(crate::tools::ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: "ok".into(),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_executes_structured_tool_calls() {
|
||||||
|
let provider_impl = Arc::new(StructuredToolCallProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let tool_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockTool {
|
||||||
|
calls: Arc::clone(&tool_calls),
|
||||||
|
})];
|
||||||
|
|
||||||
|
let mut state = test_app_state(provider, memory, false);
|
||||||
|
state.tools_registry = Arc::new(tools);
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
HeaderMap::new(),
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "please use tool".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
let payload = response.into_body().collect().await.unwrap().to_bytes();
|
||||||
|
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
|
||||||
|
assert_eq!(parsed["response"], "Gateway tool result ready.");
|
||||||
|
assert_eq!(tool_calls.load(Ordering::SeqCst), 1);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||||
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@ pub mod reliable;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
|
#[allow(unused_imports)]
|
||||||
pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
||||||
|
|
||||||
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
||||||
|
|
|
||||||
|
|
@ -220,15 +220,27 @@ impl Tool for DelegateTool {
|
||||||
};
|
};
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
Ok(response) => Ok(ToolResult {
|
Ok(response) => {
|
||||||
success: true,
|
let has_tool_calls = response.has_tool_calls();
|
||||||
output: format!(
|
let mut rendered = response.text.unwrap_or_default();
|
||||||
"[Agent '{agent_name}' ({provider}/{model})]\n{response}",
|
if rendered.trim().is_empty() {
|
||||||
provider = agent_config.provider,
|
if has_tool_calls {
|
||||||
model = agent_config.model
|
rendered = "[Tool-only response; no text content]".to_string();
|
||||||
),
|
} else {
|
||||||
error: None,
|
rendered = "[Empty response]".to_string();
|
||||||
}),
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!(
|
||||||
|
"[Agent '{agent_name}' ({provider}/{model})]\n{rendered}",
|
||||||
|
provider = agent_config.provider,
|
||||||
|
model = agent_config.model
|
||||||
|
),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
Err(e) => Ok(ToolResult {
|
Err(e) => Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue