fix: use safe Unicode string truncation to prevent panics (CWE-119)

Merge pull request #117 from theonlyhennygod/fix/unicode-truncation-panic
This commit is contained in:
Argenis 2026-02-15 06:49:48 -05:00 committed by GitHub
parent 5cc02c5813
commit 7b5e77f03c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 1689 additions and 143 deletions

View file

@ -5,6 +5,7 @@ use crate::providers::{self, Provider};
use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools;
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use std::fmt::Write;
use std::sync::Arc;
@ -150,11 +151,7 @@ pub async fn run(
// Auto-save assistant response to daily log
if config.memory.auto_save {
let summary = if response.len() > 100 {
format!("{}...", &response[..100])
} else {
response.clone()
};
let summary = truncate_with_ellipsis(&response, 100);
let _ = mem
.store("assistant_resp", &summary, MemoryCategory::Daily)
.await;
@ -193,11 +190,7 @@ pub async fn run(
println!("\n{response}\n");
if config.memory.auto_save {
let summary = if response.len() > 100 {
format!("{}...", &response[..100])
} else {
response.clone()
};
let summary = truncate_with_ellipsis(&response, 100);
let _ = mem
.store("assistant_resp", &summary, MemoryCategory::Daily)
.await;

View file

@ -0,0 +1,446 @@
#![allow(clippy::uninlined_format_args)]
#![allow(clippy::map_unwrap_or)]
#![allow(clippy::redundant_closure_for_method_calls)]
#![allow(clippy::cast_lossless)]
#![allow(clippy::trim_split_whitespace)]
#![allow(clippy::doc_link_with_quotes)]
#![allow(clippy::doc_markdown)]
#![allow(clippy::too_many_lines)]
#![allow(clippy::unnecessary_map_or)]
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use lettre::transport::smtp::authentication::Credentials;
use lettre::{Message, SmtpTransport, Transport};
use mail_parser::{MessageParser, MimeHeaders};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::io::Write as IoWrite;
use std::net::TcpStream;
use std::sync::Mutex;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::mpsc;
use tokio::time::{interval, sleep};
use tracing::{error, info, warn};
use uuid::Uuid;
use super::traits::{Channel, ChannelMessage};
/// Email channel configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmailConfig {
/// IMAP server hostname
pub imap_host: String,
/// IMAP server port (default: 993 for TLS)
#[serde(default = "default_imap_port")]
pub imap_port: u16,
/// IMAP folder to poll (default: INBOX)
#[serde(default = "default_imap_folder")]
pub imap_folder: String,
/// SMTP server hostname
pub smtp_host: String,
/// SMTP server port (default: 587 for STARTTLS)
#[serde(default = "default_smtp_port")]
pub smtp_port: u16,
/// Use TLS for SMTP (default: true)
#[serde(default = "default_true")]
pub smtp_tls: bool,
/// Email username for authentication
pub username: String,
/// Email password for authentication
pub password: String,
/// From address for outgoing emails
pub from_address: String,
/// Poll interval in seconds (default: 60)
#[serde(default = "default_poll_interval")]
pub poll_interval_secs: u64,
/// Allowed sender addresses/domains (empty = deny all, ["*"] = allow all)
#[serde(default)]
pub allowed_senders: Vec<String>,
}
fn default_imap_port() -> u16 {
993
}
fn default_smtp_port() -> u16 {
587
}
fn default_imap_folder() -> String {
"INBOX".into()
}
fn default_poll_interval() -> u64 {
60
}
fn default_true() -> bool {
true
}
impl Default for EmailConfig {
fn default() -> Self {
Self {
imap_host: String::new(),
imap_port: default_imap_port(),
imap_folder: default_imap_folder(),
smtp_host: String::new(),
smtp_port: default_smtp_port(),
smtp_tls: true,
username: String::new(),
password: String::new(),
from_address: String::new(),
poll_interval_secs: default_poll_interval(),
allowed_senders: Vec::new(),
}
}
}
/// Email channel — IMAP polling for inbound, SMTP for outbound
pub struct EmailChannel {
pub config: EmailConfig,
seen_messages: Mutex<HashSet<String>>,
}
impl EmailChannel {
pub fn new(config: EmailConfig) -> Self {
Self {
config,
seen_messages: Mutex::new(HashSet::new()),
}
}
/// Check if a sender email is in the allowlist
pub fn is_sender_allowed(&self, email: &str) -> bool {
if self.config.allowed_senders.is_empty() {
return false; // Empty = deny all
}
if self.config.allowed_senders.iter().any(|a| a == "*") {
return true; // Wildcard = allow all
}
let email_lower = email.to_lowercase();
self.config.allowed_senders.iter().any(|allowed| {
if allowed.starts_with('@') {
// Domain match with @ prefix: "@example.com"
email_lower.ends_with(&allowed.to_lowercase())
} else if allowed.contains('@') {
// Full email address match
allowed.eq_ignore_ascii_case(email)
} else {
// Domain match without @ prefix: "example.com"
email_lower.ends_with(&format!("@{}", allowed.to_lowercase()))
}
})
}
/// Strip HTML tags from content (basic)
pub fn strip_html(html: &str) -> String {
let mut result = String::new();
let mut in_tag = false;
for ch in html.chars() {
match ch {
'<' => in_tag = true,
'>' => in_tag = false,
_ if !in_tag => result.push(ch),
_ => {}
}
}
result.split_whitespace().collect::<Vec<_>>().join(" ")
}
/// Extract the sender address from a parsed email
fn extract_sender(parsed: &mail_parser::Message) -> String {
parsed
.from()
.and_then(|addr| addr.first())
.and_then(|a| a.address())
.map(|s| s.to_string())
.unwrap_or_else(|| "unknown".into())
}
/// Extract readable text from a parsed email
fn extract_text(parsed: &mail_parser::Message) -> String {
if let Some(text) = parsed.body_text(0) {
return text.to_string();
}
if let Some(html) = parsed.body_html(0) {
return Self::strip_html(html.as_ref());
}
for part in parsed.attachments() {
let part: &mail_parser::MessagePart = part;
if let Some(ct) = MimeHeaders::content_type(part) {
if ct.ctype() == "text" {
if let Ok(text) = std::str::from_utf8(part.contents()) {
let name = MimeHeaders::attachment_name(part).unwrap_or("file");
return format!("[Attachment: {}]\n{}", name, text);
}
}
}
}
"(no readable content)".to_string()
}
/// Fetch unseen emails via IMAP (blocking, run in spawn_blocking)
fn fetch_unseen_imap(config: &EmailConfig) -> Result<Vec<(String, String, String, u64)>> {
use rustls::ClientConfig as TlsConfig;
use rustls_pki_types::ServerName;
use std::sync::Arc;
use tokio_rustls::rustls;
// Connect TCP
let tcp = TcpStream::connect((&*config.imap_host, config.imap_port))?;
tcp.set_read_timeout(Some(Duration::from_secs(30)))?;
// TLS
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
let tls_config = Arc::new(
TlsConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
);
let server_name: ServerName<'_> = ServerName::try_from(config.imap_host.clone())?;
let conn = rustls::ClientConnection::new(tls_config, server_name)?;
let mut tls = rustls::StreamOwned::new(conn, tcp);
let read_line =
|tls: &mut rustls::StreamOwned<rustls::ClientConnection, TcpStream>| -> Result<String> {
let mut buf = Vec::new();
loop {
let mut byte = [0u8; 1];
match std::io::Read::read(tls, &mut byte) {
Ok(0) => return Err(anyhow!("IMAP connection closed")),
Ok(_) => {
buf.push(byte[0]);
if buf.ends_with(b"\r\n") {
return Ok(String::from_utf8_lossy(&buf).to_string());
}
}
Err(e) => return Err(e.into()),
}
}
};
let send_cmd = |tls: &mut rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
tag: &str,
cmd: &str|
-> Result<Vec<String>> {
let full = format!("{} {}\r\n", tag, cmd);
IoWrite::write_all(tls, full.as_bytes())?;
IoWrite::flush(tls)?;
let mut lines = Vec::new();
loop {
let line = read_line(tls)?;
let done = line.starts_with(tag);
lines.push(line);
if done {
break;
}
}
Ok(lines)
};
// Read greeting
let _greeting = read_line(&mut tls)?;
// Login
let login_resp = send_cmd(
&mut tls,
"A1",
&format!("LOGIN \"{}\" \"{}\"", config.username, config.password),
)?;
if !login_resp.last().map_or(false, |l| l.contains("OK")) {
return Err(anyhow!("IMAP login failed"));
}
// Select folder
let _select = send_cmd(
&mut tls,
"A2",
&format!("SELECT \"{}\"", config.imap_folder),
)?;
// Search unseen
let search_resp = send_cmd(&mut tls, "A3", "SEARCH UNSEEN")?;
let mut uids: Vec<&str> = Vec::new();
for line in &search_resp {
if line.starts_with("* SEARCH") {
let parts: Vec<&str> = line.trim().split_whitespace().collect();
if parts.len() > 2 {
uids.extend_from_slice(&parts[2..]);
}
}
}
let mut results = Vec::new();
let mut tag_counter = 4_u32; // Start after A1, A2, A3
for uid in &uids {
// Fetch RFC822 with unique tag
let fetch_tag = format!("A{}", tag_counter);
tag_counter += 1;
let fetch_resp = send_cmd(&mut tls, &fetch_tag, &format!("FETCH {} RFC822", uid))?;
// Reconstruct the raw email from the response (skip first and last lines)
let raw: String = fetch_resp
.iter()
.skip(1)
.take(fetch_resp.len().saturating_sub(2))
.cloned()
.collect();
if let Some(parsed) = MessageParser::default().parse(raw.as_bytes()) {
let sender = Self::extract_sender(&parsed);
let subject = parsed.subject().unwrap_or("(no subject)").to_string();
let body = Self::extract_text(&parsed);
let content = format!("Subject: {}\n\n{}", subject, body);
let msg_id = parsed
.message_id()
.map(|s| s.to_string())
.unwrap_or_else(|| format!("gen-{}", Uuid::new_v4()));
#[allow(clippy::cast_sign_loss)]
let ts = parsed
.date()
.map(|d| {
let naive = chrono::NaiveDate::from_ymd_opt(
d.year as i32,
u32::from(d.month),
u32::from(d.day),
)
.and_then(|date| {
date.and_hms_opt(
u32::from(d.hour),
u32::from(d.minute),
u32::from(d.second),
)
});
naive.map_or(0, |n| n.and_utc().timestamp() as u64)
})
.unwrap_or_else(|| {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
});
results.push((msg_id, sender, content, ts));
}
// Mark as seen with unique tag
let store_tag = format!("A{tag_counter}");
tag_counter += 1;
let _ = send_cmd(
&mut tls,
&store_tag,
&format!("STORE {uid} +FLAGS (\\Seen)"),
);
}
// Logout with unique tag
let logout_tag = format!("A{tag_counter}");
let _ = send_cmd(&mut tls, &logout_tag, "LOGOUT");
Ok(results)
}
fn create_smtp_transport(&self) -> Result<SmtpTransport> {
let creds = Credentials::new(self.config.username.clone(), self.config.password.clone());
let transport = if self.config.smtp_tls {
SmtpTransport::relay(&self.config.smtp_host)?
.port(self.config.smtp_port)
.credentials(creds)
.build()
} else {
SmtpTransport::builder_dangerous(&self.config.smtp_host)
.port(self.config.smtp_port)
.credentials(creds)
.build()
};
Ok(transport)
}
}
#[async_trait]
impl Channel for EmailChannel {
fn name(&self) -> &str {
"email"
}
async fn send(&self, message: &str, recipient: &str) -> Result<()> {
let (subject, body) = if message.starts_with("Subject: ") {
if let Some(pos) = message.find('\n') {
(&message[9..pos], message[pos + 1..].trim())
} else {
("ZeroClaw Message", message)
}
} else {
("ZeroClaw Message", message)
};
let email = Message::builder()
.from(self.config.from_address.parse()?)
.to(recipient.parse()?)
.subject(subject)
.body(body.to_string())?;
let transport = self.create_smtp_transport()?;
transport.send(&email)?;
info!("Email sent to {}", recipient);
Ok(())
}
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> Result<()> {
info!(
"Email polling every {}s on {}",
self.config.poll_interval_secs, self.config.imap_folder
);
let mut tick = interval(Duration::from_secs(self.config.poll_interval_secs));
let config = self.config.clone();
loop {
tick.tick().await;
let cfg = config.clone();
match tokio::task::spawn_blocking(move || Self::fetch_unseen_imap(&cfg)).await {
Ok(Ok(messages)) => {
for (id, sender, content, ts) in messages {
{
let mut seen = self.seen_messages.lock().unwrap();
if seen.contains(&id) {
continue;
}
if !self.is_sender_allowed(&sender) {
warn!("Blocked email from {}", sender);
continue;
}
seen.insert(id.clone());
} // MutexGuard dropped before await
let msg = ChannelMessage {
id,
sender,
content,
channel: "email".to_string(),
timestamp: ts,
};
if tx.send(msg).await.is_err() {
return Ok(());
}
}
}
Ok(Err(e)) => {
error!("Email poll failed: {}", e);
sleep(Duration::from_secs(10)).await;
}
Err(e) => {
error!("Email poll task panicked: {}", e);
sleep(Duration::from_secs(10)).await;
}
}
}
}
async fn health_check(&self) -> bool {
let cfg = self.config.clone();
tokio::task::spawn_blocking(move || {
let tcp = TcpStream::connect((&*cfg.imap_host, cfg.imap_port));
tcp.is_ok()
})
.await
.unwrap_or_default()
}
}

View file

@ -1,5 +1,6 @@
pub mod cli;
pub mod discord;
pub mod email_channel;
pub mod imessage;
pub mod matrix;
pub mod slack;

View file

@ -89,10 +89,10 @@ impl Default for IdentityConfig {
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GatewayConfig {
/// Gateway port (default: 3000)
/// Gateway port (default: 8080)
#[serde(default = "default_gateway_port")]
pub port: u16,
/// Gateway host/bind address (default: 127.0.0.1)
/// Gateway host (default: 127.0.0.1)
#[serde(default = "default_gateway_host")]
pub host: String,
/// Require pairing before accepting requests (default: true)
@ -178,13 +178,13 @@ impl Default for SecretsConfig {
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct BrowserConfig {
/// Enable browser tools (`browser_open` and browser automation)
/// Enable `browser_open` tool (opens URLs in Brave without scraping)
#[serde(default)]
pub enabled: bool,
/// Allowed domains for browser tools (exact or subdomain match)
/// Allowed domains for `browser_open` (exact or subdomain match)
#[serde(default)]
pub allowed_domains: Vec<String>,
/// Session name for agent-browser (persists state across commands)
/// Browser session name (for agent-browser automation)
#[serde(default)]
pub session_name: Option<String>,
}
@ -604,8 +604,7 @@ pub struct WhatsAppConfig {
pub phone_number_id: String,
/// Webhook verify token (you define this, Meta sends it back for verification)
pub verify_token: String,
/// App secret from Meta Business Suite (for webhook signature verification)
/// Can also be set via `ZEROCLAW_WHATSAPP_APP_SECRET` environment variable
/// App secret for webhook signature verification (X-Hub-Signature-256)
#[serde(default)]
pub app_secret: Option<String>,
/// Allowed phone numbers (E.164 format: +1234567890) or "*" for all
@ -647,19 +646,10 @@ impl Default for Config {
impl Config {
pub fn load_or_init() -> Result<Self> {
// Check for workspace override from environment (Docker support)
let zeroclaw_dir = if let Ok(workspace) = std::env::var("ZEROCLAW_WORKSPACE") {
let ws_path = PathBuf::from(&workspace);
ws_path
.parent()
.map_or_else(|| PathBuf::from(&workspace), PathBuf::from)
} else {
let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
home.join(".zeroclaw")
};
let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
let zeroclaw_dir = home.join(".zeroclaw");
let config_path = zeroclaw_dir.join("config.toml");
if !zeroclaw_dir.exists() {
@ -668,35 +658,20 @@ impl Config {
.context("Failed to create workspace directory")?;
}
let mut config = if config_path.exists() {
if config_path.exists() {
let contents =
fs::read_to_string(&config_path).context("Failed to read config file")?;
toml::from_str(&contents).context("Failed to parse config file")?
let config: Config =
toml::from_str(&contents).context("Failed to parse config file")?;
Ok(config)
} else {
Config::default()
};
// Apply environment variable overrides (Docker/container support)
config.apply_env_overrides();
// Save config if it didn't exist (creates default config with env overrides)
if !config_path.exists() {
let config = Config::default();
config.save()?;
Ok(config)
}
Ok(config)
}
/// Apply environment variable overrides to config.
///
/// Supports:
/// - `ZEROCLAW_API_KEY` or `API_KEY` - LLM provider API key
/// - `ZEROCLAW_PROVIDER` or `PROVIDER` - Provider name (openrouter, openai, anthropic, ollama)
/// - `ZEROCLAW_MODEL` - Model name/ID
/// - `ZEROCLAW_WORKSPACE` - Workspace directory path
/// - `ZEROCLAW_GATEWAY_PORT` or `PORT` - Gateway server port
/// - `ZEROCLAW_GATEWAY_HOST` or `HOST` - Gateway bind address
/// - `ZEROCLAW_TEMPERATURE` - Default temperature (0.0-2.0)
/// Apply environment variable overrides to config
pub fn apply_env_overrides(&mut self) {
// API Key: ZEROCLAW_API_KEY or API_KEY
if let Ok(key) = std::env::var("ZEROCLAW_API_KEY").or_else(|_| std::env::var("API_KEY")) {
@ -721,15 +696,6 @@ impl Config {
}
}
// Temperature: ZEROCLAW_TEMPERATURE
if let Ok(temp_str) = std::env::var("ZEROCLAW_TEMPERATURE") {
if let Ok(temp) = temp_str.parse::<f64>() {
if (0.0..=2.0).contains(&temp) {
self.default_temperature = temp;
}
}
}
// Workspace directory: ZEROCLAW_WORKSPACE
if let Ok(workspace) = std::env::var("ZEROCLAW_WORKSPACE") {
if !workspace.is_empty() {
@ -753,6 +719,15 @@ impl Config {
self.gateway.host = host;
}
}
// Temperature: ZEROCLAW_TEMPERATURE
if let Ok(temp_str) = std::env::var("ZEROCLAW_TEMPERATURE") {
if let Ok(temp) = temp_str.parse::<f64>() {
if (0.0..=2.0).contains(&temp) {
self.default_temperature = temp;
}
}
}
}
pub fn save(&self) -> Result<()> {
@ -1193,7 +1168,7 @@ channel_id = "C123"
access_token: "tok".into(),
phone_number_id: "12345".into(),
verify_token: "verify".into(),
app_secret: Some("secret123".into()),
app_secret: None,
allowed_numbers: vec!["+1".into()],
};
let toml_str = toml::to_string(&wc).unwrap();
@ -1482,53 +1457,49 @@ default_temperature = 0.7
#[test]
fn env_override_api_key() {
// Primary and fallback tested together to avoid env-var races.
std::env::remove_var("ZEROCLAW_API_KEY");
std::env::remove_var("API_KEY");
// Primary: ZEROCLAW_API_KEY
let mut config = Config::default();
assert!(config.api_key.is_none());
std::env::set_var("ZEROCLAW_API_KEY", "sk-test-env-key");
config.apply_env_overrides();
assert_eq!(config.api_key.as_deref(), Some("sk-test-env-key"));
std::env::remove_var("ZEROCLAW_API_KEY");
// Fallback: API_KEY
let mut config2 = Config::default();
std::env::remove_var("ZEROCLAW_API_KEY");
}
#[test]
fn env_override_api_key_fallback() {
let mut config = Config::default();
std::env::remove_var("ZEROCLAW_API_KEY");
std::env::set_var("API_KEY", "sk-fallback-key");
config2.apply_env_overrides();
assert_eq!(config2.api_key.as_deref(), Some("sk-fallback-key"));
config.apply_env_overrides();
assert_eq!(config.api_key.as_deref(), Some("sk-fallback-key"));
std::env::remove_var("API_KEY");
}
#[test]
fn env_override_provider() {
// Primary, fallback, and empty-value tested together to avoid env-var races.
std::env::remove_var("ZEROCLAW_PROVIDER");
std::env::remove_var("PROVIDER");
// Primary: ZEROCLAW_PROVIDER
let mut config = Config::default();
std::env::set_var("ZEROCLAW_PROVIDER", "anthropic");
config.apply_env_overrides();
assert_eq!(config.default_provider.as_deref(), Some("anthropic"));
std::env::remove_var("ZEROCLAW_PROVIDER");
// Fallback: PROVIDER
let mut config2 = Config::default();
std::env::remove_var("ZEROCLAW_PROVIDER");
}
#[test]
fn env_override_provider_fallback() {
let mut config = Config::default();
std::env::remove_var("ZEROCLAW_PROVIDER");
std::env::set_var("PROVIDER", "openai");
config2.apply_env_overrides();
assert_eq!(config2.default_provider.as_deref(), Some("openai"));
std::env::remove_var("PROVIDER");
config.apply_env_overrides();
assert_eq!(config.default_provider.as_deref(), Some("openai"));
// Empty value should not override
let mut config3 = Config::default();
let original_provider = config3.default_provider.clone();
std::env::set_var("ZEROCLAW_PROVIDER", "");
config3.apply_env_overrides();
assert_eq!(config3.default_provider, original_provider);
std::env::remove_var("ZEROCLAW_PROVIDER");
std::env::remove_var("PROVIDER");
}
#[test]
@ -1539,7 +1510,6 @@ default_temperature = 0.7
config.apply_env_overrides();
assert_eq!(config.default_model.as_deref(), Some("gpt-4o"));
// Clean up
std::env::remove_var("ZEROCLAW_MODEL");
}
@ -1551,86 +1521,111 @@ default_temperature = 0.7
config.apply_env_overrides();
assert_eq!(config.workspace_dir, PathBuf::from("/custom/workspace"));
// Clean up
std::env::remove_var("ZEROCLAW_WORKSPACE");
}
#[test]
fn env_override_gateway_port() {
// Port, fallback, and invalid tested together to avoid env-var races.
std::env::remove_var("ZEROCLAW_GATEWAY_PORT");
std::env::remove_var("PORT");
fn env_override_empty_values_ignored() {
let mut config = Config::default();
let original_provider = config.default_provider.clone();
// Primary: ZEROCLAW_GATEWAY_PORT
std::env::set_var("ZEROCLAW_PROVIDER", "");
config.apply_env_overrides();
assert_eq!(config.default_provider, original_provider);
std::env::remove_var("ZEROCLAW_PROVIDER");
}
#[test]
fn env_override_gateway_port() {
let mut config = Config::default();
assert_eq!(config.gateway.port, 3000);
std::env::set_var("ZEROCLAW_GATEWAY_PORT", "8080");
config.apply_env_overrides();
assert_eq!(config.gateway.port, 8080);
std::env::remove_var("ZEROCLAW_GATEWAY_PORT");
}
// Fallback: PORT
let mut config2 = Config::default();
#[test]
fn env_override_port_fallback() {
let mut config = Config::default();
std::env::remove_var("ZEROCLAW_GATEWAY_PORT");
std::env::set_var("PORT", "9000");
config2.apply_env_overrides();
assert_eq!(config2.gateway.port, 9000);
// Invalid PORT is ignored
let mut config3 = Config::default();
let original_port = config3.gateway.port;
std::env::set_var("PORT", "not_a_number");
config3.apply_env_overrides();
assert_eq!(config3.gateway.port, original_port);
config.apply_env_overrides();
assert_eq!(config.gateway.port, 9000);
std::env::remove_var("PORT");
}
#[test]
fn env_override_gateway_host() {
// Primary and fallback tested together to avoid env-var races.
std::env::remove_var("ZEROCLAW_GATEWAY_HOST");
std::env::remove_var("HOST");
// Primary: ZEROCLAW_GATEWAY_HOST
let mut config = Config::default();
assert_eq!(config.gateway.host, "127.0.0.1");
std::env::set_var("ZEROCLAW_GATEWAY_HOST", "0.0.0.0");
config.apply_env_overrides();
assert_eq!(config.gateway.host, "0.0.0.0");
std::env::remove_var("ZEROCLAW_GATEWAY_HOST");
// Fallback: HOST
let mut config2 = Config::default();
std::env::remove_var("ZEROCLAW_GATEWAY_HOST");
}
#[test]
fn env_override_host_fallback() {
let mut config = Config::default();
std::env::remove_var("ZEROCLAW_GATEWAY_HOST");
std::env::set_var("HOST", "0.0.0.0");
config2.apply_env_overrides();
assert_eq!(config2.gateway.host, "0.0.0.0");
config.apply_env_overrides();
assert_eq!(config.gateway.host, "0.0.0.0");
std::env::remove_var("HOST");
}
#[test]
fn env_override_temperature() {
// Valid and out-of-range tested together to avoid env-var races.
std::env::remove_var("ZEROCLAW_TEMPERATURE");
// Valid temperature is applied
let mut config = Config::default();
std::env::set_var("ZEROCLAW_TEMPERATURE", "0.5");
config.apply_env_overrides();
assert!((config.default_temperature - 0.5).abs() < f64::EPSILON);
// Out-of-range temperature is ignored
let mut config2 = Config::default();
let original_temp = config2.default_temperature;
std::env::remove_var("ZEROCLAW_TEMPERATURE");
}
#[test]
fn env_override_temperature_out_of_range_ignored() {
// Clean up any leftover env vars from other tests
std::env::remove_var("ZEROCLAW_TEMPERATURE");
let mut config = Config::default();
let original_temp = config.default_temperature;
// Temperature > 2.0 should be ignored
std::env::set_var("ZEROCLAW_TEMPERATURE", "3.0");
config2.apply_env_overrides();
config.apply_env_overrides();
assert!(
(config2.default_temperature - original_temp).abs() < f64::EPSILON,
(config.default_temperature - original_temp).abs() < f64::EPSILON,
"Temperature 3.0 should be ignored (out of range)"
);
std::env::remove_var("ZEROCLAW_TEMPERATURE");
}
#[test]
fn env_override_invalid_port_ignored() {
let mut config = Config::default();
let original_port = config.gateway.port;
std::env::set_var("PORT", "not_a_number");
config.apply_env_overrides();
assert_eq!(config.gateway.port, original_port);
std::env::remove_var("PORT");
}
#[test]
fn gateway_config_default_values() {
let g = GatewayConfig::default();

View file

@ -12,6 +12,7 @@ use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::providers::{self, Provider};
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use axum::{
body::Bytes,
@ -457,11 +458,7 @@ async fn handle_whatsapp_message(
tracing::info!(
"WhatsApp message from {}: {}",
msg.sender,
if msg.content.len() > 50 {
format!("{}...", &msg.content[..50])
} else {
msg.content.clone()
}
truncate_with_ellipsis(&msg.content, 50)
);
// Auto-save to memory

View file

@ -11,10 +11,17 @@
dead_code
)]
pub mod channels;
pub mod config;
pub mod gateway;
pub mod health;
pub mod heartbeat;
pub mod memory;
pub mod observability;
pub mod providers;
pub mod runtime;
pub mod security;
pub mod skills;
pub mod tools;
pub mod tunnel;
pub mod util;

View file

@ -398,6 +398,7 @@ fn default_model_for_provider(provider: &str) -> String {
"ollama" => "llama3.2".into(),
"groq" => "llama-3.3-70b-versatile".into(),
"deepseek" => "deepseek-chat".into(),
"gemini" | "google" | "google-gemini" => "gemini-2.0-flash".into(),
_ => "anthropic/claude-sonnet-4-20250514".into(),
}
}
@ -466,7 +467,7 @@ fn setup_workspace() -> Result<(PathBuf, PathBuf)> {
fn setup_provider() -> Result<(String, String, String)> {
// ── Tier selection ──
let tiers = vec![
"⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI)",
"⭐ Recommended (OpenRouter, Venice, Anthropic, OpenAI, Gemini)",
"⚡ Fast inference (Groq, Fireworks, Together AI)",
"🌐 Gateway / proxy (Vercel AI, Cloudflare AI, Amazon Bedrock)",
"🔬 Specialized (Moonshot/Kimi, GLM/Zhipu, MiniMax, Qianfan, Z.AI, Synthetic, OpenCode Zen, Cohere)",
@ -493,6 +494,10 @@ fn setup_provider() -> Result<(String, String, String)> {
("mistral", "Mistral — Large & Codestral"),
("xai", "xAI — Grok 3 & 4"),
("perplexity", "Perplexity — search-augmented AI"),
(
"gemini",
"Google Gemini — Gemini 2.0 Flash & Pro (supports CLI auth)",
),
],
1 => vec![
("groq", "Groq — ultra-fast LPU inference"),
@ -575,6 +580,53 @@ fn setup_provider() -> Result<(String, String, String)> {
let api_key = if provider_name == "ollama" {
print_bullet("Ollama runs locally — no API key needed!");
String::new()
} else if provider_name == "gemini"
|| provider_name == "google"
|| provider_name == "google-gemini"
{
// Special handling for Gemini: check for CLI auth first
if crate::providers::gemini::GeminiProvider::has_cli_credentials() {
print_bullet(&format!(
"{} Gemini CLI credentials detected! You can skip the API key.",
style("").green().bold()
));
print_bullet("ZeroClaw will reuse your existing Gemini CLI authentication.");
println!();
let use_cli: bool = dialoguer::Confirm::new()
.with_prompt(" Use existing Gemini CLI authentication?")
.default(true)
.interact()?;
if use_cli {
println!(
" {} Using Gemini CLI OAuth tokens",
style("").green().bold()
);
String::new() // Empty key = will use CLI tokens
} else {
print_bullet("Get your API key at: https://aistudio.google.com/app/apikey");
Input::new()
.with_prompt(" Paste your Gemini API key")
.allow_empty(true)
.interact_text()?
}
} else if std::env::var("GEMINI_API_KEY").is_ok() {
print_bullet(&format!(
"{} GEMINI_API_KEY environment variable detected!",
style("").green().bold()
));
String::new()
} else {
print_bullet("Get your API key at: https://aistudio.google.com/app/apikey");
print_bullet("Or run `gemini` CLI to authenticate (tokens will be reused).");
println!();
Input::new()
.with_prompt(" Paste your Gemini API key (or press Enter to skip)")
.allow_empty(true)
.interact_text()?
}
} else {
let key_url = match provider_name {
"openrouter" => "https://openrouter.ai/keys",
@ -594,6 +646,7 @@ fn setup_provider() -> Result<(String, String, String)> {
"vercel" => "https://vercel.com/account/tokens",
"cloudflare" => "https://dash.cloudflare.com/profile/api-tokens",
"bedrock" => "https://console.aws.amazon.com/iam",
"gemini" | "google" | "google-gemini" => "https://aistudio.google.com/app/apikey",
_ => "",
};
@ -735,6 +788,15 @@ fn setup_provider() -> Result<(String, String, String)> {
("codellama", "Code Llama"),
("phi3", "Phi-3 (small, fast)"),
],
"gemini" | "google" | "google-gemini" => vec![
("gemini-2.0-flash", "Gemini 2.0 Flash (fast, recommended)"),
(
"gemini-2.0-flash-lite",
"Gemini 2.0 Flash Lite (fastest, cheapest)",
),
("gemini-1.5-pro", "Gemini 1.5 Pro (best quality)"),
("gemini-1.5-flash", "Gemini 1.5 Flash (balanced)"),
],
_ => vec![("default", "Default model")],
};
@ -783,6 +845,7 @@ fn provider_env_var(name: &str) -> &'static str {
"vercel" | "vercel-ai" => "VERCEL_API_KEY",
"cloudflare" | "cloudflare-ai" => "CLOUDFLARE_API_KEY",
"bedrock" | "aws-bedrock" => "AWS_ACCESS_KEY_ID",
"gemini" | "google" | "google-gemini" => "GEMINI_API_KEY",
_ => "API_KEY",
}
}
@ -1619,8 +1682,8 @@ fn setup_channels() -> Result<ChannelsConfig> {
access_token: access_token.trim().to_string(),
phone_number_id: phone_number_id.trim().to_string(),
verify_token: verify_token.trim().to_string(),
app_secret: None, // Can be set via ZEROCLAW_WHATSAPP_APP_SECRET env var
allowed_numbers,
app_secret: None, // Can be set via ZEROCLAW_WHATSAPP_APP_SECRET env var
});
}
6 => {

385
src/providers/gemini.rs Normal file
View file

@ -0,0 +1,385 @@
//! Google Gemini provider with support for:
//! - Direct API key (`GEMINI_API_KEY` env var or config)
//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication)
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
use crate::providers::traits::Provider;
use async_trait::async_trait;
use directories::UserDirs;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Gemini provider supporting multiple authentication methods.
pub struct GeminiProvider {
api_key: Option<String>,
client: Client,
}
// ══════════════════════════════════════════════════════════════════════════════
// API REQUEST/RESPONSE TYPES
// ══════════════════════════════════════════════════════════════════════════════
#[derive(Debug, Serialize)]
struct GenerateContentRequest {
contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<Content>,
#[serde(rename = "generationConfig")]
generation_config: GenerationConfig,
}
#[derive(Debug, Serialize)]
struct Content {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
parts: Vec<Part>,
}
#[derive(Debug, Serialize)]
struct Part {
text: String,
}
#[derive(Debug, Serialize)]
struct GenerationConfig {
temperature: f64,
#[serde(rename = "maxOutputTokens")]
max_output_tokens: u32,
}
#[derive(Debug, Deserialize)]
struct GenerateContentResponse {
candidates: Option<Vec<Candidate>>,
error: Option<ApiError>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: CandidateContent,
}
#[derive(Debug, Deserialize)]
struct CandidateContent {
parts: Vec<ResponsePart>,
}
#[derive(Debug, Deserialize)]
struct ResponsePart {
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
}
// ══════════════════════════════════════════════════════════════════════════════
// GEMINI CLI TOKEN STRUCTURES
// ══════════════════════════════════════════════════════════════════════════════
/// OAuth token stored by Gemini CLI in `~/.gemini/oauth_creds.json`
#[derive(Debug, Deserialize)]
struct GeminiCliOAuthCreds {
access_token: Option<String>,
refresh_token: Option<String>,
expiry: Option<String>,
}
/// Settings stored by Gemini CLI in ~/.gemini/settings.json
#[derive(Debug, Deserialize)]
struct GeminiCliSettings {
#[serde(rename = "selectedAuthType")]
selected_auth_type: Option<String>,
}
impl GeminiProvider {
/// Create a new Gemini provider.
///
/// Authentication priority:
/// 1. Explicit API key passed in
/// 2. `GEMINI_API_KEY` environment variable
/// 3. `GOOGLE_API_KEY` environment variable
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
pub fn new(api_key: Option<&str>) -> Self {
let resolved_key = api_key
.map(String::from)
.or_else(|| std::env::var("GEMINI_API_KEY").ok())
.or_else(|| std::env::var("GOOGLE_API_KEY").ok())
.or_else(Self::try_load_gemini_cli_token);
Self {
api_key: resolved_key,
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
/// Try to load OAuth access token from Gemini CLI's cached credentials.
/// Location: `~/.gemini/oauth_creds.json`
fn try_load_gemini_cli_token() -> Option<String> {
let gemini_dir = Self::gemini_cli_dir()?;
let creds_path = gemini_dir.join("oauth_creds.json");
if !creds_path.exists() {
return None;
}
let content = std::fs::read_to_string(&creds_path).ok()?;
let creds: GeminiCliOAuthCreds = serde_json::from_str(&content).ok()?;
// Check if token is expired (basic check)
if let Some(ref expiry) = creds.expiry {
if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) {
if expiry_time < chrono::Utc::now() {
tracing::debug!("Gemini CLI OAuth token expired, skipping");
return None;
}
}
}
creds.access_token
}
/// Get the Gemini CLI config directory (~/.gemini)
fn gemini_cli_dir() -> Option<PathBuf> {
UserDirs::new().map(|u| u.home_dir().join(".gemini"))
}
/// Check if Gemini CLI is configured and has valid credentials
pub fn has_cli_credentials() -> bool {
Self::try_load_gemini_cli_token().is_some()
}
/// Check if any Gemini authentication is available
pub fn has_any_auth() -> bool {
std::env::var("GEMINI_API_KEY").is_ok()
|| std::env::var("GOOGLE_API_KEY").is_ok()
|| Self::has_cli_credentials()
}
/// Get authentication source description for diagnostics
pub fn auth_source(&self) -> &'static str {
if self.api_key.is_none() {
return "none";
}
if std::env::var("GEMINI_API_KEY").is_ok() {
return "GEMINI_API_KEY env var";
}
if std::env::var("GOOGLE_API_KEY").is_ok() {
return "GOOGLE_API_KEY env var";
}
if Self::has_cli_credentials() {
return "Gemini CLI OAuth";
}
"config"
}
}
#[async_trait]
impl Provider for GeminiProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Gemini API key not found. Options:\n\
1. Set GEMINI_API_KEY env var\n\
2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
3. Get an API key from https://aistudio.google.com/app/apikey\n\
4. Run `zeroclaw onboard` to configure"
)
})?;
// Build request
let system_instruction = system_prompt.map(|sys| Content {
role: None,
parts: vec![Part {
text: sys.to_string(),
}],
});
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: message.to_string(),
}],
}],
system_instruction,
generation_config: GenerationConfig {
temperature,
max_output_tokens: 8192,
},
};
// Gemini API endpoint
// Model format: gemini-2.0-flash, gemini-1.5-pro, etc.
let model_name = if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
};
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent?key={api_key}"
);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Gemini API error ({status}): {error_text}");
}
let result: GenerateContentResponse = response.json().await?;
// Check for API error in response body
if let Some(err) = result.error {
anyhow::bail!("Gemini API error: {}", err.message);
}
// Extract text from response
result
.candidates
.and_then(|c| c.into_iter().next())
.and_then(|c| c.content.parts.into_iter().next())
.and_then(|p| p.text)
.ok_or_else(|| anyhow::anyhow!("No response from Gemini"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provider_creates_without_key() {
let provider = GeminiProvider::new(None);
// Should not panic, just have no key
assert!(provider.api_key.is_none() || provider.api_key.is_some());
}
#[test]
fn provider_creates_with_key() {
let provider = GeminiProvider::new(Some("test-api-key"));
assert!(provider.api_key.is_some());
assert_eq!(provider.api_key.as_deref(), Some("test-api-key"));
}
#[test]
fn gemini_cli_dir_returns_path() {
let dir = GeminiProvider::gemini_cli_dir();
// Should return Some on systems with home dir
if UserDirs::new().is_some() {
assert!(dir.is_some());
assert!(dir.unwrap().ends_with(".gemini"));
}
}
#[test]
fn auth_source_reports_correctly() {
let provider = GeminiProvider::new(Some("explicit-key"));
// With explicit key, should report "config" (unless CLI credentials exist)
let source = provider.auth_source();
// Should be either "config" or "Gemini CLI OAuth" if CLI is configured
assert!(source == "config" || source == "Gemini CLI OAuth");
}
#[test]
fn model_name_formatting() {
// Test that model names are formatted correctly
let model = "gemini-2.0-flash";
let formatted = if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
};
assert_eq!(formatted, "models/gemini-2.0-flash");
// Already prefixed
let model2 = "models/gemini-1.5-pro";
let formatted2 = if model2.starts_with("models/") {
model2.to_string()
} else {
format!("models/{model2}")
};
assert_eq!(formatted2, "models/gemini-1.5-pro");
}
#[test]
fn request_serialization() {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: "Hello".to_string(),
}],
}],
system_instruction: Some(Content {
role: None,
parts: vec![Part {
text: "You are helpful".to_string(),
}],
}),
generation_config: GenerationConfig {
temperature: 0.7,
max_output_tokens: 8192,
},
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"text\":\"Hello\""));
assert!(json.contains("\"temperature\":0.7"));
assert!(json.contains("\"maxOutputTokens\":8192"));
}
#[test]
fn response_deserialization() {
let json = r#"{
"candidates": [{
"content": {
"parts": [{"text": "Hello there!"}]
}
}]
}"#;
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
assert!(response.candidates.is_some());
let text = response
.candidates
.unwrap()
.into_iter()
.next()
.unwrap()
.content
.parts
.into_iter()
.next()
.unwrap()
.text;
assert_eq!(text, Some("Hello there!".to_string()));
}
#[test]
fn error_response_deserialization() {
let json = r#"{
"error": {
"message": "Invalid API key"
}
}"#;
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().message, "Invalid API key");
}
}

View file

@ -1,5 +1,6 @@
pub mod anthropic;
pub mod compatible;
pub mod gemini;
pub mod ollama;
pub mod openai;
pub mod openrouter;
@ -100,6 +101,9 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(
api_key.filter(|k| !k.is_empty()),
))),
"gemini" | "google" | "google-gemini" => {
Ok(Box::new(gemini::GeminiProvider::new(api_key)))
}
// ── OpenAI-compatible providers ──────────────────────
"venice" => Ok(Box::new(OpenAiCompatibleProvider::new(
@ -253,6 +257,15 @@ mod tests {
assert!(create_provider("ollama", None).is_ok());
}
#[test]
fn factory_gemini() {
assert!(create_provider("gemini", Some("test-key")).is_ok());
assert!(create_provider("google", Some("test-key")).is_ok());
assert!(create_provider("google-gemini", Some("test-key")).is_ok());
// Should also work without key (will try CLI auth)
assert!(create_provider("gemini", None).is_ok());
}
// ── OpenAI-compatible providers ──────────────────────────
#[test]
@ -445,6 +458,7 @@ mod tests {
"anthropic",
"openai",
"ollama",
"gemini",
"venice",
"vercel",
"cloudflare",

134
src/util.rs Normal file
View file

@ -0,0 +1,134 @@
//! Utility functions for ZeroClaw.
//!
//! This module contains reusable helper functions used across the codebase.
/// Truncate a string to at most `max_chars` characters, appending "..." if truncated.
///
/// This function safely handles multi-byte UTF-8 characters (emoji, CJK, accented characters)
/// by using character boundaries instead of byte indices.
///
/// # Arguments
/// * `s` - The string to truncate
/// * `max_chars` - Maximum number of characters to keep (excluding "...")
///
/// # Returns
/// * Original string if length <= `max_chars`
/// * Truncated string with "..." appended if length > `max_chars`
///
/// # Examples
/// ```
/// use zeroclaw::util::truncate_with_ellipsis;
///
/// // ASCII string - no truncation needed
/// assert_eq!(truncate_with_ellipsis("hello", 10), "hello");
///
/// // ASCII string - truncation needed
/// assert_eq!(truncate_with_ellipsis("hello world", 5), "hello...");
///
/// // Multi-byte UTF-8 (emoji) - safe truncation
/// assert_eq!(truncate_with_ellipsis("Hello 🦀 World", 8), "Hello 🦀...");
/// assert_eq!(truncate_with_ellipsis("😀😀😀😀", 2), "😀😀...");
///
/// // Empty string
/// assert_eq!(truncate_with_ellipsis("", 10), "");
/// ```
pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String {
match s.char_indices().nth(max_chars) {
Some((idx, _)) => format!("{}...", &s[..idx]),
None => s.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_truncate_ascii_no_truncation() {
// ASCII string shorter than limit - no change
assert_eq!(truncate_with_ellipsis("hello", 10), "hello");
assert_eq!(truncate_with_ellipsis("hello world", 50), "hello world");
}
#[test]
fn test_truncate_ascii_with_truncation() {
// ASCII string longer than limit - truncates
assert_eq!(truncate_with_ellipsis("hello world", 5), "hello...");
assert_eq!(truncate_with_ellipsis("This is a long message", 10), "This is a ...");
}
#[test]
fn test_truncate_empty_string() {
assert_eq!(truncate_with_ellipsis("", 10), "");
}
#[test]
fn test_truncate_at_exact_boundary() {
// String exactly at boundary - no truncation
assert_eq!(truncate_with_ellipsis("hello", 5), "hello");
}
#[test]
fn test_truncate_emoji_single() {
// Single emoji (4 bytes) - should not panic
let s = "🦀";
assert_eq!(truncate_with_ellipsis(s, 10), s);
assert_eq!(truncate_with_ellipsis(s, 1), s);
}
#[test]
fn test_truncate_emoji_multiple() {
// Multiple emoji - safe truncation at character boundary
let s = "😀😀😀😀"; // 4 emoji, each 4 bytes = 16 bytes total
assert_eq!(truncate_with_ellipsis(s, 2), "😀😀...");
assert_eq!(truncate_with_ellipsis(s, 3), "😀😀😀...");
}
#[test]
fn test_truncate_mixed_ascii_emoji() {
// Mixed ASCII and emoji
assert_eq!(truncate_with_ellipsis("Hello 🦀 World", 8), "Hello 🦀...");
assert_eq!(truncate_with_ellipsis("Hi 😊", 10), "Hi 😊");
}
#[test]
fn test_truncate_cjk_characters() {
// CJK characters (Chinese - each is 3 bytes)
// This would panic with byte slicing: &s[..50] where s has 17 chars (51 bytes)
let s = "这是一个测试消息用来触发崩溃的中文"; // 21 characters
// Each character is 3 bytes, so 50 bytes is ~16 characters
let result = truncate_with_ellipsis(s, 16);
assert!(result.ends_with("..."));
// Should not panic and should be valid UTF-8
assert!(result.is_char_boundary(result.len() - 1));
}
#[test]
fn test_truncate_accented_characters() {
// Accented characters (2 bytes each in UTF-8)
let s = "café résumé naïve";
assert_eq!(truncate_with_ellipsis(s, 10), "café résumé...");
}
#[test]
fn test_truncate_unicode_edge_case() {
// Mix of 1-byte, 2-byte, 3-byte, and 4-byte characters
let s = "aé你好🦀"; // 1 + 1 + 2 + 2 + 4 bytes = 10 bytes, 5 chars
assert_eq!(truncate_with_ellipsis(s, 3), "aé你好...");
}
#[test]
fn test_truncate_long_string() {
// Long ASCII string
let s = "a".repeat(200);
let result = truncate_with_ellipsis(&s, 50);
assert_eq!(result.len(), 53); // 50 + "..."
assert!(result.ends_with("..."));
}
#[test]
fn test_truncate_zero_max_chars() {
// Edge case: max_chars = 0
assert_eq!(truncate_with_ellipsis("hello", 0), "...");
}
}