feat: add agent structure and improve tooling for provider

This commit is contained in:
mai1015 2026-02-16 00:40:43 -05:00 committed by Chummy
parent e2c966d31e
commit b341fdb368
21 changed files with 2567 additions and 443 deletions

View file

@ -10,14 +10,8 @@
use crate::channels::{Channel, WhatsAppChannel};
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::observability::{self, Observer};
use crate::providers::{self, ChatMessage, Provider};
use crate::runtime;
use crate::security::{
pairing::{constant_time_eq, is_public_bind, PairingGuard},
SecurityPolicy,
};
use crate::tools::{self, Tool};
use crate::providers::{self, Provider};
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use axum::{
@ -51,35 +45,6 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
format!("whatsapp_{}_{}", msg.sender, msg.id)
}
fn normalize_gateway_reply(reply: String) -> String {
if reply.trim().is_empty() {
return "Model returned an empty response.".to_string();
}
reply
}
async fn gateway_agent_reply(state: &AppState, message: &str) -> Result<String> {
let mut history = vec![
ChatMessage::system(state.system_prompt.as_str()),
ChatMessage::user(message),
];
let reply = crate::agent::loop_::run_tool_call_loop(
state.provider.as_ref(),
&mut history,
state.tools_registry.as_ref(),
state.observer.as_ref(),
"gateway",
&state.model,
state.temperature,
true, // silent — gateway responses go over HTTP
)
.await?;
Ok(normalize_gateway_reply(reply))
}
/// How often the rate limiter sweeps stale IP entries from its map.
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
@ -207,9 +172,6 @@ fn client_key_from_headers(headers: &HeaderMap) -> String {
#[derive(Clone)]
pub struct AppState {
pub provider: Arc<dyn Provider>,
pub observer: Arc<dyn Observer>,
pub tools_registry: Arc<Vec<Box<dyn Tool>>>,
pub system_prompt: Arc<String>,
pub model: String,
pub temperature: f64,
pub mem: Arc<dyn Memory>,
@ -256,55 +218,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
&config.workspace_dir,
config.api_key.as_deref(),
)?);
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let runtime: Arc<dyn runtime::RuntimeAdapter> =
Arc::from(runtime::create_runtime(&config.runtime)?);
let security = Arc::new(SecurityPolicy::from_config(
&config.autonomy,
&config.workspace_dir,
));
let (composio_key, composio_entity_id) = if config.composio.enabled {
(
config.composio.api_key.as_deref(),
Some(config.composio.entity_id.as_str()),
)
} else {
(None, None)
};
let tools_registry = Arc::new(tools::all_tools_with_runtime(
&security,
runtime,
Arc::clone(&mem),
composio_key,
composio_entity_id,
&config.browser,
&config.http_request,
&config.workspace_dir,
&config.agents,
config.api_key.as_deref(),
&config,
));
let skills = crate::skills::load_skills(&config.workspace_dir);
let tool_descs: Vec<(&str, &str)> = tools_registry
.iter()
.map(|tool| (tool.name(), tool.description()))
.collect();
let mut system_prompt = crate::channels::build_system_prompt(
&config.workspace_dir,
&model,
&tool_descs,
&skills,
Some(&config.identity),
None, // bootstrap_max_chars — no compact context for gateway
);
system_prompt.push_str(&crate::agent::loop_::build_tool_instructions(
tools_registry.as_ref(),
));
let system_prompt = Arc::new(system_prompt);
// Extract webhook secret for authentication
let webhook_secret: Option<Arc<str>> = config
@ -408,9 +322,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
// Build shared state
let state = AppState {
provider,
observer,
tools_registry,
system_prompt,
model,
temperature,
mem,
@ -594,9 +505,13 @@ async fn handle_webhook(
.await;
}
match gateway_agent_reply(&state, message).await {
Ok(reply) => {
let body = serde_json::json!({"response": reply, "model": state.model});
match state
.provider
.simple_chat(message, &state.model, state.temperature)
.await
{
Ok(response) => {
let body = serde_json::json!({"response": response, "model": state.model});
(StatusCode::OK, Json(body))
}
Err(e) => {
@ -744,10 +659,14 @@ async fn handle_whatsapp_message(
}
// Call the LLM
match gateway_agent_reply(&state, &msg.content).await {
Ok(reply) => {
match state
.provider
.simple_chat(&msg.content, &state.model, state.temperature)
.await
{
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa.send(&reply, &msg.sender).await {
if let Err(e) = wa.send(&response, &msg.sender).await {
tracing::error!("Failed to send WhatsApp reply: {e}");
}
}
@ -966,9 +885,9 @@ mod tests {
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<crate::providers::ChatResponse> {
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok(crate::providers::ChatResponse::with_text("ok"))
Ok("ok".into())
}
}
@ -1029,36 +948,25 @@ mod tests {
}
}
fn test_app_state(
provider: Arc<dyn Provider>,
memory: Arc<dyn Memory>,
auto_save: bool,
) -> AppState {
AppState {
provider,
observer: Arc::new(crate::observability::NoopObserver),
tools_registry: Arc::new(Vec::new()),
system_prompt: Arc::new("test-system-prompt".into()),
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
}
}
#[tokio::test]
async fn webhook_idempotency_skips_duplicate_provider_calls() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = test_app_state(provider, memory, false);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
@ -1094,7 +1002,19 @@ mod tests {
let tracking_impl = Arc::new(TrackingMemory::default());
let memory: Arc<dyn Memory> = tracking_impl.clone();
let state = test_app_state(provider, memory, true);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: true,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let headers = HeaderMap::new();
@ -1126,110 +1046,6 @@ mod tests {
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
}
#[derive(Default)]
struct StructuredToolCallProvider {
calls: AtomicUsize,
}
#[async_trait]
impl Provider for StructuredToolCallProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<crate::providers::ChatResponse> {
let turn = self.calls.fetch_add(1, Ordering::SeqCst);
if turn == 0 {
return Ok(crate::providers::ChatResponse {
text: Some("Running tool...".into()),
tool_calls: vec![crate::providers::ToolCall {
id: "call_1".into(),
name: "mock_tool".into(),
arguments: r#"{"query":"gateway"}"#.into(),
}],
});
}
Ok(crate::providers::ChatResponse::with_text(
"Gateway tool result ready.",
))
}
}
struct MockTool {
calls: Arc<AtomicUsize>,
}
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
"mock_tool"
}
fn description(&self) -> &str {
"Mock tool for gateway tests"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string"}
},
"required": ["query"]
})
}
async fn execute(
&self,
args: serde_json::Value,
) -> anyhow::Result<crate::tools::ToolResult> {
self.calls.fetch_add(1, Ordering::SeqCst);
assert_eq!(args["query"], "gateway");
Ok(crate::tools::ToolResult {
success: true,
output: "ok".into(),
error: None,
})
}
}
#[tokio::test]
async fn webhook_executes_structured_tool_calls() {
let provider_impl = Arc::new(StructuredToolCallProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let tool_calls = Arc::new(AtomicUsize::new(0));
let tools: Vec<Box<dyn Tool>> = vec![Box::new(MockTool {
calls: Arc::clone(&tool_calls),
})];
let mut state = test_app_state(provider, memory, false);
state.tools_registry = Arc::new(tools);
let response = handle_webhook(
State(state),
HeaderMap::new(),
Ok(Json(WebhookBody {
message: "please use tool".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::OK);
let payload = response.into_body().collect().await.unwrap().to_bytes();
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(parsed["response"], "Gateway tool result ready.");
assert_eq!(tool_calls.load(Ordering::SeqCst), 1);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
}
// ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════