From 483acccdb7cac864b2f425ba40e478ebf396d44a Mon Sep 17 00:00:00 2001 From: Chummy Date: Wed, 18 Feb 2026 17:58:23 +0800 Subject: [PATCH] feat(memory): add configurable postgres storage backend --- Cargo.lock | 157 +++++++++++++++++- Cargo.toml | 1 + README.md | 25 ++- src/agent/agent.rs | 3 +- src/agent/loop_.rs | 6 +- src/channels/mod.rs | 9 +- src/config/mod.rs | 5 +- src/config/schema.rs | 179 ++++++++++++++++++++- src/gateway/mod.rs | 3 +- src/main.rs | 6 +- src/memory/backend.rs | 16 ++ src/memory/mod.rs | 120 +++++++++++++- src/memory/postgres.rs | 352 +++++++++++++++++++++++++++++++++++++++++ src/onboard/wizard.rs | 4 +- 14 files changed, 859 insertions(+), 27 deletions(-) create mode 100644 src/memory/postgres.rs diff --git a/Cargo.lock b/Cargo.lock index 97819fa..76277e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -544,7 +544,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a6139a8597ed92cf816dfb33f5dd6cf0bb93a6adc938f11039f371bc5bcd26c3" dependencies = [ "chrono", - "phf", + "phf 0.12.1", ] [[package]] @@ -1231,6 +1231,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "fallible-iterator" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7" + [[package]] name = "fallible-iterator" version = "0.3.0" @@ -1453,7 +1459,7 @@ dependencies = [ "cfg-if", "js-sys", "libc", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "wasm-bindgen", ] @@ -1490,7 +1496,7 @@ version = "0.32.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" dependencies = [ - "fallible-iterator", + "fallible-iterator 0.3.0", "indexmap", "stable_deref_trait", ] @@ -2394,7 +2400,7 @@ checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "log", - "wasi", + "wasi 0.11.1+wasi-snapshot-preview1", "windows-sys 0.61.2", ] @@ -2557,6 +2563,24 @@ dependencies = [ "windows-sys 0.60.2", ] +[[package]] +name = "objc2-core-foundation" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" +dependencies = [ + "bitflags 2.11.0", +] + +[[package]] +name = "objc2-system-configuration" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7216bd11cbda54ccabcab84d523dc93b858ec75ecfb3a7d89513fa22464da396" +dependencies = [ + "objc2-core-foundation", +] + [[package]] name = "object" version = "0.37.3" @@ -2750,7 +2774,17 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "913273894cec178f401a31ec4b656318d95473527be05c0752cc41cdc32be8b7" dependencies = [ - "phf_shared", + "phf_shared 0.12.1", +] + +[[package]] +name = "phf" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1562dc717473dbaa4c1f85a36410e03c047b2e7df7f45ee938fbef64ae7fadf" +dependencies = [ + "phf_shared 0.13.1", + "serde", ] [[package]] @@ -2762,6 +2796,15 @@ dependencies = [ "siphasher", ] +[[package]] +name = "phf_shared" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e57fef6bc5981e38c2ce2d63bfa546861309f875b8a75f092d1d54ae2d64f266" +dependencies = [ + "siphasher", +] + [[package]] name = "pin-project" version = "1.1.10" @@ -2859,6 +2902,50 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60f6ce597ecdcc9a098e7fddacb1065093a3d66446fa16c675e7e71d1b5c28e6" +[[package]] +name = "postgres" +version = "0.19.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7c48ece1c6cda0db61b058c1721378da76855140e9214339fa1317decacb176" +dependencies = [ + "bytes", + "fallible-iterator 0.2.0", + "futures-util", + "log", + "tokio", + "tokio-postgres", +] + +[[package]] +name = "postgres-protocol" +version = "0.6.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ee9dd5fe15055d2b6806f4736aa0c9637217074e224bbec46d4041b91bb9491" +dependencies = [ + "base64", + "byteorder", + "bytes", + "fallible-iterator 0.2.0", + "hmac", + "md-5", + "memchr", + "rand 0.9.2", + "sha2", + "stringprep", +] + +[[package]] +name = "postgres-types" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54b858f82211e84682fecd373f68e1ceae642d8d751a1ebd13f33de6257b3e20" +dependencies = [ + "bytes", + "chrono", + "fallible-iterator 0.2.0", + "postgres-protocol", +] + [[package]] name = "postscript" version = "0.14.1" @@ -3352,7 +3439,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1c93dd1c9683b438c392c492109cb702b8090b2bfc8fed6f6e4eb4523f17af3" dependencies = [ "bitflags 2.11.0", - "fallible-iterator", + "fallible-iterator 0.3.0", "fallible-streaming-iterator", "hashlink", "libsqlite3-sys", @@ -4051,6 +4138,32 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio-postgres" +version = "0.7.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcea47c8f71744367793f16c2db1f11cb859d28f436bdb4ca9193eb1f787ee42" +dependencies = [ + "async-trait", + "byteorder", + "bytes", + "fallible-iterator 0.2.0", + "futures-channel", + "futures-util", + "log", + "parking_lot", + "percent-encoding", + "phf 0.13.1", + "pin-project-lite", + "postgres-protocol", + "postgres-types", + "rand 0.9.2", + "socket2", + "tokio", + "tokio-util", + "whoami", +] + [[package]] name = "tokio-rustls" version = "0.26.4" @@ -4632,6 +4745,15 @@ version = "0.11.1+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ccf3ec651a847eb01de73ccad15eb7d99f80485de043efb2f370cd654f4ea44b" +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + [[package]] name = "wasip2" version = "1.0.2+wasi-0.2.9" @@ -4650,6 +4772,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasite" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66fe902b4a6b8028a753d5424909b764ccf79b7a209eac9bf97e59cda9f71a42" +dependencies = [ + "wasi 0.14.7+wasi-0.2.4", +] + [[package]] name = "wasm-bindgen" version = "0.2.108" @@ -4820,6 +4951,19 @@ version = "0.1.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a28ac98ddc8b9274cb41bb4d9d4d5c425b6020c50c46f25559911905610b4a88" +[[package]] +name = "whoami" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6a5b12f9df4f978d2cfdb1bd3bac52433f44393342d7ee9c25f5a1c14c0f45d" +dependencies = [ + "libc", + "libredox", + "objc2-system-configuration", + "wasite", + "web-sys", +] + [[package]] name = "winapi" version = "0.3.9" @@ -5350,6 +5494,7 @@ dependencies = [ "opentelemetry_sdk", "parking_lot", "pdf-extract", + "postgres", "probe-rs", "prometheus", "prost", diff --git a/Cargo.toml b/Cargo.toml index c1bbccc..fda7f36 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,6 +82,7 @@ prost = { version = "0.14", default-features = false } # Memory / persistence rusqlite = { version = "0.38", features = ["bundled"] } +postgres = { version = "0.19", features = ["with-chrono-0_4"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] } chrono-tz = "0.10" cron = "0.15" diff --git a/README.md b/README.md index addee32..a217a0a 100644 --- a/README.md +++ b/README.md @@ -309,7 +309,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze |-----------|-------|------------|--------| | **AI Models** | `Provider` | Provider catalog via `zeroclaw providers` (currently 28 built-ins + aliases, plus custom endpoints) | `custom:https://your-api.com` (OpenAI-compatible) or `anthropic-custom:https://your-api.com` | | **Channels** | `Channel` | CLI, Telegram, Discord, Slack, Mattermost, iMessage, Matrix, Signal, WhatsApp, Email, IRC, Lark, DingTalk, QQ, Webhook | Any messaging API | -| **Memory** | `Memory` | SQLite hybrid search, Lucid bridge, Markdown files, explicit `none` backend, snapshot/hydrate, optional response cache | Any persistence backend | +| **Memory** | `Memory` | SQLite hybrid search, PostgreSQL backend (configurable storage provider), Lucid bridge, Markdown files, explicit `none` backend, snapshot/hydrate, optional response cache | Any persistence backend | | **Tools** | `Tool` | shell/file/memory, cron/schedule, git, pushover, browser, http_request, screenshot/image_info, composio (opt-in), delegate, hardware tools | Any capability | | **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel | | **Runtime** | `RuntimeAdapter` | Native, Docker (sandboxed) | Additional runtimes can be added via adapter; unsupported kinds fail fast | @@ -345,7 +345,7 @@ The agent automatically recalls, saves, and manages memory via tools. ```toml [memory] -backend = "sqlite" # "sqlite", "lucid", "markdown", "none" +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" auto_save = true embedding_provider = "none" # "none", "openai", "custom:https://..." vector_weight = 0.7 @@ -353,6 +353,17 @@ keyword_weight = 0.3 # backend = "none" uses an explicit no-op memory backend (no persistence) +# Optional: storage-provider override for remote memory backends. +# When provider = "postgres", ZeroClaw uses PostgreSQL for memory persistence. +# The db_url key also accepts alias `dbURL` for backward compatibility. +# +# [storage.provider.config] +# provider = "postgres" +# db_url = "postgres://user:password@host:5432/zeroclaw" +# schema = "public" +# table = "memories" +# connect_timeout_secs = 15 + # Optional for backend = "sqlite": max seconds to wait when opening the DB (e.g. file locked). Omit or leave unset for no timeout. # sqlite_open_timeout_secs = 30 @@ -493,7 +504,7 @@ default_temperature = 0.7 # default_provider = "anthropic-custom:https://your-api.com" [memory] -backend = "sqlite" # "sqlite", "lucid", "markdown", "none" +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" auto_save = true embedding_provider = "none" # "none", "openai", "custom:https://..." vector_weight = 0.7 @@ -501,6 +512,14 @@ keyword_weight = 0.3 # backend = "none" disables persistent memory via no-op backend +# Optional remote storage-provider override (PostgreSQL example) +# [storage.provider.config] +# provider = "postgres" +# db_url = "postgres://user:password@host:5432/zeroclaw" +# schema = "public" +# table = "memories" +# connect_timeout_secs = 15 + [gateway] port = 3000 # default host = "127.0.0.1" # default diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 0002799..dc8f74d 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -229,8 +229,9 @@ impl Agent { &config.workspace_dir, )); - let memory: Arc = Arc::from(memory::create_memory( + let memory: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, + Some(&config.storage.provider.config), &config.workspace_dir, config.api_key.as_deref(), )?); diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 7aecc9a..455f588 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1165,8 +1165,9 @@ pub async fn run( )); // ── Memory (the brain) ──────────────────────────────────────── - let mem: Arc = Arc::from(memory::create_memory( + let mem: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, + Some(&config.storage.provider.config), &config.workspace_dir, config.api_key.as_deref(), )?); @@ -1625,8 +1626,9 @@ pub async fn process_message(config: Config, message: &str) -> Result { &config.autonomy, &config.workspace_dir, )); - let mem: Arc = Arc::from(memory::create_memory( + let mem: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, + Some(&config.storage.provider.config), &config.workspace_dir, config.api_key.as_deref(), )?); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 9fd9381..dcc55ff 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1148,8 +1148,9 @@ pub async fn start_channels(config: Config) -> Result<()> { .clone() .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); let temperature = config.default_temperature; - let mem: Arc = Arc::from(memory::create_memory( + let mem: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, + Some(&config.storage.provider.config), &config.workspace_dir, config.api_key.as_deref(), )?); @@ -1382,9 +1383,13 @@ pub async fn start_channels(config: Config) -> Result<()> { println!("🦀 ZeroClaw Channel Server"); println!(" 🤖 Model: {model}"); + let effective_backend = memory::effective_memory_backend_name( + &config.memory.backend, + Some(&config.storage.provider.config), + ); println!( " 🧠 Memory: {} (auto-save: {})", - config.memory.backend, + effective_backend, if config.memory.auto_save { "on" } else { "off" } ); println!( diff --git a/src/config/mod.rs b/src/config/mod.rs index 430e603..4521a4a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -9,8 +9,9 @@ pub use schema::{ LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, SchedulerConfig, - SecretsConfig, SecurityConfig, SlackConfig, StreamMode, TelegramConfig, TunnelConfig, - WebSearchConfig, WebhookConfig, + SecretsConfig, SecurityConfig, SlackConfig, StorageConfig, StorageProviderConfig, + StorageProviderSection, StreamMode, TelegramConfig, TunnelConfig, WebSearchConfig, + WebhookConfig, }; #[cfg(test)] diff --git a/src/config/schema.rs b/src/config/schema.rs index 2767201..778aa0b 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -63,6 +63,9 @@ pub struct Config { #[serde(default)] pub memory: MemoryConfig, + #[serde(default)] + pub storage: StorageConfig, + #[serde(default)] pub tunnel: TunnelConfig, @@ -771,10 +774,73 @@ impl Default for WebSearchConfig { // ── Memory ─────────────────────────────────────────────────── +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct StorageConfig { + #[serde(default)] + pub provider: StorageProviderSection, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +pub struct StorageProviderSection { + #[serde(default)] + pub config: StorageProviderConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct StorageProviderConfig { + /// Storage engine key (e.g. "postgres", "sqlite"). + #[serde(default)] + pub provider: String, + + /// Connection URL for remote providers. + /// Accepts legacy aliases: dbURL, database_url, databaseUrl. + #[serde( + default, + alias = "dbURL", + alias = "database_url", + alias = "databaseUrl" + )] + pub db_url: Option, + + /// Database schema for SQL backends. + #[serde(default = "default_storage_schema")] + pub schema: String, + + /// Table name for memory entries. + #[serde(default = "default_storage_table")] + pub table: String, + + /// Optional connection timeout in seconds for remote providers. + #[serde(default)] + pub connect_timeout_secs: Option, +} + +fn default_storage_schema() -> String { + "public".into() +} + +fn default_storage_table() -> String { + "memories".into() +} + +impl Default for StorageProviderConfig { + fn default() -> Self { + Self { + provider: String::new(), + db_url: None, + schema: default_storage_schema(), + table: default_storage_table(), + connect_timeout_secs: None, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[allow(clippy::struct_excessive_bools)] pub struct MemoryConfig { - /// "sqlite" | "lucid" | "markdown" | "none" (`none` = explicit no-op memory) + /// "sqlite" | "lucid" | "postgres" | "markdown" | "none" (`none` = explicit no-op memory) + /// + /// `postgres` requires `[storage.provider.config]` with `db_url` (`dbURL` alias supported). pub backend: String, /// Auto-save conversation context to memory pub auto_save: bool, @@ -1844,6 +1910,7 @@ impl Default for Config { cron: CronConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), + storage: StorageConfig::default(), tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), @@ -2113,6 +2180,12 @@ impl Config { "config.web_search.brave_api_key", )?; + decrypt_optional_secret( + &store, + &mut config.storage.provider.config.db_url, + "config.storage.provider.config.db_url", + )?; + for agent in config.agents.values_mut() { decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; } @@ -2266,6 +2339,31 @@ impl Config { } } } + + // Storage provider key (optional backend override): ZEROCLAW_STORAGE_PROVIDER + if let Ok(provider) = std::env::var("ZEROCLAW_STORAGE_PROVIDER") { + let provider = provider.trim(); + if !provider.is_empty() { + self.storage.provider.config.provider = provider.to_string(); + } + } + + // Storage connection URL (for remote backends): ZEROCLAW_STORAGE_DB_URL + if let Ok(db_url) = std::env::var("ZEROCLAW_STORAGE_DB_URL") { + let db_url = db_url.trim(); + if !db_url.is_empty() { + self.storage.provider.config.db_url = Some(db_url.to_string()); + } + } + + // Storage connect timeout: ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS + if let Ok(timeout_secs) = std::env::var("ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS") { + if let Ok(timeout_secs) = timeout_secs.parse::() { + if timeout_secs > 0 { + self.storage.provider.config.connect_timeout_secs = Some(timeout_secs); + } + } + } } pub fn save(&self) -> Result<()> { @@ -2296,6 +2394,12 @@ impl Config { "config.web_search.brave_api_key", )?; + encrypt_optional_secret( + &store, + &mut config_to_save.storage.provider.config.db_url, + "config.storage.provider.config.db_url", + )?; + for agent in config_to_save.agents.values_mut() { encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; } @@ -2483,6 +2587,16 @@ default_temperature = 0.7 assert!(m.sqlite_open_timeout_secs.is_none()); } + #[test] + fn storage_provider_config_defaults() { + let storage = StorageConfig::default(); + assert!(storage.provider.config.provider.is_empty()); + assert!(storage.provider.config.db_url.is_none()); + assert_eq!(storage.provider.config.schema, "public"); + assert_eq!(storage.provider.config.table, "memories"); + assert!(storage.provider.config.connect_timeout_secs.is_none()); + } + #[test] fn channels_config_default() { let c = ChannelsConfig::default(); @@ -2556,6 +2670,7 @@ default_temperature = 0.7 qq: None, }, memory: MemoryConfig::default(), + storage: StorageConfig::default(), tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), @@ -2612,6 +2727,33 @@ default_temperature = 0.7 assert_eq!(parsed.memory.conversation_retention_days, 30); } + #[test] + fn storage_provider_dburl_alias_deserializes() { + let raw = r#" +default_temperature = 0.7 + +[storage.provider.config] +provider = "postgres" +dbURL = "postgres://postgres:postgres@localhost:5432/zeroclaw" +schema = "public" +table = "memories" +connect_timeout_secs = 12 +"#; + + let parsed: Config = toml::from_str(raw).unwrap(); + assert_eq!(parsed.storage.provider.config.provider, "postgres"); + assert_eq!( + parsed.storage.provider.config.db_url.as_deref(), + Some("postgres://postgres:postgres@localhost:5432/zeroclaw") + ); + assert_eq!(parsed.storage.provider.config.schema, "public"); + assert_eq!(parsed.storage.provider.config.table, "memories"); + assert_eq!( + parsed.storage.provider.config.connect_timeout_secs, + Some(12) + ); + } + #[test] fn agent_config_defaults() { let cfg = AgentConfig::default(); @@ -2667,6 +2809,7 @@ tool_dispatcher = "xml" cron: CronConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), + storage: StorageConfig::default(), tunnel: TunnelConfig::default(), gateway: GatewayConfig::default(), composio: ComposioConfig::default(), @@ -2715,6 +2858,7 @@ tool_dispatcher = "xml" config.composio.api_key = Some("composio-credential".into()); config.browser.computer_use.api_key = Some("browser-credential".into()); config.web_search.brave_api_key = Some("brave-credential".into()); + config.storage.provider.config.db_url = Some("postgres://user:pw@host/db".into()); config.agents.insert( "worker".into(), @@ -2770,6 +2914,13 @@ tool_dispatcher = "xml" assert!(crate::security::SecretStore::is_encrypted(worker_encrypted)); assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential"); + let storage_db_url = stored.storage.provider.config.db_url.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(storage_db_url)); + assert_eq!( + store.decrypt(storage_db_url).unwrap(), + "postgres://user:pw@host/db" + ); + let _ = fs::remove_dir_all(&dir); } @@ -3927,6 +4078,32 @@ default_model = "legacy-model" std::env::remove_var("WEB_SEARCH_TIMEOUT_SECS"); } + #[test] + fn env_override_storage_provider_config() { + let _env_guard = env_override_test_guard(); + let mut config = Config::default(); + + std::env::set_var("ZEROCLAW_STORAGE_PROVIDER", "postgres"); + std::env::set_var("ZEROCLAW_STORAGE_DB_URL", "postgres://example/db"); + std::env::set_var("ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS", "15"); + + config.apply_env_overrides(); + + assert_eq!(config.storage.provider.config.provider, "postgres"); + assert_eq!( + config.storage.provider.config.db_url.as_deref(), + Some("postgres://example/db") + ); + assert_eq!( + config.storage.provider.config.connect_timeout_secs, + Some(15) + ); + + std::env::remove_var("ZEROCLAW_STORAGE_PROVIDER"); + std::env::remove_var("ZEROCLAW_STORAGE_DB_URL"); + std::env::remove_var("ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS"); + } + #[test] fn gateway_config_default_values() { let g = GatewayConfig::default(); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 45f9734..3027638 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -313,8 +313,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .clone() .unwrap_or_else(|| "anthropic/claude-sonnet-4".into()); let temperature = config.default_temperature; - let mem: Arc = Arc::from(memory::create_memory( + let mem: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, + Some(&config.storage.provider.config), &config.workspace_dir, config.api_key.as_deref(), )?); diff --git a/src/main.rs b/src/main.rs index 21b22f3..a8413cd 100644 --- a/src/main.rs +++ b/src/main.rs @@ -576,6 +576,10 @@ async fn main() -> Result<()> { println!("📊 Observability: {}", config.observability.backend); println!("🛡️ Autonomy: {:?}", config.autonomy.level); println!("⚙️ Runtime: {}", config.runtime.kind); + let effective_memory_backend = memory::effective_memory_backend_name( + &config.memory.backend, + Some(&config.storage.provider.config), + ); println!( "💓 Heartbeat: {}", if config.heartbeat.enabled { @@ -586,7 +590,7 @@ async fn main() -> Result<()> { ); println!( "🧠 Memory: {} (auto-save: {})", - config.memory.backend, + effective_memory_backend, if config.memory.auto_save { "on" } else { "off" } ); diff --git a/src/memory/backend.rs b/src/memory/backend.rs index 8ba7ec3..14a57bc 100644 --- a/src/memory/backend.rs +++ b/src/memory/backend.rs @@ -2,6 +2,7 @@ pub enum MemoryBackendKind { Sqlite, Lucid, + Postgres, Markdown, None, Unknown, @@ -45,6 +46,15 @@ const MARKDOWN_PROFILE: MemoryBackendProfile = MemoryBackendProfile { optional_dependency: false, }; +const POSTGRES_PROFILE: MemoryBackendProfile = MemoryBackendProfile { + key: "postgres", + label: "PostgreSQL — remote durable storage via [storage.provider.config]", + auto_save_default: true, + uses_sqlite_hygiene: false, + sqlite_based: false, + optional_dependency: false, +}; + const NONE_PROFILE: MemoryBackendProfile = MemoryBackendProfile { key: "none", label: "None — disable persistent memory", @@ -82,6 +92,7 @@ pub fn classify_memory_backend(backend: &str) -> MemoryBackendKind { match backend { "sqlite" => MemoryBackendKind::Sqlite, "lucid" => MemoryBackendKind::Lucid, + "postgres" => MemoryBackendKind::Postgres, "markdown" => MemoryBackendKind::Markdown, "none" => MemoryBackendKind::None, _ => MemoryBackendKind::Unknown, @@ -92,6 +103,7 @@ pub fn memory_backend_profile(backend: &str) -> MemoryBackendProfile { match classify_memory_backend(backend) { MemoryBackendKind::Sqlite => SQLITE_PROFILE, MemoryBackendKind::Lucid => LUCID_PROFILE, + MemoryBackendKind::Postgres => POSTGRES_PROFILE, MemoryBackendKind::Markdown => MARKDOWN_PROFILE, MemoryBackendKind::None => NONE_PROFILE, MemoryBackendKind::Unknown => CUSTOM_PROFILE, @@ -106,6 +118,10 @@ mod tests { fn classify_known_backends() { assert_eq!(classify_memory_backend("sqlite"), MemoryBackendKind::Sqlite); assert_eq!(classify_memory_backend("lucid"), MemoryBackendKind::Lucid); + assert_eq!( + classify_memory_backend("postgres"), + MemoryBackendKind::Postgres + ); assert_eq!( classify_memory_backend("markdown"), MemoryBackendKind::Markdown diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 6798ee4..b4ea5e7 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -5,6 +5,7 @@ pub mod hygiene; pub mod lucid; pub mod markdown; pub mod none; +pub mod postgres; pub mod response_cache; pub mod snapshot; pub mod sqlite; @@ -19,24 +20,28 @@ pub use backend::{ pub use lucid::LucidMemory; pub use markdown::MarkdownMemory; pub use none::NoneMemory; +pub use postgres::PostgresMemory; pub use response_cache::ResponseCache; pub use sqlite::SqliteMemory; pub use traits::Memory; #[allow(unused_imports)] pub use traits::{MemoryCategory, MemoryEntry}; -use crate::config::MemoryConfig; +use crate::config::{MemoryConfig, StorageProviderConfig}; +use anyhow::Context; use std::path::Path; use std::sync::Arc; -fn create_memory_with_sqlite_builder( +fn create_memory_with_builders( backend_name: &str, workspace_dir: &Path, mut sqlite_builder: F, + mut postgres_builder: G, unknown_context: &str, ) -> anyhow::Result> where F: FnMut() -> anyhow::Result, + G: FnMut() -> anyhow::Result, { match classify_memory_backend(backend_name) { MemoryBackendKind::Sqlite => Ok(Box::new(sqlite_builder()?)), @@ -44,6 +49,7 @@ where let local = sqlite_builder()?; Ok(Box::new(LucidMemory::new(workspace_dir, local))) } + MemoryBackendKind::Postgres => Ok(Box::new(postgres_builder()?)), MemoryBackendKind::Markdown => Ok(Box::new(MarkdownMemory::new(workspace_dir))), MemoryBackendKind::None => Ok(Box::new(NoneMemory::new())), MemoryBackendKind::Unknown => { @@ -55,19 +61,52 @@ where } } +pub fn effective_memory_backend_name( + memory_backend: &str, + storage_provider: Option<&StorageProviderConfig>, +) -> String { + if let Some(override_provider) = storage_provider + .map(|cfg| cfg.provider.trim()) + .filter(|provider| !provider.is_empty()) + { + return override_provider.to_ascii_lowercase(); + } + + memory_backend.trim().to_ascii_lowercase() +} + /// Factory: create the right memory backend from config pub fn create_memory( config: &MemoryConfig, workspace_dir: &Path, api_key: Option<&str>, ) -> anyhow::Result> { + create_memory_with_storage(config, None, workspace_dir, api_key) +} + +/// Factory: create memory with optional storage-provider override. +pub fn create_memory_with_storage( + config: &MemoryConfig, + storage_provider: Option<&StorageProviderConfig>, + workspace_dir: &Path, + api_key: Option<&str>, +) -> anyhow::Result> { + let backend_name = effective_memory_backend_name(&config.backend, storage_provider); + let backend_kind = classify_memory_backend(&backend_name); + // Best-effort memory hygiene/retention pass (throttled by state file). if let Err(e) = hygiene::run_if_due(config, workspace_dir) { tracing::warn!("memory hygiene skipped: {e}"); } // If snapshot_on_hygiene is enabled, export core memories during hygiene. - if config.snapshot_enabled && config.snapshot_on_hygiene { + if config.snapshot_enabled + && config.snapshot_on_hygiene + && matches!( + backend_kind, + MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid + ) + { if let Err(e) = snapshot::export_snapshot(workspace_dir) { tracing::warn!("memory snapshot skipped: {e}"); } @@ -77,7 +116,7 @@ pub fn create_memory( // restore the "soul" from the snapshot before creating the backend. if config.auto_hydrate && matches!( - classify_memory_backend(&config.backend), + backend_kind, MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid ) && snapshot::should_hydrate(workspace_dir) @@ -120,10 +159,33 @@ pub fn create_memory( Ok(mem) } - create_memory_with_sqlite_builder( - &config.backend, + fn build_postgres_memory( + storage_provider: Option<&StorageProviderConfig>, + ) -> anyhow::Result { + let storage_provider = storage_provider + .context("memory backend 'postgres' requires [storage.provider.config] settings")?; + let db_url = storage_provider + .db_url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .context( + "memory backend 'postgres' requires [storage.provider.config].db_url (or dbURL)", + )?; + + PostgresMemory::new( + db_url, + &storage_provider.schema, + &storage_provider.table, + storage_provider.connect_timeout_secs, + ) + } + + create_memory_with_builders( + &backend_name, workspace_dir, || build_sqlite_memory(config, workspace_dir, api_key), + || build_postgres_memory(storage_provider), "", ) } @@ -138,10 +200,20 @@ pub fn create_memory_for_migration( ); } - create_memory_with_sqlite_builder( + if matches!( + classify_memory_backend(backend), + MemoryBackendKind::Postgres + ) { + anyhow::bail!( + "memory migration for backend 'postgres' is unsupported; migrate with sqlite or markdown first" + ); + } + + create_memory_with_builders( backend, workspace_dir, || SqliteMemory::new(workspace_dir), + || anyhow::bail!("postgres backend is not available in migration context"), " during migration", ) } @@ -175,6 +247,7 @@ pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Opt #[cfg(test)] mod tests { use super::*; + use crate::config::StorageProviderConfig; use tempfile::TempDir; #[test] @@ -247,4 +320,37 @@ mod tests { .expect("backend=none should be rejected for migration"); assert!(error.to_string().contains("disables persistence")); } + + #[test] + fn effective_backend_name_prefers_storage_override() { + let storage = StorageProviderConfig { + provider: "postgres".into(), + ..StorageProviderConfig::default() + }; + + assert_eq!( + effective_memory_backend_name("sqlite", Some(&storage)), + "postgres" + ); + } + + #[test] + fn factory_postgres_without_db_url_is_rejected() { + let tmp = TempDir::new().unwrap(); + let cfg = MemoryConfig { + backend: "postgres".into(), + ..MemoryConfig::default() + }; + + let storage = StorageProviderConfig { + provider: "postgres".into(), + db_url: None, + ..StorageProviderConfig::default() + }; + + let error = create_memory_with_storage(&cfg, Some(&storage), tmp.path(), None) + .err() + .expect("postgres without db_url should be rejected"); + assert!(error.to_string().contains("db_url")); + } } diff --git a/src/memory/postgres.rs b/src/memory/postgres.rs new file mode 100644 index 0000000..4f21293 --- /dev/null +++ b/src/memory/postgres.rs @@ -0,0 +1,352 @@ +use super::traits::{Memory, MemoryCategory, MemoryEntry}; +use anyhow::{Context, Result}; +use async_trait::async_trait; +use chrono::{DateTime, Utc}; +use parking_lot::Mutex; +use postgres::{Client, NoTls, Row}; +use std::sync::Arc; +use std::time::Duration; +use uuid::Uuid; + +/// Maximum allowed connect timeout (seconds) to avoid unreasonable waits. +const POSTGRES_CONNECT_TIMEOUT_CAP_SECS: u64 = 300; + +/// PostgreSQL-backed persistent memory. +/// +/// This backend focuses on reliable CRUD and keyword recall using SQL, without +/// requiring extension setup (for example pgvector). +pub struct PostgresMemory { + client: Arc>, + qualified_table: String, +} + +impl PostgresMemory { + pub fn new( + db_url: &str, + schema: &str, + table: &str, + connect_timeout_secs: Option, + ) -> Result { + validate_identifier(schema, "storage schema")?; + validate_identifier(table, "storage table")?; + + let mut config: postgres::Config = db_url + .parse() + .context("invalid PostgreSQL connection URL")?; + + if let Some(timeout_secs) = connect_timeout_secs { + let bounded = timeout_secs.min(POSTGRES_CONNECT_TIMEOUT_CAP_SECS); + config.connect_timeout(Duration::from_secs(bounded)); + } + + let mut client = config + .connect(NoTls) + .context("failed to connect to PostgreSQL memory backend")?; + + let schema_ident = quote_identifier(schema); + let table_ident = quote_identifier(table); + let qualified_table = format!("{schema_ident}.{table_ident}"); + + Self::init_schema(&mut client, &schema_ident, &qualified_table)?; + + Ok(Self { + client: Arc::new(Mutex::new(client)), + qualified_table, + }) + } + + fn init_schema(client: &mut Client, schema_ident: &str, qualified_table: &str) -> Result<()> { + client.batch_execute(&format!( + " + CREATE SCHEMA IF NOT EXISTS {schema_ident}; + + CREATE TABLE IF NOT EXISTS {qualified_table} ( + id TEXT PRIMARY KEY, + key TEXT UNIQUE NOT NULL, + content TEXT NOT NULL, + category TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL, + session_id TEXT + ); + + CREATE INDEX IF NOT EXISTS idx_memories_category ON {qualified_table}(category); + CREATE INDEX IF NOT EXISTS idx_memories_session_id ON {qualified_table}(session_id); + CREATE INDEX IF NOT EXISTS idx_memories_updated_at ON {qualified_table}(updated_at DESC); + " + ))?; + + Ok(()) + } + + fn category_to_str(category: &MemoryCategory) -> String { + match category { + MemoryCategory::Core => "core".to_string(), + MemoryCategory::Daily => "daily".to_string(), + MemoryCategory::Conversation => "conversation".to_string(), + MemoryCategory::Custom(name) => name.clone(), + } + } + + fn parse_category(value: &str) -> MemoryCategory { + match value { + "core" => MemoryCategory::Core, + "daily" => MemoryCategory::Daily, + "conversation" => MemoryCategory::Conversation, + other => MemoryCategory::Custom(other.to_string()), + } + } + + fn row_to_entry(row: &Row) -> Result { + let timestamp: DateTime = row.get(4); + + Ok(MemoryEntry { + id: row.get(0), + key: row.get(1), + content: row.get(2), + category: Self::parse_category(&row.get::<_, String>(3)), + timestamp: timestamp.to_rfc3339(), + session_id: row.get(5), + score: row.try_get(6).ok(), + }) + } +} + +fn validate_identifier(value: &str, field_name: &str) -> Result<()> { + if value.is_empty() { + anyhow::bail!("{field_name} must not be empty"); + } + + let mut chars = value.chars(); + let Some(first) = chars.next() else { + anyhow::bail!("{field_name} must not be empty"); + }; + + if !(first.is_ascii_alphabetic() || first == '_') { + anyhow::bail!("{field_name} must start with an ASCII letter or underscore; got '{value}'"); + } + + if !chars.all(|ch| ch.is_ascii_alphanumeric() || ch == '_') { + anyhow::bail!( + "{field_name} can only contain ASCII letters, numbers, and underscores; got '{value}'" + ); + } + + Ok(()) +} + +fn quote_identifier(value: &str) -> String { + format!("\"{value}\"") +} + +#[async_trait] +impl Memory for PostgresMemory { + fn name(&self) -> &str { + "postgres" + } + + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> Result<()> { + let client = self.client.clone(); + let qualified_table = self.qualified_table.clone(); + let key = key.to_string(); + let content = content.to_string(); + let category = Self::category_to_str(&category); + let session_id = session_id.map(str::to_string); + + tokio::task::spawn_blocking(move || -> Result<()> { + let now = Utc::now(); + let mut client = client.lock(); + let stmt = format!( + " + INSERT INTO {qualified_table} + (id, key, content, category, created_at, updated_at, session_id) + VALUES + ($1, $2, $3, $4, $5, $6, $7) + ON CONFLICT (key) DO UPDATE SET + content = EXCLUDED.content, + category = EXCLUDED.category, + updated_at = EXCLUDED.updated_at, + session_id = EXCLUDED.session_id + " + ); + + let id = Uuid::new_v4().to_string(); + client.execute( + &stmt, + &[&id, &key, &content, &category, &now, &now, &session_id], + )?; + Ok(()) + }) + .await? + } + + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> Result> { + let client = self.client.clone(); + let qualified_table = self.qualified_table.clone(); + let query = query.trim().to_string(); + let session_id = session_id.map(str::to_string); + + tokio::task::spawn_blocking(move || -> Result> { + let mut client = client.lock(); + let stmt = format!( + " + SELECT id, key, content, category, created_at, session_id, + ( + CASE WHEN key ILIKE '%' || $1 || '%' THEN 2.0 ELSE 0.0 END + + CASE WHEN content ILIKE '%' || $1 || '%' THEN 1.0 ELSE 0.0 END + ) AS score + FROM {qualified_table} + WHERE ($2::TEXT IS NULL OR session_id = $2) + AND ($1 = '' OR key ILIKE '%' || $1 || '%' OR content ILIKE '%' || $1 || '%') + ORDER BY score DESC, updated_at DESC + LIMIT $3 + " + ); + + #[allow(clippy::cast_possible_wrap)] + let limit_i64 = limit as i64; + + let rows = client.query(&stmt, &[&query, &session_id, &limit_i64])?; + rows.iter() + .map(Self::row_to_entry) + .collect::>>() + }) + .await? + } + + async fn get(&self, key: &str) -> Result> { + let client = self.client.clone(); + let qualified_table = self.qualified_table.clone(); + let key = key.to_string(); + + tokio::task::spawn_blocking(move || -> Result> { + let mut client = client.lock(); + let stmt = format!( + " + SELECT id, key, content, category, created_at, session_id + FROM {qualified_table} + WHERE key = $1 + LIMIT 1 + " + ); + + let row = client.query_opt(&stmt, &[&key])?; + row.as_ref().map(Self::row_to_entry).transpose() + }) + .await? + } + + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> Result> { + let client = self.client.clone(); + let qualified_table = self.qualified_table.clone(); + let category = category.map(Self::category_to_str); + let session_id = session_id.map(str::to_string); + + tokio::task::spawn_blocking(move || -> Result> { + let mut client = client.lock(); + let stmt = format!( + " + SELECT id, key, content, category, created_at, session_id + FROM {qualified_table} + WHERE ($1::TEXT IS NULL OR category = $1) + AND ($2::TEXT IS NULL OR session_id = $2) + ORDER BY updated_at DESC + " + ); + + let category_ref = category.as_deref(); + let session_ref = session_id.as_deref(); + let rows = client.query(&stmt, &[&category_ref, &session_ref])?; + rows.iter() + .map(Self::row_to_entry) + .collect::>>() + }) + .await? + } + + async fn forget(&self, key: &str) -> Result { + let client = self.client.clone(); + let qualified_table = self.qualified_table.clone(); + let key = key.to_string(); + + tokio::task::spawn_blocking(move || -> Result { + let mut client = client.lock(); + let stmt = format!("DELETE FROM {qualified_table} WHERE key = $1"); + let deleted = client.execute(&stmt, &[&key])?; + Ok(deleted > 0) + }) + .await? + } + + async fn count(&self) -> Result { + let client = self.client.clone(); + let qualified_table = self.qualified_table.clone(); + + tokio::task::spawn_blocking(move || -> Result { + let mut client = client.lock(); + let stmt = format!("SELECT COUNT(*) FROM {qualified_table}"); + let count: i64 = client.query_one(&stmt, &[])?.get(0); + let count = + usize::try_from(count).context("PostgreSQL returned a negative memory count")?; + Ok(count) + }) + .await? + } + + async fn health_check(&self) -> bool { + let client = self.client.clone(); + tokio::task::spawn_blocking(move || client.lock().simple_query("SELECT 1").is_ok()) + .await + .unwrap_or(false) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn valid_identifiers_pass_validation() { + assert!(validate_identifier("public", "schema").is_ok()); + assert!(validate_identifier("_memories_01", "table").is_ok()); + } + + #[test] + fn invalid_identifiers_are_rejected() { + assert!(validate_identifier("", "schema").is_err()); + assert!(validate_identifier("1bad", "schema").is_err()); + assert!(validate_identifier("bad-name", "table").is_err()); + } + + #[test] + fn parse_category_maps_known_and_custom_values() { + assert_eq!(PostgresMemory::parse_category("core"), MemoryCategory::Core); + assert_eq!( + PostgresMemory::parse_category("daily"), + MemoryCategory::Daily + ); + assert_eq!( + PostgresMemory::parse_category("conversation"), + MemoryCategory::Conversation + ); + assert_eq!( + PostgresMemory::parse_category("custom_notes"), + MemoryCategory::Custom("custom_notes".into()) + ); + } +} diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index ac30b18..f865b89 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -2,7 +2,7 @@ use crate::config::schema::{DingTalkConfig, IrcConfig, QQConfig, StreamMode, Wha use crate::config::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig, - RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, WebhookConfig, + RuntimeConfig, SecretsConfig, SlackConfig, StorageConfig, TelegramConfig, WebhookConfig, }; use crate::hardware::{self, HardwareConfig}; use crate::memory::{ @@ -125,6 +125,7 @@ pub fn run_wizard() -> Result { cron: crate::config::CronConfig::default(), channels_config, memory: memory_config, // User-selected memory backend + storage: StorageConfig::default(), tunnel: tunnel_config, gateway: crate::config::GatewayConfig::default(), composio: composio_config, @@ -347,6 +348,7 @@ pub fn run_quick_setup( cron: crate::config::CronConfig::default(), channels_config: ChannelsConfig::default(), memory: memory_config, + storage: StorageConfig::default(), tunnel: crate::config::TunnelConfig::default(), gateway: crate::config::GatewayConfig::default(), composio: ComposioConfig::default(),