feat(agent): add rule-based query classification for automatic model routing
Classify incoming user messages by keyword/pattern and route to the appropriate model hint automatically, feeding into the existing RouterProvider. Disabled by default; opt-in via [query_classification] config section. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
1336c2f03e
commit
6e53341bb1
6 changed files with 260 additions and 8 deletions
|
|
@ -33,6 +33,8 @@ pub struct Agent {
|
|||
skills: Vec<crate::skills::Skill>,
|
||||
auto_save: bool,
|
||||
history: Vec<ConversationMessage>,
|
||||
classification_config: crate::config::QueryClassificationConfig,
|
||||
available_hints: Vec<String>,
|
||||
}
|
||||
|
||||
pub struct AgentBuilder {
|
||||
|
|
@ -50,6 +52,8 @@ pub struct AgentBuilder {
|
|||
identity_config: Option<crate::config::IdentityConfig>,
|
||||
skills: Option<Vec<crate::skills::Skill>>,
|
||||
auto_save: Option<bool>,
|
||||
classification_config: Option<crate::config::QueryClassificationConfig>,
|
||||
available_hints: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl AgentBuilder {
|
||||
|
|
@ -69,6 +73,8 @@ impl AgentBuilder {
|
|||
identity_config: None,
|
||||
skills: None,
|
||||
auto_save: None,
|
||||
classification_config: None,
|
||||
available_hints: None,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -142,6 +148,19 @@ impl AgentBuilder {
|
|||
self
|
||||
}
|
||||
|
||||
pub fn classification_config(
|
||||
mut self,
|
||||
classification_config: crate::config::QueryClassificationConfig,
|
||||
) -> Self {
|
||||
self.classification_config = Some(classification_config);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn available_hints(mut self, available_hints: Vec<String>) -> Self {
|
||||
self.available_hints = Some(available_hints);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Result<Agent> {
|
||||
let tools = self
|
||||
.tools
|
||||
|
|
@ -181,6 +200,8 @@ impl AgentBuilder {
|
|||
skills: self.skills.unwrap_or_default(),
|
||||
auto_save: self.auto_save.unwrap_or(false),
|
||||
history: Vec::new(),
|
||||
classification_config: self.classification_config.unwrap_or_default(),
|
||||
available_hints: self.available_hints.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -265,6 +286,9 @@ impl Agent {
|
|||
_ => Box::new(XmlToolDispatcher),
|
||||
};
|
||||
|
||||
let available_hints: Vec<String> =
|
||||
config.model_routes.iter().map(|r| r.hint.clone()).collect();
|
||||
|
||||
Agent::builder()
|
||||
.provider(provider)
|
||||
.tools(tools)
|
||||
|
|
@ -280,6 +304,8 @@ impl Agent {
|
|||
.model_name(model_name)
|
||||
.temperature(config.default_temperature)
|
||||
.workspace_dir(config.workspace_dir.clone())
|
||||
.classification_config(config.query_classification.clone())
|
||||
.available_hints(available_hints)
|
||||
.identity_config(config.identity.clone())
|
||||
.skills(crate::skills::load_skills(&config.workspace_dir))
|
||||
.auto_save(config.memory.auto_save)
|
||||
|
|
@ -380,6 +406,16 @@ impl Agent {
|
|||
results
|
||||
}
|
||||
|
||||
fn classify_model(&self, user_message: &str) -> String {
|
||||
if let Some(hint) = super::classifier::classify(&self.classification_config, user_message) {
|
||||
if self.available_hints.contains(&hint) {
|
||||
tracing::info!(hint = hint.as_str(), "Auto-classified query");
|
||||
return format!("hint:{hint}");
|
||||
}
|
||||
}
|
||||
self.model_name.clone()
|
||||
}
|
||||
|
||||
pub async fn turn(&mut self, user_message: &str) -> Result<String> {
|
||||
if self.history.is_empty() {
|
||||
let system_prompt = self.build_system_prompt()?;
|
||||
|
|
@ -411,6 +447,8 @@ impl Agent {
|
|||
self.history
|
||||
.push(ConversationMessage::Chat(ChatMessage::user(enriched)));
|
||||
|
||||
let effective_model = self.classify_model(user_message);
|
||||
|
||||
for _ in 0..self.config.max_tool_iterations {
|
||||
let messages = self.tool_dispatcher.to_provider_messages(&self.history);
|
||||
let response = match self
|
||||
|
|
@ -424,7 +462,7 @@ impl Agent {
|
|||
None
|
||||
},
|
||||
},
|
||||
&self.model_name,
|
||||
&effective_model,
|
||||
self.temperature,
|
||||
)
|
||||
.await
|
||||
|
|
|
|||
172
src/agent/classifier.rs
Normal file
172
src/agent/classifier.rs
Normal file
|
|
@ -0,0 +1,172 @@
|
|||
use crate::config::schema::QueryClassificationConfig;
|
||||
|
||||
/// Classify a user message against the configured rules and return the
|
||||
/// matching hint string, if any.
|
||||
///
|
||||
/// Returns `None` when classification is disabled, no rules are configured,
|
||||
/// or no rule matches the message.
|
||||
pub fn classify(config: &QueryClassificationConfig, message: &str) -> Option<String> {
|
||||
if !config.enabled || config.rules.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let lower = message.to_lowercase();
|
||||
let len = message.len();
|
||||
|
||||
let mut rules: Vec<_> = config.rules.iter().collect();
|
||||
rules.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
|
||||
for rule in rules {
|
||||
// Length constraints
|
||||
if let Some(min) = rule.min_length {
|
||||
if len < min {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if let Some(max) = rule.max_length {
|
||||
if len > max {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Check keywords (case-insensitive) and patterns (case-sensitive)
|
||||
let keyword_hit = rule
|
||||
.keywords
|
||||
.iter()
|
||||
.any(|kw: &String| lower.contains(&kw.to_lowercase()));
|
||||
let pattern_hit = rule
|
||||
.patterns
|
||||
.iter()
|
||||
.any(|pat: &String| message.contains(pat.as_str()));
|
||||
|
||||
if keyword_hit || pattern_hit {
|
||||
return Some(rule.hint.clone());
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::schema::{ClassificationRule, QueryClassificationConfig};
|
||||
|
||||
fn make_config(enabled: bool, rules: Vec<ClassificationRule>) -> QueryClassificationConfig {
|
||||
QueryClassificationConfig { enabled, rules }
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn disabled_returns_none() {
|
||||
let config = make_config(
|
||||
false,
|
||||
vec![ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["hello".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
);
|
||||
assert_eq!(classify(&config, "hello"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_rules_returns_none() {
|
||||
let config = make_config(true, vec![]);
|
||||
assert_eq!(classify(&config, "hello"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn keyword_match_case_insensitive() {
|
||||
let config = make_config(
|
||||
true,
|
||||
vec![ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["hello".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
);
|
||||
assert_eq!(classify(&config, "HELLO world"), Some("fast".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pattern_match_case_sensitive() {
|
||||
let config = make_config(
|
||||
true,
|
||||
vec![ClassificationRule {
|
||||
hint: "code".into(),
|
||||
patterns: vec!["fn ".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
);
|
||||
assert_eq!(classify(&config, "fn main()"), Some("code".into()));
|
||||
assert_eq!(classify(&config, "FN MAIN()"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn length_constraints() {
|
||||
let config = make_config(
|
||||
true,
|
||||
vec![ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["hi".into()],
|
||||
max_length: Some(10),
|
||||
..Default::default()
|
||||
}],
|
||||
);
|
||||
assert_eq!(classify(&config, "hi"), Some("fast".into()));
|
||||
assert_eq!(
|
||||
classify(&config, "hi there, how are you doing today?"),
|
||||
None
|
||||
);
|
||||
|
||||
let config2 = make_config(
|
||||
true,
|
||||
vec![ClassificationRule {
|
||||
hint: "reasoning".into(),
|
||||
keywords: vec!["explain".into()],
|
||||
min_length: Some(20),
|
||||
..Default::default()
|
||||
}],
|
||||
);
|
||||
assert_eq!(classify(&config2, "explain"), None);
|
||||
assert_eq!(
|
||||
classify(&config2, "explain how this works in detail"),
|
||||
Some("reasoning".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn priority_ordering() {
|
||||
let config = make_config(
|
||||
true,
|
||||
vec![
|
||||
ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 1,
|
||||
..Default::default()
|
||||
},
|
||||
ClassificationRule {
|
||||
hint: "code".into(),
|
||||
keywords: vec!["code".into()],
|
||||
priority: 10,
|
||||
..Default::default()
|
||||
},
|
||||
],
|
||||
);
|
||||
assert_eq!(classify(&config, "write some code"), Some("code".into()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn no_match_returns_none() {
|
||||
let config = make_config(
|
||||
true,
|
||||
vec![ClassificationRule {
|
||||
hint: "fast".into(),
|
||||
keywords: vec!["hello".into()],
|
||||
..Default::default()
|
||||
}],
|
||||
);
|
||||
assert_eq!(classify(&config, "something completely different"), None);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
#[allow(clippy::module_inception)]
|
||||
pub mod agent;
|
||||
pub mod classifier;
|
||||
pub mod dispatcher;
|
||||
pub mod loop_;
|
||||
pub mod memory_loader;
|
||||
|
|
|
|||
|
|
@ -3,13 +3,13 @@ pub mod schema;
|
|||
#[allow(unused_imports)]
|
||||
pub use schema::{
|
||||
AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig,
|
||||
ChannelsConfig, ComposioConfig, Config, CostConfig, CronConfig, DelegateAgentConfig,
|
||||
DiscordConfig, DockerRuntimeConfig, GatewayConfig, HardwareConfig, HardwareTransport,
|
||||
HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig,
|
||||
MemoryConfig, ModelRouteConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig,
|
||||
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
|
||||
SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, TelegramConfig, TunnelConfig,
|
||||
WebhookConfig,
|
||||
ChannelsConfig, ClassificationRule, ComposioConfig, Config, CostConfig, CronConfig,
|
||||
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HardwareConfig,
|
||||
HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig,
|
||||
LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig,
|
||||
PeripheralBoardConfig, PeripheralsConfig, QueryClassificationConfig, ReliabilityConfig,
|
||||
ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig,
|
||||
SecretsConfig, SecurityConfig, SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -47,6 +47,10 @@ pub struct Config {
|
|||
#[serde(default)]
|
||||
pub model_routes: Vec<ModelRouteConfig>,
|
||||
|
||||
/// Automatic query classification — maps user messages to model hints.
|
||||
#[serde(default)]
|
||||
pub query_classification: QueryClassificationConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub heartbeat: HeartbeatConfig,
|
||||
|
||||
|
|
@ -1205,6 +1209,40 @@ pub struct ModelRouteConfig {
|
|||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
// ── Query Classification ─────────────────────────────────────────
|
||||
|
||||
/// Automatic query classification — classifies user messages by keyword/pattern
|
||||
/// and routes to the appropriate model hint. Disabled by default.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct QueryClassificationConfig {
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
#[serde(default)]
|
||||
pub rules: Vec<ClassificationRule>,
|
||||
}
|
||||
|
||||
/// A single classification rule mapping message patterns to a model hint.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ClassificationRule {
|
||||
/// Must match a `[[model_routes]]` hint value.
|
||||
pub hint: String,
|
||||
/// Case-insensitive substring matches.
|
||||
#[serde(default)]
|
||||
pub keywords: Vec<String>,
|
||||
/// Case-sensitive literal matches (for "```", "fn ", etc.).
|
||||
#[serde(default)]
|
||||
pub patterns: Vec<String>,
|
||||
/// Only match if message length >= N chars.
|
||||
#[serde(default)]
|
||||
pub min_length: Option<usize>,
|
||||
/// Only match if message length <= N chars.
|
||||
#[serde(default)]
|
||||
pub max_length: Option<usize>,
|
||||
/// Higher priority rules are checked first.
|
||||
#[serde(default)]
|
||||
pub priority: i32,
|
||||
}
|
||||
|
||||
// ── Heartbeat ────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -1740,6 +1778,7 @@ impl Default for Config {
|
|||
peripherals: PeripheralsConfig::default(),
|
||||
agents: HashMap::new(),
|
||||
hardware: HardwareConfig::default(),
|
||||
query_classification: QueryClassificationConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -136,6 +136,7 @@ pub fn run_wizard() -> Result<Config> {
|
|||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
hardware: hardware_config,
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
};
|
||||
|
||||
println!(
|
||||
|
|
@ -356,6 +357,7 @@ pub fn run_quick_setup(
|
|||
peripherals: crate::config::PeripheralsConfig::default(),
|
||||
agents: std::collections::HashMap::new(),
|
||||
hardware: crate::config::HardwareConfig::default(),
|
||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
};
|
||||
|
||||
config.save()?;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue