feat(memory): add configurable postgres storage backend

This commit is contained in:
Chummy 2026-02-18 17:58:23 +08:00
parent b13e230942
commit 483acccdb7
14 changed files with 859 additions and 27 deletions

View file

@ -229,8 +229,9 @@ impl Agent {
&config.workspace_dir,
));
let memory: Arc<dyn Memory> = Arc::from(memory::create_memory(
let memory: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
&config.memory,
Some(&config.storage.provider.config),
&config.workspace_dir,
config.api_key.as_deref(),
)?);

View file

@ -1165,8 +1165,9 @@ pub async fn run(
));
// ── Memory (the brain) ────────────────────────────────────────
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
&config.memory,
Some(&config.storage.provider.config),
&config.workspace_dir,
config.api_key.as_deref(),
)?);
@ -1625,8 +1626,9 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
&config.autonomy,
&config.workspace_dir,
));
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
&config.memory,
Some(&config.storage.provider.config),
&config.workspace_dir,
config.api_key.as_deref(),
)?);

View file

@ -1148,8 +1148,9 @@ pub async fn start_channels(config: Config) -> Result<()> {
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
&config.memory,
Some(&config.storage.provider.config),
&config.workspace_dir,
config.api_key.as_deref(),
)?);
@ -1382,9 +1383,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
println!("🦀 ZeroClaw Channel Server");
println!(" 🤖 Model: {model}");
let effective_backend = memory::effective_memory_backend_name(
&config.memory.backend,
Some(&config.storage.provider.config),
);
println!(
" 🧠 Memory: {} (auto-save: {})",
config.memory.backend,
effective_backend,
if config.memory.auto_save { "on" } else { "off" }
);
println!(

View file

@ -9,8 +9,9 @@ pub use schema::{
LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig,
PeripheralBoardConfig, PeripheralsConfig, QueryClassificationConfig, ReliabilityConfig,
ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig,
SecretsConfig, SecurityConfig, SlackConfig, StreamMode, TelegramConfig, TunnelConfig,
WebSearchConfig, WebhookConfig,
SecretsConfig, SecurityConfig, SlackConfig, StorageConfig, StorageProviderConfig,
StorageProviderSection, StreamMode, TelegramConfig, TunnelConfig, WebSearchConfig,
WebhookConfig,
};
#[cfg(test)]

View file

@ -63,6 +63,9 @@ pub struct Config {
#[serde(default)]
pub memory: MemoryConfig,
#[serde(default)]
pub storage: StorageConfig,
#[serde(default)]
pub tunnel: TunnelConfig,
@ -771,10 +774,73 @@ impl Default for WebSearchConfig {
// ── Memory ───────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StorageConfig {
#[serde(default)]
pub provider: StorageProviderSection,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct StorageProviderSection {
#[serde(default)]
pub config: StorageProviderConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StorageProviderConfig {
/// Storage engine key (e.g. "postgres", "sqlite").
#[serde(default)]
pub provider: String,
/// Connection URL for remote providers.
/// Accepts legacy aliases: dbURL, database_url, databaseUrl.
#[serde(
default,
alias = "dbURL",
alias = "database_url",
alias = "databaseUrl"
)]
pub db_url: Option<String>,
/// Database schema for SQL backends.
#[serde(default = "default_storage_schema")]
pub schema: String,
/// Table name for memory entries.
#[serde(default = "default_storage_table")]
pub table: String,
/// Optional connection timeout in seconds for remote providers.
#[serde(default)]
pub connect_timeout_secs: Option<u64>,
}
fn default_storage_schema() -> String {
"public".into()
}
fn default_storage_table() -> String {
"memories".into()
}
impl Default for StorageProviderConfig {
fn default() -> Self {
Self {
provider: String::new(),
db_url: None,
schema: default_storage_schema(),
table: default_storage_table(),
connect_timeout_secs: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct MemoryConfig {
/// "sqlite" | "lucid" | "markdown" | "none" (`none` = explicit no-op memory)
/// "sqlite" | "lucid" | "postgres" | "markdown" | "none" (`none` = explicit no-op memory)
///
/// `postgres` requires `[storage.provider.config]` with `db_url` (`dbURL` alias supported).
pub backend: String,
/// Auto-save conversation context to memory
pub auto_save: bool,
@ -1844,6 +1910,7 @@ impl Default for Config {
cron: CronConfig::default(),
channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(),
storage: StorageConfig::default(),
tunnel: TunnelConfig::default(),
gateway: GatewayConfig::default(),
composio: ComposioConfig::default(),
@ -2113,6 +2180,12 @@ impl Config {
"config.web_search.brave_api_key",
)?;
decrypt_optional_secret(
&store,
&mut config.storage.provider.config.db_url,
"config.storage.provider.config.db_url",
)?;
for agent in config.agents.values_mut() {
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
@ -2266,6 +2339,31 @@ impl Config {
}
}
}
// Storage provider key (optional backend override): ZEROCLAW_STORAGE_PROVIDER
if let Ok(provider) = std::env::var("ZEROCLAW_STORAGE_PROVIDER") {
let provider = provider.trim();
if !provider.is_empty() {
self.storage.provider.config.provider = provider.to_string();
}
}
// Storage connection URL (for remote backends): ZEROCLAW_STORAGE_DB_URL
if let Ok(db_url) = std::env::var("ZEROCLAW_STORAGE_DB_URL") {
let db_url = db_url.trim();
if !db_url.is_empty() {
self.storage.provider.config.db_url = Some(db_url.to_string());
}
}
// Storage connect timeout: ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS
if let Ok(timeout_secs) = std::env::var("ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS") {
if let Ok(timeout_secs) = timeout_secs.parse::<u64>() {
if timeout_secs > 0 {
self.storage.provider.config.connect_timeout_secs = Some(timeout_secs);
}
}
}
}
pub fn save(&self) -> Result<()> {
@ -2296,6 +2394,12 @@ impl Config {
"config.web_search.brave_api_key",
)?;
encrypt_optional_secret(
&store,
&mut config_to_save.storage.provider.config.db_url,
"config.storage.provider.config.db_url",
)?;
for agent in config_to_save.agents.values_mut() {
encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
@ -2483,6 +2587,16 @@ default_temperature = 0.7
assert!(m.sqlite_open_timeout_secs.is_none());
}
#[test]
fn storage_provider_config_defaults() {
let storage = StorageConfig::default();
assert!(storage.provider.config.provider.is_empty());
assert!(storage.provider.config.db_url.is_none());
assert_eq!(storage.provider.config.schema, "public");
assert_eq!(storage.provider.config.table, "memories");
assert!(storage.provider.config.connect_timeout_secs.is_none());
}
#[test]
fn channels_config_default() {
let c = ChannelsConfig::default();
@ -2556,6 +2670,7 @@ default_temperature = 0.7
qq: None,
},
memory: MemoryConfig::default(),
storage: StorageConfig::default(),
tunnel: TunnelConfig::default(),
gateway: GatewayConfig::default(),
composio: ComposioConfig::default(),
@ -2612,6 +2727,33 @@ default_temperature = 0.7
assert_eq!(parsed.memory.conversation_retention_days, 30);
}
#[test]
fn storage_provider_dburl_alias_deserializes() {
let raw = r#"
default_temperature = 0.7
[storage.provider.config]
provider = "postgres"
dbURL = "postgres://postgres:postgres@localhost:5432/zeroclaw"
schema = "public"
table = "memories"
connect_timeout_secs = 12
"#;
let parsed: Config = toml::from_str(raw).unwrap();
assert_eq!(parsed.storage.provider.config.provider, "postgres");
assert_eq!(
parsed.storage.provider.config.db_url.as_deref(),
Some("postgres://postgres:postgres@localhost:5432/zeroclaw")
);
assert_eq!(parsed.storage.provider.config.schema, "public");
assert_eq!(parsed.storage.provider.config.table, "memories");
assert_eq!(
parsed.storage.provider.config.connect_timeout_secs,
Some(12)
);
}
#[test]
fn agent_config_defaults() {
let cfg = AgentConfig::default();
@ -2667,6 +2809,7 @@ tool_dispatcher = "xml"
cron: CronConfig::default(),
channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(),
storage: StorageConfig::default(),
tunnel: TunnelConfig::default(),
gateway: GatewayConfig::default(),
composio: ComposioConfig::default(),
@ -2715,6 +2858,7 @@ tool_dispatcher = "xml"
config.composio.api_key = Some("composio-credential".into());
config.browser.computer_use.api_key = Some("browser-credential".into());
config.web_search.brave_api_key = Some("brave-credential".into());
config.storage.provider.config.db_url = Some("postgres://user:pw@host/db".into());
config.agents.insert(
"worker".into(),
@ -2770,6 +2914,13 @@ tool_dispatcher = "xml"
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
let storage_db_url = stored.storage.provider.config.db_url.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(storage_db_url));
assert_eq!(
store.decrypt(storage_db_url).unwrap(),
"postgres://user:pw@host/db"
);
let _ = fs::remove_dir_all(&dir);
}
@ -3927,6 +4078,32 @@ default_model = "legacy-model"
std::env::remove_var("WEB_SEARCH_TIMEOUT_SECS");
}
#[test]
fn env_override_storage_provider_config() {
let _env_guard = env_override_test_guard();
let mut config = Config::default();
std::env::set_var("ZEROCLAW_STORAGE_PROVIDER", "postgres");
std::env::set_var("ZEROCLAW_STORAGE_DB_URL", "postgres://example/db");
std::env::set_var("ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS", "15");
config.apply_env_overrides();
assert_eq!(config.storage.provider.config.provider, "postgres");
assert_eq!(
config.storage.provider.config.db_url.as_deref(),
Some("postgres://example/db")
);
assert_eq!(
config.storage.provider.config.connect_timeout_secs,
Some(15)
);
std::env::remove_var("ZEROCLAW_STORAGE_PROVIDER");
std::env::remove_var("ZEROCLAW_STORAGE_DB_URL");
std::env::remove_var("ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS");
}
#[test]
fn gateway_config_default_values() {
let g = GatewayConfig::default();

View file

@ -313,8 +313,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4".into());
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory_with_storage(
&config.memory,
Some(&config.storage.provider.config),
&config.workspace_dir,
config.api_key.as_deref(),
)?);

View file

@ -576,6 +576,10 @@ async fn main() -> Result<()> {
println!("📊 Observability: {}", config.observability.backend);
println!("🛡️ Autonomy: {:?}", config.autonomy.level);
println!("⚙️ Runtime: {}", config.runtime.kind);
let effective_memory_backend = memory::effective_memory_backend_name(
&config.memory.backend,
Some(&config.storage.provider.config),
);
println!(
"💓 Heartbeat: {}",
if config.heartbeat.enabled {
@ -586,7 +590,7 @@ async fn main() -> Result<()> {
);
println!(
"🧠 Memory: {} (auto-save: {})",
config.memory.backend,
effective_memory_backend,
if config.memory.auto_save { "on" } else { "off" }
);

View file

@ -2,6 +2,7 @@
pub enum MemoryBackendKind {
Sqlite,
Lucid,
Postgres,
Markdown,
None,
Unknown,
@ -45,6 +46,15 @@ const MARKDOWN_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
optional_dependency: false,
};
const POSTGRES_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "postgres",
label: "PostgreSQL — remote durable storage via [storage.provider.config]",
auto_save_default: true,
uses_sqlite_hygiene: false,
sqlite_based: false,
optional_dependency: false,
};
const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile {
key: "none",
label: "None — disable persistent memory",
@ -82,6 +92,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind {
match backend {
"sqlite" => MemoryBackendKind::Sqlite,
"lucid" => MemoryBackendKind::Lucid,
"postgres" => MemoryBackendKind::Postgres,
"markdown" => MemoryBackendKind::Markdown,
"none" => MemoryBackendKind::None,
_ => MemoryBackendKind::Unknown,
@ -92,6 +103,7 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile {
match classify_memory_backend(backend) {
MemoryBackendKind::Sqlite => SQLITE_PROFILE,
MemoryBackendKind::Lucid => LUCID_PROFILE,
MemoryBackendKind::Postgres => POSTGRES_PROFILE,
MemoryBackendKind::Markdown => MARKDOWN_PROFILE,
MemoryBackendKind::None => NONE_PROFILE,
MemoryBackendKind::Unknown => CUSTOM_PROFILE,
@ -106,6 +118,10 @@ mod tests {
fn classify_known_backends() {
assert_eq!(classify_memory_backend("sqlite"), MemoryBackendKind::Sqlite);
assert_eq!(classify_memory_backend("lucid"), MemoryBackendKind::Lucid);
assert_eq!(
classify_memory_backend("postgres"),
MemoryBackendKind::Postgres
);
assert_eq!(
classify_memory_backend("markdown"),
MemoryBackendKind::Markdown

View file

@ -5,6 +5,7 @@ pub mod hygiene;
pub mod lucid;
pub mod markdown;
pub mod none;
pub mod postgres;
pub mod response_cache;
pub mod snapshot;
pub mod sqlite;
@ -19,24 +20,28 @@ pub use backend::{
pub use lucid::LucidMemory;
pub use markdown::MarkdownMemory;
pub use none::NoneMemory;
pub use postgres::PostgresMemory;
pub use response_cache::ResponseCache;
pub use sqlite::SqliteMemory;
pub use traits::Memory;
#[allow(unused_imports)]
pub use traits::{MemoryCategory, MemoryEntry};
use crate::config::MemoryConfig;
use crate::config::{MemoryConfig, StorageProviderConfig};
use anyhow::Context;
use std::path::Path;
use std::sync::Arc;
fn create_memory_with_sqlite_builder<F>(
fn create_memory_with_builders<F, G>(
backend_name: &str,
workspace_dir: &Path,
mut sqlite_builder: F,
mut postgres_builder: G,
unknown_context: &str,
) -> anyhow::Result<Box<dyn Memory>>
where
F: FnMut() -> anyhow::Result<SqliteMemory>,
G: FnMut() -> anyhow::Result<PostgresMemory>,
{
match classify_memory_backend(backend_name) {
MemoryBackendKind::Sqlite => Ok(Box::new(sqlite_builder()?)),
@ -44,6 +49,7 @@ where
let local = sqlite_builder()?;
Ok(Box::new(LucidMemory::new(workspace_dir, local)))
}
MemoryBackendKind::Postgres => Ok(Box::new(postgres_builder()?)),
MemoryBackendKind::Markdown => Ok(Box::new(MarkdownMemory::new(workspace_dir))),
MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())),
MemoryBackendKind::Unknown => {
@ -55,19 +61,52 @@ where
}
}
pub fn effective_memory_backend_name(
memory_backend: &str,
storage_provider: Option<&StorageProviderConfig>,
) -> String {
if let Some(override_provider) = storage_provider
.map(|cfg| cfg.provider.trim())
.filter(|provider| !provider.is_empty())
{
return override_provider.to_ascii_lowercase();
}
memory_backend.trim().to_ascii_lowercase()
}
/// Factory: create the right memory backend from config
pub fn create_memory(
config: &MemoryConfig,
workspace_dir: &Path,
api_key: Option<&str>,
) -> anyhow::Result<Box<dyn Memory>> {
create_memory_with_storage(config, None, workspace_dir, api_key)
}
/// Factory: create memory with optional storage-provider override.
pub fn create_memory_with_storage(
config: &MemoryConfig,
storage_provider: Option<&StorageProviderConfig>,
workspace_dir: &Path,
api_key: Option<&str>,
) -> anyhow::Result<Box<dyn Memory>> {
let backend_name = effective_memory_backend_name(&config.backend, storage_provider);
let backend_kind = classify_memory_backend(&backend_name);
// Best-effort memory hygiene/retention pass (throttled by state file).
if let Err(e) = hygiene::run_if_due(config, workspace_dir) {
tracing::warn!("memory hygiene skipped: {e}");
}
// If snapshot_on_hygiene is enabled, export core memories during hygiene.
if config.snapshot_enabled && config.snapshot_on_hygiene {
if config.snapshot_enabled
&& config.snapshot_on_hygiene
&& matches!(
backend_kind,
MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid
)
{
if let Err(e) = snapshot::export_snapshot(workspace_dir) {
tracing::warn!("memory snapshot skipped: {e}");
}
@ -77,7 +116,7 @@ pub fn create_memory(
// restore the "soul" from the snapshot before creating the backend.
if config.auto_hydrate
&& matches!(
classify_memory_backend(&config.backend),
backend_kind,
MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid
)
&& snapshot::should_hydrate(workspace_dir)
@ -120,10 +159,33 @@ pub fn create_memory(
Ok(mem)
}
create_memory_with_sqlite_builder(
&config.backend,
fn build_postgres_memory(
storage_provider: Option<&StorageProviderConfig>,
) -> anyhow::Result<PostgresMemory> {
let storage_provider = storage_provider
.context("memory backend 'postgres' requires [storage.provider.config] settings")?;
let db_url = storage_provider
.db_url
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.context(
"memory backend 'postgres' requires [storage.provider.config].db_url (or dbURL)",
)?;
PostgresMemory::new(
db_url,
&storage_provider.schema,
&storage_provider.table,
storage_provider.connect_timeout_secs,
)
}
create_memory_with_builders(
&backend_name,
workspace_dir,
|| build_sqlite_memory(config, workspace_dir, api_key),
|| build_postgres_memory(storage_provider),
"",
)
}
@ -138,10 +200,20 @@ pub fn create_memory_for_migration(
);
}
create_memory_with_sqlite_builder(
if matches!(
classify_memory_backend(backend),
MemoryBackendKind::Postgres
) {
anyhow::bail!(
"memory migration for backend 'postgres' is unsupported; migrate with sqlite or markdown first"
);
}
create_memory_with_builders(
backend,
workspace_dir,
|| SqliteMemory::new(workspace_dir),
|| anyhow::bail!("postgres backend is not available in migration context"),
" during migration",
)
}
@ -175,6 +247,7 @@ pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Opt
#[cfg(test)]
mod tests {
use super::*;
use crate::config::StorageProviderConfig;
use tempfile::TempDir;
#[test]
@ -247,4 +320,37 @@ mod tests {
.expect("backend=none should be rejected for migration");
assert!(error.to_string().contains("disables persistence"));
}
#[test]
fn effective_backend_name_prefers_storage_override() {
let storage = StorageProviderConfig {
provider: "postgres".into(),
..StorageProviderConfig::default()
};
assert_eq!(
effective_memory_backend_name("sqlite", Some(&storage)),
"postgres"
);
}
#[test]
fn factory_postgres_without_db_url_is_rejected() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "postgres".into(),
..MemoryConfig::default()
};
let storage = StorageProviderConfig {
provider: "postgres".into(),
db_url: None,
..StorageProviderConfig::default()
};
let error = create_memory_with_storage(&cfg, Some(&storage), tmp.path(), None)
.err()
.expect("postgres without db_url should be rejected");
assert!(error.to_string().contains("db_url"));
}
}

352
src/memory/postgres.rs Normal file
View file

@ -0,0 +1,352 @@
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use anyhow::{Context, Result};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use parking_lot::Mutex;
use postgres::{Client, NoTls, Row};
use std::sync::Arc;
use std::time::Duration;
use uuid::Uuid;
/// Maximum allowed connect timeout (seconds) to avoid unreasonable waits.
const POSTGRES_CONNECT_TIMEOUT_CAP_SECS: u64 = 300;
/// PostgreSQL-backed persistent memory.
///
/// This backend focuses on reliable CRUD and keyword recall using SQL, without
/// requiring extension setup (for example pgvector).
pub struct PostgresMemory {
client: Arc<Mutex<Client>>,
qualified_table: String,
}
impl PostgresMemory {
pub fn new(
db_url: &str,
schema: &str,
table: &str,
connect_timeout_secs: Option<u64>,
) -> Result<Self> {
validate_identifier(schema, "storage schema")?;
validate_identifier(table, "storage table")?;
let mut config: postgres::Config = db_url
.parse()
.context("invalid PostgreSQL connection URL")?;
if let Some(timeout_secs) = connect_timeout_secs {
let bounded = timeout_secs.min(POSTGRES_CONNECT_TIMEOUT_CAP_SECS);
config.connect_timeout(Duration::from_secs(bounded));
}
let mut client = config
.connect(NoTls)
.context("failed to connect to PostgreSQL memory backend")?;
let schema_ident = quote_identifier(schema);
let table_ident = quote_identifier(table);
let qualified_table = format!("{schema_ident}.{table_ident}");
Self::init_schema(&mut client, &schema_ident, &qualified_table)?;
Ok(Self {
client: Arc::new(Mutex::new(client)),
qualified_table,
})
}
fn init_schema(client: &mut Client, schema_ident: &str, qualified_table: &str) -> Result<()> {
client.batch_execute(&format!(
"
CREATE SCHEMA IF NOT EXISTS {schema_ident};
CREATE TABLE IF NOT EXISTS {qualified_table} (
id TEXT PRIMARY KEY,
key TEXT UNIQUE NOT NULL,
content TEXT NOT NULL,
category TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
session_id TEXT
);
CREATE INDEX IF NOT EXISTS idx_memories_category ON {qualified_table}(category);
CREATE INDEX IF NOT EXISTS idx_memories_session_id ON {qualified_table}(session_id);
CREATE INDEX IF NOT EXISTS idx_memories_updated_at ON {qualified_table}(updated_at DESC);
"
))?;
Ok(())
}
fn category_to_str(category: &MemoryCategory) -> String {
match category {
MemoryCategory::Core => "core".to_string(),
MemoryCategory::Daily => "daily".to_string(),
MemoryCategory::Conversation => "conversation".to_string(),
MemoryCategory::Custom(name) => name.clone(),
}
}
fn parse_category(value: &str) -> MemoryCategory {
match value {
"core" => MemoryCategory::Core,
"daily" => MemoryCategory::Daily,
"conversation" => MemoryCategory::Conversation,
other => MemoryCategory::Custom(other.to_string()),
}
}
fn row_to_entry(row: &Row) -> Result<MemoryEntry> {
let timestamp: DateTime<Utc> = row.get(4);
Ok(MemoryEntry {
id: row.get(0),
key: row.get(1),
content: row.get(2),
category: Self::parse_category(&row.get::<_, String>(3)),
timestamp: timestamp.to_rfc3339(),
session_id: row.get(5),
score: row.try_get(6).ok(),
})
}
}
fn validate_identifier(value: &str, field_name: &str) -> Result<()> {
if value.is_empty() {
anyhow::bail!("{field_name} must not be empty");
}
let mut chars = value.chars();
let Some(first) = chars.next() else {
anyhow::bail!("{field_name} must not be empty");
};
if !(first.is_ascii_alphabetic() || first == '_') {
anyhow::bail!("{field_name} must start with an ASCII letter or underscore; got '{value}'");
}
if !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') {
anyhow::bail!(
"{field_name} can only contain ASCII letters, numbers, and underscores; got '{value}'"
);
}
Ok(())
}
fn quote_identifier(value: &str) -> String {
format!("\"{value}\"")
}
#[async_trait]
impl Memory for PostgresMemory {
fn name(&self) -> &str {
"postgres"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> Result<()> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
let key = key.to_string();
let content = content.to_string();
let category = Self::category_to_str(&category);
let session_id = session_id.map(str::to_string);
tokio::task::spawn_blocking(move || -> Result<()> {
let now = Utc::now();
let mut client = client.lock();
let stmt = format!(
"
INSERT INTO {qualified_table}
(id, key, content, category, created_at, updated_at, session_id)
VALUES
($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT (key) DO UPDATE SET
content = EXCLUDED.content,
category = EXCLUDED.category,
updated_at = EXCLUDED.updated_at,
session_id = EXCLUDED.session_id
"
);
let id = Uuid::new_v4().to_string();
client.execute(
&stmt,
&[&id, &key, &content, &category, &now, &now, &session_id],
)?;
Ok(())
})
.await?
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> Result<Vec<MemoryEntry>> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
let query = query.trim().to_string();
let session_id = session_id.map(str::to_string);
tokio::task::spawn_blocking(move || -> Result<Vec<MemoryEntry>> {
let mut client = client.lock();
let stmt = format!(
"
SELECT id, key, content, category, created_at, session_id,
(
CASE WHEN key ILIKE '%' || $1 || '%' THEN 2.0 ELSE 0.0 END +
CASE WHEN content ILIKE '%' || $1 || '%' THEN 1.0 ELSE 0.0 END
) AS score
FROM {qualified_table}
WHERE ($2::TEXT IS NULL OR session_id = $2)
AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%')
ORDER BY score DESC, updated_at DESC
LIMIT $3
"
);
#[allow(clippy::cast_possible_wrap)]
let limit_i64 = limit as i64;
let rows = client.query(&stmt, &[&query, &session_id, &limit_i64])?;
rows.iter()
.map(Self::row_to_entry)
.collect::<Result<Vec<MemoryEntry>>>()
})
.await?
}
async fn get(&self, key: &str) -> Result<Option<MemoryEntry>> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || -> Result<Option<MemoryEntry>> {
let mut client = client.lock();
let stmt = format!(
"
SELECT id, key, content, category, created_at, session_id
FROM {qualified_table}
WHERE key = $1
LIMIT 1
"
);
let row = client.query_opt(&stmt, &[&key])?;
row.as_ref().map(Self::row_to_entry).transpose()
})
.await?
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> Result<Vec<MemoryEntry>> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
let category = category.map(Self::category_to_str);
let session_id = session_id.map(str::to_string);
tokio::task::spawn_blocking(move || -> Result<Vec<MemoryEntry>> {
let mut client = client.lock();
let stmt = format!(
"
SELECT id, key, content, category, created_at, session_id
FROM {qualified_table}
WHERE ($1::TEXT IS NULL OR category = $1)
AND ($2::TEXT IS NULL OR session_id = $2)
ORDER BY updated_at DESC
"
);
let category_ref = category.as_deref();
let session_ref = session_id.as_deref();
let rows = client.query(&stmt, &[&category_ref, &session_ref])?;
rows.iter()
.map(Self::row_to_entry)
.collect::<Result<Vec<MemoryEntry>>>()
})
.await?
}
async fn forget(&self, key: &str) -> Result<bool> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || -> Result<bool> {
let mut client = client.lock();
let stmt = format!("DELETE FROM {qualified_table} WHERE key = $1");
let deleted = client.execute(&stmt, &[&key])?;
Ok(deleted > 0)
})
.await?
}
async fn count(&self) -> Result<usize> {
let client = self.client.clone();
let qualified_table = self.qualified_table.clone();
tokio::task::spawn_blocking(move || -> Result<usize> {
let mut client = client.lock();
let stmt = format!("SELECT COUNT(*) FROM {qualified_table}");
let count: i64 = client.query_one(&stmt, &[])?.get(0);
let count =
usize::try_from(count).context("PostgreSQL returned a negative memory count")?;
Ok(count)
})
.await?
}
async fn health_check(&self) -> bool {
let client = self.client.clone();
tokio::task::spawn_blocking(move || client.lock().simple_query("SELECT 1").is_ok())
.await
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn valid_identifiers_pass_validation() {
assert!(validate_identifier("public", "schema").is_ok());
assert!(validate_identifier("_memories_01", "table").is_ok());
}
#[test]
fn invalid_identifiers_are_rejected() {
assert!(validate_identifier("", "schema").is_err());
assert!(validate_identifier("1bad", "schema").is_err());
assert!(validate_identifier("bad-name", "table").is_err());
}
#[test]
fn parse_category_maps_known_and_custom_values() {
assert_eq!(PostgresMemory::parse_category("core"), MemoryCategory::Core);
assert_eq!(
PostgresMemory::parse_category("daily"),
MemoryCategory::Daily
);
assert_eq!(
PostgresMemory::parse_category("conversation"),
MemoryCategory::Conversation
);
assert_eq!(
PostgresMemory::parse_category("custom_notes"),
MemoryCategory::Custom("custom_notes".into())
);
}
}

View file

@ -2,7 +2,7 @@ use crate::config::schema::{DingTalkConfig, IrcConfig, QQConfig, StreamMode, Wha
use crate::config::{
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig,
RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, WebhookConfig,
RuntimeConfig, SecretsConfig, SlackConfig, StorageConfig, TelegramConfig, WebhookConfig,
};
use crate::hardware::{self, HardwareConfig};
use crate::memory::{
@ -125,6 +125,7 @@ pub fn run_wizard() -> Result<Config> {
cron: crate::config::CronConfig::default(),
channels_config,
memory: memory_config, // User-selected memory backend
storage: StorageConfig::default(),
tunnel: tunnel_config,
gateway: crate::config::GatewayConfig::default(),
composio: composio_config,
@ -347,6 +348,7 @@ pub fn run_quick_setup(
cron: crate::config::CronConfig::default(),
channels_config: ChannelsConfig::default(),
memory: memory_config,
storage: StorageConfig::default(),
tunnel: crate::config::TunnelConfig::default(),
gateway: crate::config::GatewayConfig::default(),
composio: ComposioConfig::default(),