feat: initial release — ZeroClaw v0.1.0
- 22 AI providers (OpenRouter, Anthropic, OpenAI, Mistral, etc.) - 7 channels (CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook) - 5-step onboarding wizard with Project Context personalization - OpenClaw-aligned system prompt (SOUL.md, IDENTITY.md, USER.md, AGENTS.md, etc.) - SQLite memory backend with auto-save - Skills system with on-demand loading - Security: autonomy levels, command allowlists, cost limits - 532 tests passing, 0 clippy warnings
This commit is contained in:
commit
05cb353f7f
71 changed files with 15757 additions and 0 deletions
344
src/memory/markdown.rs
Normal file
344
src/memory/markdown.rs
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::fs;
|
||||
|
||||
/// Markdown-based memory — plain files as source of truth
|
||||
///
|
||||
/// Layout:
|
||||
/// workspace/MEMORY.md — curated long-term memory (core)
|
||||
/// workspace/memory/YYYY-MM-DD.md — daily logs (append-only)
|
||||
pub struct MarkdownMemory {
|
||||
workspace_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl MarkdownMemory {
|
||||
pub fn new(workspace_dir: &Path) -> Self {
|
||||
Self {
|
||||
workspace_dir: workspace_dir.to_path_buf(),
|
||||
}
|
||||
}
|
||||
|
||||
fn memory_dir(&self) -> PathBuf {
|
||||
self.workspace_dir.join("memory")
|
||||
}
|
||||
|
||||
fn core_path(&self) -> PathBuf {
|
||||
self.workspace_dir.join("MEMORY.md")
|
||||
}
|
||||
|
||||
fn daily_path(&self) -> PathBuf {
|
||||
let date = Local::now().format("%Y-%m-%d").to_string();
|
||||
self.memory_dir().join(format!("{date}.md"))
|
||||
}
|
||||
|
||||
async fn ensure_dirs(&self) -> anyhow::Result<()> {
|
||||
fs::create_dir_all(self.memory_dir()).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn append_to_file(&self, path: &Path, content: &str) -> anyhow::Result<()> {
|
||||
self.ensure_dirs().await?;
|
||||
|
||||
let existing = if path.exists() {
|
||||
fs::read_to_string(path).await.unwrap_or_default()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let updated = if existing.is_empty() {
|
||||
let header = if path == self.core_path() {
|
||||
"# Long-Term Memory\n\n"
|
||||
} else {
|
||||
let date = Local::now().format("%Y-%m-%d").to_string();
|
||||
&format!("# Daily Log — {date}\n\n")
|
||||
};
|
||||
format!("{header}{content}\n")
|
||||
} else {
|
||||
format!("{existing}\n{content}\n")
|
||||
};
|
||||
|
||||
fs::write(path, updated).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_entries_from_file(
|
||||
path: &Path,
|
||||
content: &str,
|
||||
category: &MemoryCategory,
|
||||
) -> Vec<MemoryEntry> {
|
||||
let filename = path
|
||||
.file_stem()
|
||||
.and_then(|s| s.to_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
content
|
||||
.lines()
|
||||
.filter(|line| {
|
||||
let trimmed = line.trim();
|
||||
!trimmed.is_empty() && !trimmed.starts_with('#')
|
||||
})
|
||||
.enumerate()
|
||||
.map(|(i, line)| {
|
||||
let trimmed = line.trim();
|
||||
let clean = trimmed.strip_prefix("- ").unwrap_or(trimmed);
|
||||
MemoryEntry {
|
||||
id: format!("{filename}:{i}"),
|
||||
key: format!("{filename}:{i}"),
|
||||
content: clean.to_string(),
|
||||
category: category.clone(),
|
||||
timestamp: filename.to_string(),
|
||||
session_id: None,
|
||||
score: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn read_all_entries(&self) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let mut entries = Vec::new();
|
||||
|
||||
// Read MEMORY.md (core)
|
||||
let core_path = self.core_path();
|
||||
if core_path.exists() {
|
||||
let content = fs::read_to_string(&core_path).await?;
|
||||
entries.extend(Self::parse_entries_from_file(
|
||||
&core_path,
|
||||
&content,
|
||||
&MemoryCategory::Core,
|
||||
));
|
||||
}
|
||||
|
||||
// Read daily logs
|
||||
let mem_dir = self.memory_dir();
|
||||
if mem_dir.exists() {
|
||||
let mut dir = fs::read_dir(&mem_dir).await?;
|
||||
while let Some(entry) = dir.next_entry().await? {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|e| e.to_str()) == Some("md") {
|
||||
let content = fs::read_to_string(&path).await?;
|
||||
entries.extend(Self::parse_entries_from_file(
|
||||
&path,
|
||||
&content,
|
||||
&MemoryCategory::Daily,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
entries.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
|
||||
Ok(entries)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for MarkdownMemory {
|
||||
fn name(&self) -> &str {
|
||||
"markdown"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
let entry = format!("- **{key}**: {content}");
|
||||
let path = match category {
|
||||
MemoryCategory::Core => self.core_path(),
|
||||
_ => self.daily_path(),
|
||||
};
|
||||
self.append_to_file(&path, &entry).await
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
||||
let mut scored: Vec<MemoryEntry> = all
|
||||
.into_iter()
|
||||
.filter_map(|mut entry| {
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let matched = keywords
|
||||
.iter()
|
||||
.filter(|kw| content_lower.contains(**kw))
|
||||
.count();
|
||||
if matched > 0 {
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
let score = matched as f64 / keywords.len() as f64;
|
||||
entry.score = Some(score);
|
||||
Some(entry)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
scored.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
scored.truncate(limit);
|
||||
Ok(scored)
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
Ok(all
|
||||
.into_iter()
|
||||
.find(|e| e.key == key || e.content.contains(key)))
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
match category {
|
||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
None => Ok(all),
|
||||
}
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
// Markdown memory is append-only by design (audit trail)
|
||||
// Return false to indicate the entry wasn't removed
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let all = self.read_all_entries().await?;
|
||||
Ok(all.len())
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.workspace_dir.exists()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs as sync_fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_workspace() -> (TempDir, MarkdownMemory) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = MarkdownMemory::new(tmp.path());
|
||||
(tmp, mem)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_name() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_health_check() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
assert!(mem.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_store_core() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||
assert!(content.contains("User likes Rust"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_store_daily() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
let path = mem.daily_path();
|
||||
let content = sync_fs::read_to_string(path).unwrap();
|
||||
assert!(content.contains("Finished tests"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_recall_keyword() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|r| r.content.to_lowercase().contains("rust")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_recall_no_match() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let count = mem.count().await.unwrap();
|
||||
assert!(count >= 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_list_by_category() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "core fact", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_forget_is_noop() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "permanent", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let removed = mem.forget("a").await.unwrap();
|
||||
assert!(!removed, "Markdown memory is append-only");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_empty_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
assert_eq!(mem.count().await.unwrap(), 0);
|
||||
}
|
||||
}
|
||||
77
src/memory/mod.rs
Normal file
77
src/memory/mod.rs
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
pub mod markdown;
|
||||
pub mod sqlite;
|
||||
pub mod traits;
|
||||
|
||||
pub use markdown::MarkdownMemory;
|
||||
pub use sqlite::SqliteMemory;
|
||||
pub use traits::Memory;
|
||||
#[allow(unused_imports)]
|
||||
pub use traits::{MemoryCategory, MemoryEntry};
|
||||
|
||||
use crate::config::MemoryConfig;
|
||||
use std::path::Path;
|
||||
|
||||
/// Factory: create the right memory backend from config
|
||||
pub fn create_memory(
|
||||
config: &MemoryConfig,
|
||||
workspace_dir: &Path,
|
||||
) -> anyhow::Result<Box<dyn Memory>> {
|
||||
match config.backend.as_str() {
|
||||
"sqlite" => Ok(Box::new(SqliteMemory::new(workspace_dir)?)),
|
||||
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))),
|
||||
other => {
|
||||
tracing::warn!("Unknown memory backend '{other}', falling back to markdown");
|
||||
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn factory_sqlite() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "sqlite".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "sqlite");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "markdown".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_none_falls_back_to_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "none".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_unknown_falls_back_to_markdown() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let cfg = MemoryConfig {
|
||||
backend: "redis".into(),
|
||||
auto_save: true,
|
||||
};
|
||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
||||
assert_eq!(mem.name(), "markdown");
|
||||
}
|
||||
}
|
||||
481
src/memory/sqlite.rs
Normal file
481
src/memory/sqlite.rs
Normal file
|
|
@ -0,0 +1,481 @@
|
|||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Local;
|
||||
use rusqlite::{params, Connection};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// SQLite-backed persistent memory — the brain
|
||||
///
|
||||
/// Stores memories in a local `SQLite` database with keyword search.
|
||||
/// Zero external dependencies, works offline, survives restarts.
|
||||
pub struct SqliteMemory {
|
||||
conn: Mutex<Connection>,
|
||||
db_path: PathBuf,
|
||||
}
|
||||
|
||||
impl SqliteMemory {
|
||||
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
|
||||
let db_path = workspace_dir.join("memory").join("brain.db");
|
||||
|
||||
if let Some(parent) = db_path.parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(&db_path)?;
|
||||
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS memories (
|
||||
id TEXT PRIMARY KEY,
|
||||
key TEXT NOT NULL UNIQUE,
|
||||
content TEXT NOT NULL,
|
||||
category TEXT NOT NULL DEFAULT 'core',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);",
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
db_path,
|
||||
})
|
||||
}
|
||||
|
||||
fn category_to_str(cat: &MemoryCategory) -> String {
|
||||
match cat {
|
||||
MemoryCategory::Core => "core".into(),
|
||||
MemoryCategory::Daily => "daily".into(),
|
||||
MemoryCategory::Conversation => "conversation".into(),
|
||||
MemoryCategory::Custom(name) => name.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn str_to_category(s: &str) -> MemoryCategory {
|
||||
match s {
|
||||
"core" => MemoryCategory::Core,
|
||||
"daily" => MemoryCategory::Daily,
|
||||
"conversation" => MemoryCategory::Conversation,
|
||||
other => MemoryCategory::Custom(other.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Memory for SqliteMemory {
|
||||
fn name(&self) -> &str {
|
||||
"sqlite"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
) -> anyhow::Result<()> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let now = Local::now().to_rfc3339();
|
||||
let cat = Self::category_to_str(&category);
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
updated_at = excluded.updated_at",
|
||||
params![id, key, content, cat, now, now],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
// Keyword search: split query into words, match any
|
||||
let keywords: Vec<String> = query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
||||
|
||||
if keywords.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
// Build dynamic WHERE clause for keyword matching
|
||||
let conditions: Vec<String> = keywords
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, _)| format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2))
|
||||
.collect();
|
||||
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
WHERE {where_clause}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
keywords.len() * 2 + 1
|
||||
);
|
||||
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
|
||||
// Build params: each keyword appears twice (for content and key)
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
for kw in &keywords {
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
param_values.push(Box::new(kw.clone()));
|
||||
}
|
||||
#[allow(clippy::cast_possible_wrap)]
|
||||
param_values.push(Box::new(limit as i64));
|
||||
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
param_values.iter().map(AsRef::as_ref).collect();
|
||||
|
||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
score: Some(1.0),
|
||||
})
|
||||
})?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
|
||||
// Score by keyword match count
|
||||
let query_lower = query.to_lowercase();
|
||||
let kw_list: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
for entry in &mut results {
|
||||
let content_lower = entry.content.to_lowercase();
|
||||
let matched = kw_list
|
||||
.iter()
|
||||
.filter(|kw| content_lower.contains(**kw))
|
||||
.count();
|
||||
#[allow(clippy::cast_precision_loss)]
|
||||
{
|
||||
entry.score = Some(matched as f64 / kw_list.len().max(1) as f64);
|
||||
}
|
||||
}
|
||||
|
||||
results.sort_by(|a, b| {
|
||||
b.score
|
||||
.partial_cmp(&a.score)
|
||||
.unwrap_or(std::cmp::Ordering::Equal)
|
||||
});
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
||||
match rows.next() {
|
||||
Some(Ok(entry)) => Ok(Some(entry)),
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut results = Vec::new();
|
||||
|
||||
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
score: None,
|
||||
})
|
||||
};
|
||||
|
||||
if let Some(cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map([], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
||||
Ok(affected > 0)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok(count as usize)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.conn
|
||||
.lock()
|
||||
.map(|c| c.execute_batch("SELECT 1").is_ok())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn temp_sqlite() -> (TempDir, SqliteMemory) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
(tmp, mem)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_name() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert_eq!(mem.name(), "sqlite");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_health() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert!(mem.health_check().await);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_store_and_get() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entry = mem.get("user_lang").await.unwrap();
|
||||
assert!(entry.is_some());
|
||||
let entry = entry.unwrap();
|
||||
assert_eq!(entry.key, "user_lang");
|
||||
assert_eq!(entry.content, "Prefers Rust");
|
||||
assert_eq!(entry.category, MemoryCategory::Core);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_store_upsert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entry = mem.get("pref").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "loves Rust");
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_recall_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|r| r.content.to_lowercase().contains("rust")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_recall_multi_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_recall_no_match() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
||||
let removed = mem.forget("temp").await.unwrap();
|
||||
assert!(removed);
|
||||
assert_eq!(mem.count().await.unwrap(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget_nonexistent() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let removed = mem.forget("nope").await.unwrap();
|
||||
assert!(!removed);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = mem.list(None).await.unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_by_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
assert_eq!(core.len(), 2);
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
assert_eq!(daily.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_count_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert_eq!(mem.count().await.unwrap(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_get_nonexistent() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
assert!(mem.get("nope").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_db_persists() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Reopen
|
||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let entry = mem2.get("persist").await.unwrap();
|
||||
assert!(entry.is_some());
|
||||
assert_eq!(entry.unwrap().content, "I survive restarts");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_category_roundtrip() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let categories = vec![
|
||||
MemoryCategory::Core,
|
||||
MemoryCategory::Daily,
|
||||
MemoryCategory::Conversation,
|
||||
MemoryCategory::Custom("project".into()),
|
||||
];
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
let entry = mem.get(&format!("k{i}")).await.unwrap().unwrap();
|
||||
assert_eq!(&entry.category, cat);
|
||||
}
|
||||
}
|
||||
}
|
||||
68
src/memory/traits.rs
Normal file
68
src/memory/traits.rs
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A single memory entry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryEntry {
|
||||
pub id: String,
|
||||
pub key: String,
|
||||
pub content: String,
|
||||
pub category: MemoryCategory,
|
||||
pub timestamp: String,
|
||||
pub session_id: Option<String>,
|
||||
pub score: Option<f64>,
|
||||
}
|
||||
|
||||
/// Memory categories for organization
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MemoryCategory {
|
||||
/// Long-term facts, preferences, decisions
|
||||
Core,
|
||||
/// Daily session logs
|
||||
Daily,
|
||||
/// Conversation context
|
||||
Conversation,
|
||||
/// User-defined custom category
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for MemoryCategory {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Core => write!(f, "core"),
|
||||
Self::Daily => write!(f, "daily"),
|
||||
Self::Conversation => write!(f, "conversation"),
|
||||
Self::Custom(name) => write!(f, "{name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Core memory trait — implement for any persistence backend
|
||||
#[async_trait]
|
||||
pub trait Memory: Send + Sync {
|
||||
/// Backend name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Store a memory entry
|
||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
||||
-> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search)
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||
|
||||
/// List all memory keys, optionally filtered by category
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Remove a memory by key
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
|
||||
/// Count total memories
|
||||
async fn count(&self) -> anyhow::Result<usize>;
|
||||
|
||||
/// Health check
|
||||
async fn health_check(&self) -> bool;
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue