From 6e53341bb1da6c3f3f3a6fd7f2541af963129e1b Mon Sep 17 00:00:00 2001 From: Edvard Date: Tue, 17 Feb 2026 19:53:37 -0500 Subject: [PATCH] feat(agent): add rule-based query classification for automatic model routing Classify incoming user messages by keyword/pattern and route to the appropriate model hint automatically, feeding into the existing RouterProvider. Disabled by default; opt-in via [query_classification] config section. Co-Authored-By: Claude Opus 4.6 --- src/agent/agent.rs | 40 +++++++++- src/agent/classifier.rs | 172 ++++++++++++++++++++++++++++++++++++++++ src/agent/mod.rs | 1 + src/config/mod.rs | 14 ++-- src/config/schema.rs | 39 +++++++++ src/onboard/wizard.rs | 2 + 6 files changed, 260 insertions(+), 8 deletions(-) create mode 100644 src/agent/classifier.rs diff --git a/src/agent/agent.rs b/src/agent/agent.rs index fbb5ec6..0002799 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -33,6 +33,8 @@ pub struct Agent { skills: Vec, auto_save: bool, history: Vec, + classification_config: crate::config::QueryClassificationConfig, + available_hints: Vec, } pub struct AgentBuilder { @@ -50,6 +52,8 @@ pub struct AgentBuilder { identity_config: Option, skills: Option>, auto_save: Option, + classification_config: Option, + available_hints: Option>, } impl AgentBuilder { @@ -69,6 +73,8 @@ impl AgentBuilder { identity_config: None, skills: None, auto_save: None, + classification_config: None, + available_hints: None, } } @@ -142,6 +148,19 @@ impl AgentBuilder { self } + pub fn classification_config( + mut self, + classification_config: crate::config::QueryClassificationConfig, + ) -> Self { + self.classification_config = Some(classification_config); + self + } + + pub fn available_hints(mut self, available_hints: Vec) -> Self { + self.available_hints = Some(available_hints); + self + } + pub fn build(self) -> Result { let tools = self .tools @@ -181,6 +200,8 @@ impl AgentBuilder { skills: self.skills.unwrap_or_default(), auto_save: self.auto_save.unwrap_or(false), history: Vec::new(), + classification_config: self.classification_config.unwrap_or_default(), + available_hints: self.available_hints.unwrap_or_default(), }) } } @@ -265,6 +286,9 @@ impl Agent { _ => Box::new(XmlToolDispatcher), }; + let available_hints: Vec = + config.model_routes.iter().map(|r| r.hint.clone()).collect(); + Agent::builder() .provider(provider) .tools(tools) @@ -280,6 +304,8 @@ impl Agent { .model_name(model_name) .temperature(config.default_temperature) .workspace_dir(config.workspace_dir.clone()) + .classification_config(config.query_classification.clone()) + .available_hints(available_hints) .identity_config(config.identity.clone()) .skills(crate::skills::load_skills(&config.workspace_dir)) .auto_save(config.memory.auto_save) @@ -380,6 +406,16 @@ impl Agent { results } + fn classify_model(&self, user_message: &str) -> String { + if let Some(hint) = super::classifier::classify(&self.classification_config, user_message) { + if self.available_hints.contains(&hint) { + tracing::info!(hint = hint.as_str(), "Auto-classified query"); + return format!("hint:{hint}"); + } + } + self.model_name.clone() + } + pub async fn turn(&mut self, user_message: &str) -> Result { if self.history.is_empty() { let system_prompt = self.build_system_prompt()?; @@ -411,6 +447,8 @@ impl Agent { self.history .push(ConversationMessage::Chat(ChatMessage::user(enriched))); + let effective_model = self.classify_model(user_message); + for _ in 0..self.config.max_tool_iterations { let messages = self.tool_dispatcher.to_provider_messages(&self.history); let response = match self @@ -424,7 +462,7 @@ impl Agent { None }, }, - &self.model_name, + &effective_model, self.temperature, ) .await diff --git a/src/agent/classifier.rs b/src/agent/classifier.rs new file mode 100644 index 0000000..76c965a --- /dev/null +++ b/src/agent/classifier.rs @@ -0,0 +1,172 @@ +use crate::config::schema::QueryClassificationConfig; + +/// Classify a user message against the configured rules and return the +/// matching hint string, if any. +/// +/// Returns `None` when classification is disabled, no rules are configured, +/// or no rule matches the message. +pub fn classify(config: &QueryClassificationConfig, message: &str) -> Option { + if !config.enabled || config.rules.is_empty() { + return None; + } + + let lower = message.to_lowercase(); + let len = message.len(); + + let mut rules: Vec<_> = config.rules.iter().collect(); + rules.sort_by(|a, b| b.priority.cmp(&a.priority)); + + for rule in rules { + // Length constraints + if let Some(min) = rule.min_length { + if len < min { + continue; + } + } + if let Some(max) = rule.max_length { + if len > max { + continue; + } + } + + // Check keywords (case-insensitive) and patterns (case-sensitive) + let keyword_hit = rule + .keywords + .iter() + .any(|kw: &String| lower.contains(&kw.to_lowercase())); + let pattern_hit = rule + .patterns + .iter() + .any(|pat: &String| message.contains(pat.as_str())); + + if keyword_hit || pattern_hit { + return Some(rule.hint.clone()); + } + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::schema::{ClassificationRule, QueryClassificationConfig}; + + fn make_config(enabled: bool, rules: Vec) -> QueryClassificationConfig { + QueryClassificationConfig { enabled, rules } + } + + #[test] + fn disabled_returns_none() { + let config = make_config( + false, + vec![ClassificationRule { + hint: "fast".into(), + keywords: vec!["hello".into()], + ..Default::default() + }], + ); + assert_eq!(classify(&config, "hello"), None); + } + + #[test] + fn empty_rules_returns_none() { + let config = make_config(true, vec![]); + assert_eq!(classify(&config, "hello"), None); + } + + #[test] + fn keyword_match_case_insensitive() { + let config = make_config( + true, + vec![ClassificationRule { + hint: "fast".into(), + keywords: vec!["hello".into()], + ..Default::default() + }], + ); + assert_eq!(classify(&config, "HELLO world"), Some("fast".into())); + } + + #[test] + fn pattern_match_case_sensitive() { + let config = make_config( + true, + vec![ClassificationRule { + hint: "code".into(), + patterns: vec!["fn ".into()], + ..Default::default() + }], + ); + assert_eq!(classify(&config, "fn main()"), Some("code".into())); + assert_eq!(classify(&config, "FN MAIN()"), None); + } + + #[test] + fn length_constraints() { + let config = make_config( + true, + vec![ClassificationRule { + hint: "fast".into(), + keywords: vec!["hi".into()], + max_length: Some(10), + ..Default::default() + }], + ); + assert_eq!(classify(&config, "hi"), Some("fast".into())); + assert_eq!( + classify(&config, "hi there, how are you doing today?"), + None + ); + + let config2 = make_config( + true, + vec![ClassificationRule { + hint: "reasoning".into(), + keywords: vec!["explain".into()], + min_length: Some(20), + ..Default::default() + }], + ); + assert_eq!(classify(&config2, "explain"), None); + assert_eq!( + classify(&config2, "explain how this works in detail"), + Some("reasoning".into()) + ); + } + + #[test] + fn priority_ordering() { + let config = make_config( + true, + vec![ + ClassificationRule { + hint: "fast".into(), + keywords: vec!["code".into()], + priority: 1, + ..Default::default() + }, + ClassificationRule { + hint: "code".into(), + keywords: vec!["code".into()], + priority: 10, + ..Default::default() + }, + ], + ); + assert_eq!(classify(&config, "write some code"), Some("code".into())); + } + + #[test] + fn no_match_returns_none() { + let config = make_config( + true, + vec![ClassificationRule { + hint: "fast".into(), + keywords: vec!["hello".into()], + ..Default::default() + }], + ); + assert_eq!(classify(&config, "something completely different"), None); + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index 29c96a5..3d33bb4 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,5 +1,6 @@ #[allow(clippy::module_inception)] pub mod agent; +pub mod classifier; pub mod dispatcher; pub mod loop_; pub mod memory_loader; diff --git a/src/config/mod.rs b/src/config/mod.rs index 8e37cce..a78b132 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,13 +3,13 @@ pub mod schema; #[allow(unused_imports)] pub use schema::{ AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, - ChannelsConfig, ComposioConfig, Config, CostConfig, CronConfig, DelegateAgentConfig, - DiscordConfig, DockerRuntimeConfig, GatewayConfig, HardwareConfig, HardwareTransport, - HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, - MemoryConfig, ModelRouteConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, - ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, - SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, TelegramConfig, TunnelConfig, - WebhookConfig, + ChannelsConfig, ClassificationRule, ComposioConfig, Config, CostConfig, CronConfig, + DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HardwareConfig, + HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, + LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, + PeripheralBoardConfig, PeripheralsConfig, QueryClassificationConfig, ReliabilityConfig, + ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, + SecretsConfig, SecurityConfig, SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig, }; #[cfg(test)] diff --git a/src/config/schema.rs b/src/config/schema.rs index 4d987a8..e7aed1f 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -47,6 +47,10 @@ pub struct Config { #[serde(default)] pub model_routes: Vec, + /// Automatic query classification — maps user messages to model hints. + #[serde(default)] + pub query_classification: QueryClassificationConfig, + #[serde(default)] pub heartbeat: HeartbeatConfig, @@ -1205,6 +1209,40 @@ pub struct ModelRouteConfig { pub api_key: Option, } +// ── Query Classification ───────────────────────────────────────── + +/// Automatic query classification — classifies user messages by keyword/pattern +/// and routes to the appropriate model hint. Disabled by default. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct QueryClassificationConfig { + #[serde(default)] + pub enabled: bool, + #[serde(default)] + pub rules: Vec, +} + +/// A single classification rule mapping message patterns to a model hint. +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct ClassificationRule { + /// Must match a `[[model_routes]]` hint value. + pub hint: String, + /// Case-insensitive substring matches. + #[serde(default)] + pub keywords: Vec, + /// Case-sensitive literal matches (for "```", "fn ", etc.). + #[serde(default)] + pub patterns: Vec, + /// Only match if message length >= N chars. + #[serde(default)] + pub min_length: Option, + /// Only match if message length <= N chars. + #[serde(default)] + pub max_length: Option, + /// Higher priority rules are checked first. + #[serde(default)] + pub priority: i32, +} + // ── Heartbeat ──────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -1740,6 +1778,7 @@ impl Default for Config { peripherals: PeripheralsConfig::default(), agents: HashMap::new(), hardware: HardwareConfig::default(), + query_classification: QueryClassificationConfig::default(), } } } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index df71d58..fdba1f9 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -136,6 +136,7 @@ pub fn run_wizard() -> Result { peripherals: crate::config::PeripheralsConfig::default(), agents: std::collections::HashMap::new(), hardware: hardware_config, + query_classification: crate::config::QueryClassificationConfig::default(), }; println!( @@ -356,6 +357,7 @@ pub fn run_quick_setup( peripherals: crate::config::PeripheralsConfig::default(), agents: std::collections::HashMap::new(), hardware: crate::config::HardwareConfig::default(), + query_classification: crate::config::QueryClassificationConfig::default(), }; config.save()?;