feat: add agent structure and improve tooling for provider

This commit is contained in:
mai1015 2026-02-16 00:40:43 -05:00 committed by Chummy
parent e2c966d31e
commit b341fdb368
21 changed files with 2567 additions and 443 deletions

View file

@ -1,4 +1,4 @@
use super::traits::{ChatMessage, ChatResponse};
use super::traits::ChatMessage;
use super::Provider;
use async_trait::async_trait;
use std::collections::HashMap;
@ -156,7 +156,7 @@ impl Provider for ReliableProvider {
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
) -> anyhow::Result<String> {
let models = self.model_chain(model);
let mut failures = Vec::new();
@ -254,7 +254,7 @@ impl Provider for ReliableProvider {
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
) -> anyhow::Result<String> {
let models = self.model_chain(model);
let mut failures = Vec::new();
@ -359,12 +359,12 @@ mod tests {
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(ChatResponse::with_text(self.response))
Ok(self.response.to_string())
}
async fn chat_with_history(
@ -372,12 +372,12 @@ mod tests {
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(ChatResponse::with_text(self.response))
Ok(self.response.to_string())
}
}
@ -397,13 +397,13 @@ mod tests {
_message: &str,
model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.models_seen.lock().unwrap().push(model.to_string());
if self.fail_models.contains(&model) {
anyhow::bail!("500 model {} unavailable", model);
}
Ok(ChatResponse::with_text(self.response))
Ok(self.response.to_string())
}
}
@ -426,8 +426,8 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result.text_or_empty(), "ok");
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@ -448,8 +448,8 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result.text_or_empty(), "recovered");
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "recovered");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
@ -483,8 +483,8 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result.text_or_empty(), "from fallback");
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
@ -517,7 +517,7 @@ mod tests {
);
let err = provider
.chat("hello", "test", 0.0)
.simple_chat("hello", "test", 0.0)
.await
.expect_err("all providers should fail");
let msg = err.to_string();
@ -572,8 +572,8 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result.text_or_empty(), "from fallback");
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
// Primary should have been called only once (no retries)
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
@ -601,7 +601,7 @@ mod tests {
.chat_with_history(&messages, "test", 0.0)
.await
.unwrap();
assert_eq!(result.text_or_empty(), "history ok");
assert_eq!(result, "history ok");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
@ -640,7 +640,7 @@ mod tests {
.chat_with_history(&messages, "test", 0.0)
.await
.unwrap();
assert_eq!(result.text_or_empty(), "fallback ok");
assert_eq!(result, "fallback ok");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
@ -827,7 +827,7 @@ mod tests {
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
) -> anyhow::Result<String> {
self.as_ref()
.chat_with_system(system_prompt, message, model, temperature)
.await