fix(memory): avoid tokio runtime panic when initializing postgres backend
This commit is contained in:
parent
7c2c370180
commit
654f822430
1 changed files with 58 additions and 14 deletions
|
|
@ -30,24 +30,16 @@ impl PostgresMemory {
|
||||||
validate_identifier(schema, "storage schema")?;
|
validate_identifier(schema, "storage schema")?;
|
||||||
validate_identifier(table, "storage table")?;
|
validate_identifier(table, "storage table")?;
|
||||||
|
|
||||||
let mut config: postgres::Config = db_url
|
|
||||||
.parse()
|
|
||||||
.context("invalid PostgreSQL connection URL")?;
|
|
||||||
|
|
||||||
if let Some(timeout_secs) = connect_timeout_secs {
|
|
||||||
let bounded = timeout_secs.min(POSTGRES_CONNECT_TIMEOUT_CAP_SECS);
|
|
||||||
config.connect_timeout(Duration::from_secs(bounded));
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut client = config
|
|
||||||
.connect(NoTls)
|
|
||||||
.context("failed to connect to PostgreSQL memory backend")?;
|
|
||||||
|
|
||||||
let schema_ident = quote_identifier(schema);
|
let schema_ident = quote_identifier(schema);
|
||||||
let table_ident = quote_identifier(table);
|
let table_ident = quote_identifier(table);
|
||||||
let qualified_table = format!("{schema_ident}.{table_ident}");
|
let qualified_table = format!("{schema_ident}.{table_ident}");
|
||||||
|
|
||||||
Self::init_schema(&mut client, &schema_ident, &qualified_table)?;
|
let client = Self::initialize_client(
|
||||||
|
db_url.to_string(),
|
||||||
|
connect_timeout_secs,
|
||||||
|
schema_ident.clone(),
|
||||||
|
qualified_table.clone(),
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
client: Arc::new(Mutex::new(client)),
|
client: Arc::new(Mutex::new(client)),
|
||||||
|
|
@ -55,6 +47,40 @@ impl PostgresMemory {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn initialize_client(
|
||||||
|
db_url: String,
|
||||||
|
connect_timeout_secs: Option<u64>,
|
||||||
|
schema_ident: String,
|
||||||
|
qualified_table: String,
|
||||||
|
) -> Result<Client> {
|
||||||
|
let init_handle = std::thread::Builder::new()
|
||||||
|
.name("postgres-memory-init".to_string())
|
||||||
|
.spawn(move || -> Result<Client> {
|
||||||
|
let mut config: postgres::Config = db_url
|
||||||
|
.parse()
|
||||||
|
.context("invalid PostgreSQL connection URL")?;
|
||||||
|
|
||||||
|
if let Some(timeout_secs) = connect_timeout_secs {
|
||||||
|
let bounded = timeout_secs.min(POSTGRES_CONNECT_TIMEOUT_CAP_SECS);
|
||||||
|
config.connect_timeout(Duration::from_secs(bounded));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut client = config
|
||||||
|
.connect(NoTls)
|
||||||
|
.context("failed to connect to PostgreSQL memory backend")?;
|
||||||
|
|
||||||
|
Self::init_schema(&mut client, &schema_ident, &qualified_table)?;
|
||||||
|
Ok(client)
|
||||||
|
})
|
||||||
|
.context("failed to spawn PostgreSQL initializer thread")?;
|
||||||
|
|
||||||
|
let init_result = init_handle
|
||||||
|
.join()
|
||||||
|
.map_err(|_| anyhow::anyhow!("PostgreSQL initializer thread panicked"))?;
|
||||||
|
|
||||||
|
init_result
|
||||||
|
}
|
||||||
|
|
||||||
fn init_schema(client: &mut Client, schema_ident: &str, qualified_table: &str) -> Result<()> {
|
fn init_schema(client: &mut Client, schema_ident: &str, qualified_table: &str) -> Result<()> {
|
||||||
client.batch_execute(&format!(
|
client.batch_execute(&format!(
|
||||||
"
|
"
|
||||||
|
|
@ -346,4 +372,22 @@ mod tests {
|
||||||
MemoryCategory::Custom("custom_notes".into())
|
MemoryCategory::Custom("custom_notes".into())
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test(flavor = "current_thread")]
|
||||||
|
async fn new_does_not_panic_inside_tokio_runtime() {
|
||||||
|
let outcome = std::panic::catch_unwind(|| {
|
||||||
|
PostgresMemory::new(
|
||||||
|
"postgres://zeroclaw:password@127.0.0.1:1/zeroclaw",
|
||||||
|
"public",
|
||||||
|
"memories",
|
||||||
|
Some(1),
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(outcome.is_ok(), "PostgresMemory::new should not panic");
|
||||||
|
assert!(
|
||||||
|
outcome.unwrap().is_err(),
|
||||||
|
"PostgresMemory::new should return a connect error for an unreachable endpoint"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue