fix(whatsapp): complete wa-rs channel behavior and storage correctness

This commit is contained in:
Chummy 2026-02-19 17:25:08 +08:00
parent c2a1eb1088
commit b1ebd4b579
3 changed files with 884 additions and 271 deletions

View file

@ -21,22 +21,22 @@ use std::path::Path;
#[cfg(feature = "whatsapp-web")]
use std::sync::Arc;
#[cfg(feature = "whatsapp-web")]
use prost::Message;
#[cfg(feature = "whatsapp-web")]
use wa_rs_binary::jid::Jid;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::appstate::hash::HashState;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::appstate::processor::AppStateMutationMAC;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::store::Device as CoreDevice;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::store::traits::*;
use wa_rs_core::store::traits::DeviceInfo;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::store::traits::DeviceStore as DeviceStoreTrait;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::store::traits::DeviceInfo;
use wa_rs_core::store::traits::*;
#[cfg(feature = "whatsapp-web")]
use wa_rs_binary::jid::Jid;
#[cfg(feature = "whatsapp-web")]
use prost::Message;
use wa_rs_core::store::Device as CoreDevice;
/// Custom wa-rs storage backend using rusqlite
///
@ -59,15 +59,13 @@ pub struct RusqliteStore {
macro_rules! to_store_err {
// For expressions returning Result<usize, E>
(execute: $expr:expr) => {
$expr.map(|_| ()).map_err(|e| {
wa_rs_core::store::error::StoreError::Database(e.to_string())
})
$expr
.map(|_| ())
.map_err(|e| wa_rs_core::store::error::StoreError::Database(e.to_string()))
};
// For other expressions
($expr:expr) => {
$expr.map_err(|e| {
wa_rs_core::store::error::StoreError::Database(e.to_string())
})
$expr.map_err(|e| wa_rs_core::store::error::StoreError::Database(e.to_string()))
};
}
@ -268,7 +266,11 @@ impl RusqliteStore {
impl SignalStore for RusqliteStore {
// --- Identity Operations ---
async fn put_identity(&self, address: &str, key: [u8; 32]) -> wa_rs_core::store::error::Result<()> {
async fn put_identity(
&self,
address: &str,
key: [u8; 32],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"INSERT OR REPLACE INTO identities (address, key, device_id)
@ -277,7 +279,10 @@ impl SignalStore for RusqliteStore {
))
}
async fn load_identity(&self, address: &str) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
async fn load_identity(
&self,
address: &str,
) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT key FROM identities WHERE address = ?1 AND device_id = ?2",
@ -288,7 +293,9 @@ impl SignalStore for RusqliteStore {
match result {
Ok(key) => Ok(Some(key)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
@ -302,7 +309,10 @@ impl SignalStore for RusqliteStore {
// --- Session Operations ---
async fn get_session(&self, address: &str) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
async fn get_session(
&self,
address: &str,
) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT record FROM sessions WHERE address = ?1 AND device_id = ?2",
@ -313,11 +323,17 @@ impl SignalStore for RusqliteStore {
match result {
Ok(record) => Ok(Some(record)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn put_session(&self, address: &str, session: &[u8]) -> wa_rs_core::store::error::Result<()> {
async fn put_session(
&self,
address: &str,
session: &[u8],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"INSERT OR REPLACE INTO sessions (address, record, device_id)
@ -336,7 +352,12 @@ impl SignalStore for RusqliteStore {
// --- PreKey Operations ---
async fn store_prekey(&self, id: u32, record: &[u8], uploaded: bool) -> wa_rs_core::store::error::Result<()> {
async fn store_prekey(
&self,
id: u32,
record: &[u8],
uploaded: bool,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"INSERT OR REPLACE INTO prekeys (id, key, uploaded, device_id)
@ -356,7 +377,9 @@ impl SignalStore for RusqliteStore {
match result {
Ok(key) => Ok(Some(key)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
@ -370,7 +393,11 @@ impl SignalStore for RusqliteStore {
// --- Signed PreKey Operations ---
async fn store_signed_prekey(&self, id: u32, record: &[u8]) -> wa_rs_core::store::error::Result<()> {
async fn store_signed_prekey(
&self,
id: u32,
record: &[u8],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"INSERT OR REPLACE INTO signed_prekeys (id, record, device_id)
@ -379,7 +406,10 @@ impl SignalStore for RusqliteStore {
))
}
async fn load_signed_prekey(&self, id: u32) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
async fn load_signed_prekey(
&self,
id: u32,
) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT record FROM signed_prekeys WHERE id = ?1 AND device_id = ?2",
@ -390,15 +420,19 @@ impl SignalStore for RusqliteStore {
match result {
Ok(record) => Ok(Some(record)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn load_all_signed_prekeys(&self) -> wa_rs_core::store::error::Result<Vec<(u32, Vec<u8>)>> {
async fn load_all_signed_prekeys(
&self,
) -> wa_rs_core::store::error::Result<Vec<(u32, Vec<u8>)>> {
let conn = self.conn.lock();
let mut stmt = to_store_err!(conn.prepare(
"SELECT id, record FROM signed_prekeys WHERE device_id = ?1"
))?;
let mut stmt = to_store_err!(
conn.prepare("SELECT id, record FROM signed_prekeys WHERE device_id = ?1")
)?;
let rows = to_store_err!(stmt.query_map(params![self.device_id], |row| {
Ok((row.get::<_, u32>(0)?, row.get::<_, Vec<u8>>(1)?))
@ -422,7 +456,11 @@ impl SignalStore for RusqliteStore {
// --- Sender Key Operations ---
async fn put_sender_key(&self, address: &str, record: &[u8]) -> wa_rs_core::store::error::Result<()> {
async fn put_sender_key(
&self,
address: &str,
record: &[u8],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"INSERT OR REPLACE INTO sender_keys (address, record, device_id)
@ -431,7 +469,10 @@ impl SignalStore for RusqliteStore {
))
}
async fn get_sender_key(&self, address: &str) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
async fn get_sender_key(
&self,
address: &str,
) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT record FROM sender_keys WHERE address = ?1 AND device_id = ?2",
@ -442,7 +483,9 @@ impl SignalStore for RusqliteStore {
match result {
Ok(record) => Ok(Some(record)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
@ -458,7 +501,10 @@ impl SignalStore for RusqliteStore {
#[cfg(feature = "whatsapp-web")]
#[async_trait]
impl AppSyncStore for RusqliteStore {
async fn get_sync_key(&self, key_id: &[u8]) -> wa_rs_core::store::error::Result<Option<AppStateSyncKey>> {
async fn get_sync_key(
&self,
key_id: &[u8],
) -> wa_rs_core::store::error::Result<Option<AppStateSyncKey>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT key_data FROM app_state_keys WHERE key_id = ?1 AND device_id = ?2",
@ -473,11 +519,17 @@ impl AppSyncStore for RusqliteStore {
match result {
Ok(key) => Ok(Some(key)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn set_sync_key(&self, key_id: &[u8], key: AppStateSyncKey) -> wa_rs_core::store::error::Result<()> {
async fn set_sync_key(
&self,
key_id: &[u8],
key: AppStateSyncKey,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let key_data = to_store_err!(serde_json::to_vec(&key))?;
@ -499,7 +551,11 @@ impl AppSyncStore for RusqliteStore {
to_store_err!(serde_json::from_slice(&state_data))
}
async fn set_version(&self, name: &str, state: HashState) -> wa_rs_core::store::error::Result<()> {
async fn set_version(
&self,
name: &str,
state: HashState,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let state_data = to_store_err!(serde_json::to_vec(&state))?;
@ -533,7 +589,11 @@ impl AppSyncStore for RusqliteStore {
Ok(())
}
async fn get_mutation_mac(&self, name: &str, index_mac: &[u8]) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
async fn get_mutation_mac(
&self,
name: &str,
index_mac: &[u8],
) -> wa_rs_core::store::error::Result<Option<Vec<u8>>> {
let conn = self.conn.lock();
let index_mac_json = to_store_err!(serde_json::to_vec(index_mac))?;
@ -547,11 +607,17 @@ impl AppSyncStore for RusqliteStore {
match result {
Ok(mac) => Ok(Some(mac)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn delete_mutation_macs(&self, name: &str, index_macs: &[Vec<u8>]) -> wa_rs_core::store::error::Result<()> {
async fn delete_mutation_macs(
&self,
name: &str,
index_macs: &[Vec<u8>],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
for index_mac in index_macs {
@ -573,7 +639,10 @@ impl AppSyncStore for RusqliteStore {
impl ProtocolStore for RusqliteStore {
// --- SKDM Tracking ---
async fn get_skdm_recipients(&self, group_jid: &str) -> wa_rs_core::store::error::Result<Vec<Jid>> {
async fn get_skdm_recipients(
&self,
group_jid: &str,
) -> wa_rs_core::store::error::Result<Vec<Jid>> {
let conn = self.conn.lock();
let mut stmt = to_store_err!(conn.prepare(
"SELECT device_jid FROM skdm_recipients WHERE group_jid = ?1 AND device_id = ?2"
@ -594,7 +663,11 @@ impl ProtocolStore for RusqliteStore {
Ok(result)
}
async fn add_skdm_recipients(&self, group_jid: &str, device_jids: &[Jid]) -> wa_rs_core::store::error::Result<()> {
async fn add_skdm_recipients(
&self,
group_jid: &str,
device_jids: &[Jid],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let now = chrono::Utc::now().timestamp();
@ -619,7 +692,10 @@ impl ProtocolStore for RusqliteStore {
// --- LID-PN Mapping ---
async fn get_lid_mapping(&self, lid: &str) -> wa_rs_core::store::error::Result<Option<LidPnMappingEntry>> {
async fn get_lid_mapping(
&self,
lid: &str,
) -> wa_rs_core::store::error::Result<Option<LidPnMappingEntry>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT lid, phone_number, created_at, learning_source, updated_at
@ -630,8 +706,8 @@ impl ProtocolStore for RusqliteStore {
lid: row.get(0)?,
phone_number: row.get(1)?,
created_at: row.get(2)?,
updated_at: row.get(3)?,
learning_source: row.get(4)?,
learning_source: row.get(3)?,
updated_at: row.get(4)?,
})
},
);
@ -639,11 +715,16 @@ impl ProtocolStore for RusqliteStore {
match result {
Ok(entry) => Ok(Some(entry)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn get_pn_mapping(&self, phone: &str) -> wa_rs_core::store::error::Result<Option<LidPnMappingEntry>> {
async fn get_pn_mapping(
&self,
phone: &str,
) -> wa_rs_core::store::error::Result<Option<LidPnMappingEntry>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT lid, phone_number, created_at, learning_source, updated_at
@ -655,8 +736,8 @@ impl ProtocolStore for RusqliteStore {
lid: row.get(0)?,
phone_number: row.get(1)?,
created_at: row.get(2)?,
updated_at: row.get(3)?,
learning_source: row.get(4)?,
learning_source: row.get(3)?,
updated_at: row.get(4)?,
})
},
);
@ -664,11 +745,16 @@ impl ProtocolStore for RusqliteStore {
match result {
Ok(entry) => Ok(Some(entry)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn put_lid_mapping(&self, entry: &LidPnMappingEntry) -> wa_rs_core::store::error::Result<()> {
async fn put_lid_mapping(
&self,
entry: &LidPnMappingEntry,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"INSERT OR REPLACE INTO lid_pn_mapping
@ -685,7 +771,9 @@ impl ProtocolStore for RusqliteStore {
))
}
async fn get_all_lid_mappings(&self) -> wa_rs_core::store::error::Result<Vec<LidPnMappingEntry>> {
async fn get_all_lid_mappings(
&self,
) -> wa_rs_core::store::error::Result<Vec<LidPnMappingEntry>> {
let conn = self.conn.lock();
let mut stmt = to_store_err!(conn.prepare(
"SELECT lid, phone_number, created_at, learning_source, updated_at
@ -697,8 +785,8 @@ impl ProtocolStore for RusqliteStore {
lid: row.get(0)?,
phone_number: row.get(1)?,
created_at: row.get(2)?,
updated_at: row.get(3)?,
learning_source: row.get(4)?,
learning_source: row.get(3)?,
updated_at: row.get(4)?,
})
}))?;
@ -712,7 +800,12 @@ impl ProtocolStore for RusqliteStore {
// --- Base Key Collision Detection ---
async fn save_base_key(&self, address: &str, message_id: &str, base_key: &[u8]) -> wa_rs_core::store::error::Result<()> {
async fn save_base_key(
&self,
address: &str,
message_id: &str,
base_key: &[u8],
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let now = chrono::Utc::now().timestamp();
@ -743,11 +836,17 @@ impl ProtocolStore for RusqliteStore {
match result {
Ok(same) => Ok(same),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(false),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn delete_base_key(&self, address: &str, message_id: &str) -> wa_rs_core::store::error::Result<()> {
async fn delete_base_key(
&self,
address: &str,
message_id: &str,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
to_store_err!(execute: conn.execute(
"DELETE FROM base_keys WHERE address = ?1 AND message_id = ?2 AND device_id = ?3",
@ -757,7 +856,10 @@ impl ProtocolStore for RusqliteStore {
// --- Device Registry ---
async fn update_device_list(&self, record: DeviceListRecord) -> wa_rs_core::store::error::Result<()> {
async fn update_device_list(
&self,
record: DeviceListRecord,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let devices_json = to_store_err!(serde_json::to_string(&record.devices))?;
let now = chrono::Utc::now().timestamp();
@ -777,7 +879,10 @@ impl ProtocolStore for RusqliteStore {
))
}
async fn get_devices(&self, user: &str) -> wa_rs_core::store::error::Result<Option<DeviceListRecord>> {
async fn get_devices(
&self,
user: &str,
) -> wa_rs_core::store::error::Result<Option<DeviceListRecord>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT user_id, devices_json, timestamp, phash
@ -785,13 +890,15 @@ impl ProtocolStore for RusqliteStore {
params![user, self.device_id],
|row| {
// Helper to convert errors to rusqlite::Error
fn to_rusqlite_err<E: std::error::Error + Send + Sync + 'static>(e: E) -> rusqlite::Error {
fn to_rusqlite_err<E: std::error::Error + Send + Sync + 'static>(
e: E,
) -> rusqlite::Error {
rusqlite::Error::ToSqlConversionFailure(Box::new(e))
}
let devices_json: String = row.get(1)?;
let devices: Vec<DeviceInfo> = serde_json::from_str(&devices_json)
.map_err(to_rusqlite_err)?;
let devices: Vec<DeviceInfo> =
serde_json::from_str(&devices_json).map_err(to_rusqlite_err)?;
Ok(DeviceListRecord {
user: row.get(0)?,
devices,
@ -804,13 +911,19 @@ impl ProtocolStore for RusqliteStore {
match result {
Ok(record) => Ok(Some(record)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
// --- Sender Key Status (Lazy Deletion) ---
async fn mark_forget_sender_key(&self, group_jid: &str, participant: &str) -> wa_rs_core::store::error::Result<()> {
async fn mark_forget_sender_key(
&self,
group_jid: &str,
participant: &str,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let now = chrono::Utc::now().timestamp();
@ -821,7 +934,10 @@ impl ProtocolStore for RusqliteStore {
))
}
async fn consume_forget_marks(&self, group_jid: &str) -> wa_rs_core::store::error::Result<Vec<String>> {
async fn consume_forget_marks(
&self,
group_jid: &str,
) -> wa_rs_core::store::error::Result<Vec<String>> {
let conn = self.conn.lock();
let mut stmt = to_store_err!(conn.prepare(
"SELECT participant FROM sender_key_status
@ -848,7 +964,10 @@ impl ProtocolStore for RusqliteStore {
// --- TcToken Storage ---
async fn get_tc_token(&self, jid: &str) -> wa_rs_core::store::error::Result<Option<TcTokenEntry>> {
async fn get_tc_token(
&self,
jid: &str,
) -> wa_rs_core::store::error::Result<Option<TcTokenEntry>> {
let conn = self.conn.lock();
let result = conn.query_row(
"SELECT token, token_timestamp, sender_timestamp FROM tc_tokens
@ -866,11 +985,17 @@ impl ProtocolStore for RusqliteStore {
match result {
Ok(entry) => Ok(Some(entry)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
async fn put_tc_token(&self, jid: &str, entry: &TcTokenEntry) -> wa_rs_core::store::error::Result<()> {
async fn put_tc_token(
&self,
jid: &str,
entry: &TcTokenEntry,
) -> wa_rs_core::store::error::Result<()> {
let conn = self.conn.lock();
let now = chrono::Utc::now().timestamp();
@ -899,13 +1024,12 @@ impl ProtocolStore for RusqliteStore {
async fn get_all_tc_token_jids(&self) -> wa_rs_core::store::error::Result<Vec<String>> {
let conn = self.conn.lock();
let mut stmt = to_store_err!(conn.prepare(
"SELECT jid FROM tc_tokens WHERE device_id = ?1"
))?;
let mut stmt =
to_store_err!(conn.prepare("SELECT jid FROM tc_tokens WHERE device_id = ?1"))?;
let rows = to_store_err!(stmt.query_map(params![self.device_id], |row| {
row.get::<_, String>(0)
}))?;
let rows = to_store_err!(
stmt.query_map(params![self.device_id], |row| { row.get::<_, String>(0) })
)?;
let mut result = Vec::new();
for row in rows {
@ -915,14 +1039,25 @@ impl ProtocolStore for RusqliteStore {
Ok(result)
}
async fn delete_expired_tc_tokens(&self, cutoff_timestamp: i64) -> wa_rs_core::store::error::Result<u32> {
async fn delete_expired_tc_tokens(
&self,
cutoff_timestamp: i64,
) -> wa_rs_core::store::error::Result<u32> {
let conn = self.conn.lock();
// Note: We can't easily get the affected row count with the execute macro, so we'll just return 0 for now
to_store_err!(execute: conn.execute(
"DELETE FROM tc_tokens WHERE token_timestamp < ?1 AND device_id = ?2",
params![cutoff_timestamp, self.device_id],
))?;
Ok(0) // TODO: Return actual affected row count
let deleted = conn
.execute(
"DELETE FROM tc_tokens WHERE token_timestamp < ?1 AND device_id = ?2",
params![cutoff_timestamp, self.device_id],
)
.map_err(|e| wa_rs_core::store::error::StoreError::Database(e.to_string()))?;
let deleted = u32::try_from(deleted).map_err(|_| {
wa_rs_core::store::error::StoreError::Database(format!(
"Affected row count overflowed u32: {deleted}"
))
})?;
Ok(deleted)
}
}
@ -997,7 +1132,9 @@ impl DeviceStoreTrait for RusqliteStore {
params![self.device_id],
|row| {
// Helper to convert errors to rusqlite::Error
fn to_rusqlite_err<E: std::error::Error + Send + Sync + 'static>(e: E) -> rusqlite::Error {
fn to_rusqlite_err<E: std::error::Error + Send + Sync + 'static>(
e: E,
) -> rusqlite::Error {
rusqlite::Error::ToSqlConversionFailure(Box::new(e))
}
@ -1006,24 +1143,25 @@ impl DeviceStoreTrait for RusqliteStore {
let identity_key_bytes: Vec<u8> = row.get("identity_key")?;
let signed_pre_key_bytes: Vec<u8> = row.get("signed_pre_key")?;
if noise_key_bytes.len() != 64 || identity_key_bytes.len() != 64 || signed_pre_key_bytes.len() != 64 {
if noise_key_bytes.len() != 64
|| identity_key_bytes.len() != 64
|| signed_pre_key_bytes.len() != 64
{
return Err(rusqlite::Error::InvalidParameterName("key_pair".into()));
}
use wa_rs_core::libsignal::protocol::{PrivateKey, PublicKey, KeyPair};
use wa_rs_core::libsignal::protocol::{KeyPair, PrivateKey, PublicKey};
let noise_key = KeyPair::new(
PublicKey::from_djb_public_key_bytes(&noise_key_bytes[32..64])
.map_err(to_rusqlite_err)?,
PrivateKey::deserialize(&noise_key_bytes[0..32])
.map_err(to_rusqlite_err)?,
PrivateKey::deserialize(&noise_key_bytes[0..32]).map_err(to_rusqlite_err)?,
);
let identity_key = KeyPair::new(
PublicKey::from_djb_public_key_bytes(&identity_key_bytes[32..64])
.map_err(to_rusqlite_err)?,
PrivateKey::deserialize(&identity_key_bytes[0..32])
.map_err(to_rusqlite_err)?,
PrivateKey::deserialize(&identity_key_bytes[0..32]).map_err(to_rusqlite_err)?,
);
let signed_pre_key = KeyPair::new(
@ -1045,8 +1183,10 @@ impl DeviceStoreTrait for RusqliteStore {
adv_secret.copy_from_slice(&adv_secret_bytes);
let account = if let Some(bytes) = account_bytes {
Some(wa_rs_proto::whatsapp::AdvSignedDeviceIdentity::decode(&*bytes)
.map_err(to_rusqlite_err)?)
Some(
wa_rs_proto::whatsapp::AdvSignedDeviceIdentity::decode(&*bytes)
.map_err(to_rusqlite_err)?,
)
} else {
None
};
@ -1077,7 +1217,9 @@ impl DeviceStoreTrait for RusqliteStore {
match result {
Ok(device) => Ok(Some(device)),
Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(e.to_string())),
Err(e) => Err(wa_rs_core::store::error::StoreError::Database(
e.to_string(),
)),
}
}
@ -1097,7 +1239,11 @@ impl DeviceStoreTrait for RusqliteStore {
Ok(self.device_id)
}
async fn snapshot_db(&self, name: &str, extra_content: Option<&[u8]>) -> wa_rs_core::store::error::Result<()> {
async fn snapshot_db(
&self,
name: &str,
extra_content: Option<&[u8]>,
) -> wa_rs_core::store::error::Result<()> {
// Create a snapshot by copying the database file
let snapshot_path = format!("{}.snapshot.{}", self.db_path, name);
@ -1116,6 +1262,8 @@ impl DeviceStoreTrait for RusqliteStore {
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "whatsapp-web")]
use wa_rs_core::store::traits::{LidPnMappingEntry, ProtocolStore, TcTokenEntry};
#[cfg(feature = "whatsapp-web")]
#[test]
@ -1124,4 +1272,74 @@ mod tests {
let store = RusqliteStore::new(tmp.path()).unwrap();
assert_eq!(store.device_id, 1);
}
#[cfg(feature = "whatsapp-web")]
#[tokio::test]
async fn lid_mapping_round_trip_preserves_learning_source_and_updated_at() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let store = RusqliteStore::new(tmp.path()).unwrap();
let entry = LidPnMappingEntry {
lid: "100000012345678".to_string(),
phone_number: "15551234567".to_string(),
created_at: 1_700_000_000,
updated_at: 1_700_000_100,
learning_source: "usync".to_string(),
};
ProtocolStore::put_lid_mapping(&store, &entry)
.await
.unwrap();
let loaded = ProtocolStore::get_lid_mapping(&store, &entry.lid)
.await
.unwrap()
.expect("expected lid mapping to be present");
assert_eq!(loaded.learning_source, entry.learning_source);
assert_eq!(loaded.updated_at, entry.updated_at);
let loaded_by_pn = ProtocolStore::get_pn_mapping(&store, &entry.phone_number)
.await
.unwrap()
.expect("expected pn mapping to be present");
assert_eq!(loaded_by_pn.learning_source, entry.learning_source);
assert_eq!(loaded_by_pn.updated_at, entry.updated_at);
}
#[cfg(feature = "whatsapp-web")]
#[tokio::test]
async fn delete_expired_tc_tokens_returns_deleted_row_count() {
let tmp = tempfile::NamedTempFile::new().unwrap();
let store = RusqliteStore::new(tmp.path()).unwrap();
let expired = TcTokenEntry {
token: vec![1, 2, 3],
token_timestamp: 10,
sender_timestamp: None,
};
let fresh = TcTokenEntry {
token: vec![4, 5, 6],
token_timestamp: 1000,
sender_timestamp: Some(1000),
};
ProtocolStore::put_tc_token(&store, "15550000001", &expired)
.await
.unwrap();
ProtocolStore::put_tc_token(&store, "15550000002", &fresh)
.await
.unwrap();
let deleted = ProtocolStore::delete_expired_tc_tokens(&store, 100)
.await
.unwrap();
assert_eq!(deleted, 1);
assert!(ProtocolStore::get_tc_token(&store, "15550000001")
.await
.unwrap()
.is_none());
assert!(ProtocolStore::get_tc_token(&store, "15550000002")
.await
.unwrap()
.is_some());
}
}

View file

@ -28,7 +28,7 @@
use super::traits::{Channel, ChannelMessage, SendMessage};
use super::whatsapp_storage::RusqliteStore;
use anyhow::Result;
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use parking_lot::Mutex;
use std::sync::Arc;
@ -60,6 +60,8 @@ pub struct WhatsAppWebChannel {
allowed_numbers: Vec<String>,
/// Bot handle for shutdown
bot_handle: Arc<Mutex<Option<tokio::task::JoinHandle<()>>>>,
/// Client handle for sending messages and typing indicators
client: Arc<Mutex<Option<Arc<wa_rs::Client>>>>,
/// Message sender channel
tx: Arc<Mutex<Option<tokio::sync::mpsc::Sender<ChannelMessage>>>>,
}
@ -86,6 +88,7 @@ impl WhatsAppWebChannel {
pair_code,
allowed_numbers,
bot_handle: Arc::new(Mutex::new(None)),
client: Arc::new(Mutex::new(None)),
tx: Arc::new(Mutex::new(None)),
}
}
@ -100,12 +103,44 @@ impl WhatsAppWebChannel {
/// Normalize phone number to E.164 format
#[cfg(feature = "whatsapp-web")]
fn normalize_phone(&self, phone: &str) -> String {
if phone.starts_with('+') {
phone.to_string()
let trimmed = phone.trim();
let user_part = trimmed
.split_once('@')
.map(|(user, _)| user)
.unwrap_or(trimmed);
let normalized_user = user_part.trim_start_matches('+');
if user_part.starts_with('+') {
format!("+{normalized_user}")
} else {
format!("+{phone}")
format!("+{normalized_user}")
}
}
/// Convert a recipient to a wa-rs JID.
///
/// Supports:
/// - Full JIDs (e.g. "12345@s.whatsapp.net")
/// - E.164-like numbers (e.g. "+1234567890")
#[cfg(feature = "whatsapp-web")]
fn recipient_to_jid(&self, recipient: &str) -> Result<wa_rs_binary::jid::Jid> {
let trimmed = recipient.trim();
if trimmed.is_empty() {
anyhow::bail!("Recipient cannot be empty");
}
if trimmed.contains('@') {
return trimmed
.parse::<wa_rs_binary::jid::Jid>()
.map_err(|e| anyhow!("Invalid WhatsApp JID `{trimmed}`: {e}"));
}
let digits: String = trimmed.chars().filter(|c| c.is_ascii_digit()).collect();
if digits.is_empty() {
anyhow::bail!("Recipient `{trimmed}` does not contain a valid phone number");
}
Ok(wa_rs_binary::jid::Jid::pn(digits))
}
}
#[cfg(feature = "whatsapp-web")]
@ -116,23 +151,33 @@ impl Channel for WhatsAppWebChannel {
}
async fn send(&self, message: &SendMessage) -> Result<()> {
// Check if bot is running
let bot_handle_guard = self.bot_handle.lock();
if bot_handle_guard.is_none() {
let client = self.client.lock().clone();
let Some(client) = client else {
anyhow::bail!("WhatsApp Web client not connected. Initialize the bot first.");
}
drop(bot_handle_guard);
};
// Validate recipient is allowed
let normalized = self.normalize_phone(&message.recipient);
if !self.is_number_allowed(&normalized) {
tracing::warn!("WhatsApp Web: recipient {} not in allowed list", message.recipient);
tracing::warn!(
"WhatsApp Web: recipient {} not in allowed list",
message.recipient
);
return Ok(());
}
// TODO: Implement sending via wa-rs client
// This requires getting the client from the bot and using its send_message API
tracing::debug!("WhatsApp Web: sending message to {}: {}", message.recipient, message.content);
let to = self.recipient_to_jid(&message.recipient)?;
let outgoing = wa_rs_proto::whatsapp::Message {
conversation: Some(message.content.clone()),
..Default::default()
};
let message_id = client.send_message(to, outgoing).await?;
tracing::debug!(
"WhatsApp Web: sent message to {} (id: {})",
message.recipient,
message_id
);
Ok(())
}
@ -141,11 +186,13 @@ impl Channel for WhatsAppWebChannel {
*self.tx.lock() = Some(tx.clone());
use wa_rs::bot::Bot;
use wa_rs::pair_code::PairCodeOptions;
use wa_rs::store::{Device, DeviceStore};
use wa_rs_core::types::events::Event;
use wa_rs_ureq_http::UreqHttpClient;
use wa_rs_tokio_transport::TokioWebSocketTransportFactory;
use wa_rs_binary::jid::JidExt as _;
use wa_rs_core::proto_helpers::MessageExt;
use wa_rs_core::types::events::Event;
use wa_rs_tokio_transport::TokioWebSocketTransportFactory;
use wa_rs_ureq_http::UreqHttpClient;
tracing::info!(
"WhatsApp Web channel starting (session: {})",
@ -166,7 +213,9 @@ impl Channel for WhatsAppWebChannel {
anyhow::bail!("Device exists but failed to load");
}
} else {
tracing::info!("WhatsApp Web: no existing session, new device will be created during pairing");
tracing::info!(
"WhatsApp Web: no existing session, new device will be created during pairing"
);
};
// Create transport factory
@ -182,7 +231,7 @@ impl Channel for WhatsAppWebChannel {
let tx_clone = tx.clone();
let allowed_numbers = self.allowed_numbers.clone();
let mut bot = Bot::builder()
let mut builder = Bot::builder()
.with_backend(backend)
.with_transport_factory(transport_factory)
.with_http_client(http_client)
@ -194,7 +243,7 @@ impl Channel for WhatsAppWebChannel {
Event::Message(msg, info) => {
// Extract message content
let text = msg.text_content().unwrap_or("");
let sender = info.source.sender.to_string();
let sender = info.source.sender.user().to_string();
let chat = info.source.chat.to_string();
tracing::info!("📨 WhatsApp message from {} in {}: {}", sender, chat, text);
@ -209,14 +258,17 @@ impl Channel for WhatsAppWebChannel {
if allowed_numbers.is_empty()
|| allowed_numbers.iter().any(|n| n == "*" || n == &normalized)
{
if let Err(e) = tx_inner.send(ChannelMessage {
id: uuid::Uuid::new_v4().to_string(),
channel: "whatsapp".to_string(),
sender: normalized.clone(),
reply_target: normalized.clone(),
content: text.to_string(),
timestamp: chrono::Utc::now().timestamp_millis() as u64,
}).await {
if let Err(e) = tx_inner
.send(ChannelMessage {
id: uuid::Uuid::new_v4().to_string(),
channel: "whatsapp".to_string(),
sender: normalized.clone(),
reply_target: normalized.clone(),
content: text.to_string(),
timestamp: chrono::Utc::now().timestamp_millis() as u64,
})
.await
{
tracing::error!("Failed to send message to channel: {}", e);
}
} else {
@ -244,17 +296,25 @@ impl Channel for WhatsAppWebChannel {
}
}
})
.build()
.await?;
;
// Configure pair code options if pair_phone is set
// Configure pair-code flow when a phone number is provided.
if let Some(ref phone) = self.pair_phone {
// Set the phone number for pair code linking
// The exact API depends on wa-rs version
tracing::info!("Requesting pair code for phone: {}", phone);
// bot.request_pair_code(phone).await?;
tracing::info!("WhatsApp Web: pair-code flow enabled for configured phone number");
builder = builder.with_pair_code(PairCodeOptions {
phone_number: phone.clone(),
custom_code: self.pair_code.clone(),
..Default::default()
});
} else if self.pair_code.is_some() {
tracing::warn!(
"WhatsApp Web: pair_code is set but pair_phone is missing; pair code config is ignored"
);
}
let mut bot = builder.build().await?;
*self.client.lock() = Some(bot.client());
// Run the bot
let bot_handle = bot.run().await?;
@ -273,6 +333,11 @@ impl Channel for WhatsAppWebChannel {
}
}
*self.client.lock() = None;
if let Some(handle) = self.bot_handle.lock().take() {
handle.abort();
}
Ok(())
}
@ -282,14 +347,54 @@ impl Channel for WhatsAppWebChannel {
}
async fn start_typing(&self, recipient: &str) -> Result<()> {
let client = self.client.lock().clone();
let Some(client) = client else {
anyhow::bail!("WhatsApp Web client not connected. Initialize the bot first.");
};
let normalized = self.normalize_phone(recipient);
if !self.is_number_allowed(&normalized) {
tracing::warn!(
"WhatsApp Web: typing target {} not in allowed list",
recipient
);
return Ok(());
}
let to = self.recipient_to_jid(recipient)?;
client
.chatstate()
.send_composing(&to)
.await
.map_err(|e| anyhow!("Failed to send typing state (composing): {e}"))?;
tracing::debug!("WhatsApp Web: start typing for {}", recipient);
// TODO: Implement typing indicator via wa-rs client
Ok(())
}
async fn stop_typing(&self, recipient: &str) -> Result<()> {
let client = self.client.lock().clone();
let Some(client) = client else {
anyhow::bail!("WhatsApp Web client not connected. Initialize the bot first.");
};
let normalized = self.normalize_phone(recipient);
if !self.is_number_allowed(&normalized) {
tracing::warn!(
"WhatsApp Web: typing target {} not in allowed list",
recipient
);
return Ok(());
}
let to = self.recipient_to_jid(recipient)?;
client
.chatstate()
.send_paused(&to)
.await
.map_err(|e| anyhow!("Failed to send typing state (paused): {e}"))?;
tracing::debug!("WhatsApp Web: stop typing for {}", recipient);
// TODO: Implement typing indicator via wa-rs client
Ok(())
}
}
@ -308,10 +413,7 @@ impl WhatsAppWebChannel {
_pair_code: Option<String>,
_allowed_numbers: Vec<String>,
) -> Self {
panic!(
"WhatsApp Web channel requires the 'whatsapp-web' feature. \
Enable with: cargo build --features whatsapp-web"
);
Self { _private: () }
}
}
@ -323,11 +425,17 @@ impl Channel for WhatsAppWebChannel {
}
async fn send(&self, _message: &SendMessage) -> Result<()> {
unreachable!()
anyhow::bail!(
"WhatsApp Web channel requires the 'whatsapp-web' feature. \
Enable with: cargo build --features whatsapp-web"
);
}
async fn listen(&self, _tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> Result<()> {
unreachable!()
anyhow::bail!(
"WhatsApp Web channel requires the 'whatsapp-web' feature. \
Enable with: cargo build --features whatsapp-web"
);
}
async fn health_check(&self) -> bool {
@ -335,11 +443,17 @@ impl Channel for WhatsAppWebChannel {
}
async fn start_typing(&self, _recipient: &str) -> Result<()> {
unreachable!()
anyhow::bail!(
"WhatsApp Web channel requires the 'whatsapp-web' feature. \
Enable with: cargo build --features whatsapp-web"
);
}
async fn stop_typing(&self, _recipient: &str) -> Result<()> {
unreachable!()
anyhow::bail!(
"WhatsApp Web channel requires the 'whatsapp-web' feature. \
Enable with: cargo build --features whatsapp-web"
);
}
}