1use crate::storage_trait::*;
7use async_trait::async_trait;
8use libsignal_protocol::{
9 message_decrypt, message_encrypt, process_prekey_bundle, CiphertextMessage, Direction,
10 GenericSignedPreKey, IdentityChange, IdentityKey, IdentityKeyPair, IdentityKeyStore,
11 KyberPreKeyId, KyberPreKeyRecord, KyberPreKeyStore, PreKeyBundle, PreKeyId, PreKeyRecord,
12 PreKeyStore, ProtocolAddress, SessionRecord, SessionStore, SignalProtocolError, SignedPreKeyId,
13 SignedPreKeyRecord, SignedPreKeyStore, UsePQRatchet,
14};
15use rusqlite::Connection;
16use std::sync::{Arc, Mutex};
17
18type BundleMetadata = (u32, u32, u32);
20
21pub struct SqliteStorage {
23 connection: Arc<Mutex<Connection>>,
24 session_store: Option<SqliteSessionStore>,
25 identity_store: Option<SqliteIdentityStore>,
26 pre_key_store: Option<SqlitePreKeyStore>,
27 signed_pre_key_store: Option<SqliteSignedPreKeyStore>,
28 kyber_pre_key_store: Option<SqliteKyberPreKeyStore>,
29 message_history: Option<crate::message_history::MessageHistory>,
30 is_closed: bool,
31}
32
33impl SqliteStorage {
34 pub async fn new(db_path: &str) -> Result<Self, Box<dyn std::error::Error>> {
35 use crate::db_encryption;
36
37 let key = db_encryption::get_or_create_db_key(db_path)?;
38 let connection = Connection::open(db_path)?;
39
40 connection.pragma_update(None, "key", hex::encode(key))?;
41
42 let connection = Arc::new(Mutex::new(connection));
43
44 Ok(Self {
45 connection,
46 session_store: None,
47 identity_store: None,
48 pre_key_store: None,
49 signed_pre_key_store: None,
50 kyber_pre_key_store: None,
51 message_history: None,
52 is_closed: false,
53 })
54 }
55
56 pub fn is_closed(&self) -> bool {
57 self.is_closed
58 }
59
60 pub fn initialize_schema(&mut self) -> Result<(), Box<dyn std::error::Error>> {
61 {
62 let conn = self.connection.lock().unwrap();
63
64 conn.execute(
65 "CREATE TABLE IF NOT EXISTS schema_info (
66 version INTEGER NOT NULL DEFAULT 1,
67 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
68 )",
69 [],
70 )?;
71
72 conn.execute("INSERT OR IGNORE INTO schema_info (version) VALUES (1)", [])?;
73
74 let current_version: i32 =
75 conn.query_row("SELECT version FROM schema_info", [], |row| row.get(0))?;
76
77 SqliteIdentityStore::create_tables(&conn)?;
78 SqliteSessionStore::create_tables(&conn)?;
79 SqlitePreKeyStore::create_tables(&conn)?;
80 SqliteSignedPreKeyStore::create_tables(&conn)?;
81 SqliteKyberPreKeyStore::create_tables(&conn)?;
82
83 conn.execute(
84 "CREATE TABLE IF NOT EXISTS contacts (
85 rdx_fingerprint TEXT PRIMARY KEY,
86 nostr_pubkey TEXT UNIQUE NOT NULL,
87 user_alias TEXT,
88 signal_identity_key BLOB NOT NULL,
89 first_seen INTEGER NOT NULL,
90 last_updated INTEGER NOT NULL
91 )",
92 [],
93 )?;
94
95 conn.execute(
96 "CREATE INDEX IF NOT EXISTS idx_contacts_alias
97 ON contacts(user_alias) WHERE user_alias IS NOT NULL",
98 [],
99 )?;
100
101 conn.execute(
102 "CREATE INDEX IF NOT EXISTS idx_contacts_nostr_pubkey
103 ON contacts(nostr_pubkey)",
104 [],
105 )?;
106
107 conn.execute(
108 "CREATE TABLE IF NOT EXISTS settings (
109 key TEXT PRIMARY KEY,
110 value TEXT NOT NULL,
111 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
112 )",
113 [],
114 )?;
115
116 conn.execute(
117 "CREATE TABLE IF NOT EXISTS bundle_metadata (
118 id INTEGER PRIMARY KEY CHECK (id = 1),
119 pre_key_id INTEGER NOT NULL,
120 signed_pre_key_id INTEGER NOT NULL,
121 kyber_pre_key_id INTEGER NOT NULL,
122 published_at INTEGER NOT NULL
123 )",
124 [],
125 )?;
126
127 if current_version < 2 {
128 Self::migrate_to_v2(&conn)?;
129 }
130 }
131
132 self.session_store = Some(SqliteSessionStore::new(self.connection.clone()));
133 self.identity_store = Some(SqliteIdentityStore::new(self.connection.clone()));
134 self.pre_key_store = Some(SqlitePreKeyStore::new(self.connection.clone()));
135 self.signed_pre_key_store = Some(SqliteSignedPreKeyStore::new(self.connection.clone()));
136 self.kyber_pre_key_store = Some(SqliteKyberPreKeyStore::new(self.connection.clone()));
137 self.message_history = Some(crate::message_history::MessageHistory::new(
138 self.connection.clone(),
139 ));
140
141 Ok(())
142 }
143
144 fn migrate_to_v2(conn: &Connection) -> Result<(), Box<dyn std::error::Error>> {
145 conn.execute(
146 "CREATE TABLE IF NOT EXISTS conversations (
147 id INTEGER PRIMARY KEY AUTOINCREMENT,
148 rdx_fingerprint TEXT NOT NULL UNIQUE,
149 nostr_pubkey TEXT,
150 last_message_timestamp INTEGER NOT NULL,
151 unread_count INTEGER DEFAULT 0,
152 archived BOOLEAN DEFAULT 0,
153 FOREIGN KEY (rdx_fingerprint) REFERENCES contacts(rdx_fingerprint) ON DELETE CASCADE
154 )",
155 [],
156 )?;
157
158 conn.execute(
159 "CREATE INDEX IF NOT EXISTS idx_conversations_timestamp
160 ON conversations(last_message_timestamp DESC)",
161 [],
162 )?;
163
164 conn.execute(
165 "CREATE INDEX IF NOT EXISTS idx_conversations_unread
166 ON conversations(unread_count) WHERE unread_count > 0",
167 [],
168 )?;
169
170 conn.execute(
171 "CREATE TABLE IF NOT EXISTS messages (
172 id INTEGER PRIMARY KEY AUTOINCREMENT,
173 conversation_id INTEGER NOT NULL,
174 direction INTEGER NOT NULL,
175 timestamp INTEGER NOT NULL,
176 message_type INTEGER NOT NULL,
177 content BLOB NOT NULL,
178 delivery_status INTEGER DEFAULT 0,
179 was_prekey_message BOOLEAN DEFAULT 0,
180 session_established BOOLEAN DEFAULT 0,
181 FOREIGN KEY (conversation_id) REFERENCES conversations(id) ON DELETE CASCADE
182 )",
183 [],
184 )?;
185
186 conn.execute(
187 "CREATE INDEX IF NOT EXISTS idx_messages_conversation
188 ON messages(conversation_id, timestamp DESC)",
189 [],
190 )?;
191
192 conn.execute(
193 "CREATE INDEX IF NOT EXISTS idx_messages_timestamp
194 ON messages(timestamp DESC)",
195 [],
196 )?;
197
198 conn.execute(
199 "CREATE INDEX IF NOT EXISTS idx_messages_undelivered
200 ON messages(delivery_status) WHERE delivery_status IN (0, 3)",
201 [],
202 )?;
203
204 conn.execute(
205 "UPDATE schema_info SET version = 2, updated_at = strftime('%s', 'now')",
206 [],
207 )?;
208
209 Ok(())
210 }
211
212 pub fn get_schema_version(&self) -> Result<i32, Box<dyn std::error::Error>> {
213 let conn = self.connection.lock().unwrap();
214 let mut stmt = conn.prepare("SELECT version FROM schema_info")?;
215 let version: i32 = stmt.query_row([], |row| row.get(0))?;
216 Ok(version)
217 }
218
219 pub fn connection(&self) -> Arc<Mutex<Connection>> {
220 self.connection.clone()
221 }
222
223 pub fn message_history(&self) -> &crate::message_history::MessageHistory {
224 self.message_history
225 .as_ref()
226 .expect("MessageHistory not initialized")
227 }
228
229 pub fn get_last_message_timestamp(&self) -> Result<u64, Box<dyn std::error::Error>> {
230 let conn = self.connection.lock().unwrap();
231 let mut stmt =
232 conn.prepare("SELECT value FROM settings WHERE key = 'last_message_timestamp'")?;
233
234 match stmt.query_row([], |row| row.get::<_, String>(0)) {
235 Ok(value_str) => Ok(value_str.parse::<u64>()?),
236 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(0),
237 Err(e) => Err(Box::new(e)),
238 }
239 }
240
241 pub fn set_last_message_timestamp(
242 &mut self,
243 timestamp: u64,
244 ) -> Result<(), Box<dyn std::error::Error>> {
245 let conn = self.connection.lock().unwrap();
246 conn.execute(
247 "INSERT OR REPLACE INTO settings (key, value, updated_at) VALUES ('last_message_timestamp', ?1, strftime('%s', 'now'))",
248 [timestamp.to_string()],
249 )?;
250 Ok(())
251 }
252
253 pub fn record_published_bundle(
254 &mut self,
255 pre_key_id: u32,
256 signed_pre_key_id: u32,
257 kyber_pre_key_id: u32,
258 ) -> Result<(), Box<dyn std::error::Error>> {
259 let conn = self.connection.lock().unwrap();
260 let now = std::time::SystemTime::now()
261 .duration_since(std::time::UNIX_EPOCH)?
262 .as_secs();
263
264 conn.execute(
265 "INSERT OR REPLACE INTO bundle_metadata (id, pre_key_id, signed_pre_key_id, kyber_pre_key_id, published_at)
266 VALUES (1, ?1, ?2, ?3, ?4)",
267 rusqlite::params![pre_key_id, signed_pre_key_id, kyber_pre_key_id, now],
268 )?;
269 Ok(())
270 }
271
272 pub fn get_last_published_bundle_metadata(
273 &self,
274 ) -> Result<Option<BundleMetadata>, Box<dyn std::error::Error>> {
275 let conn = self.connection.lock().unwrap();
276 let mut stmt = conn.prepare(
277 "SELECT pre_key_id, signed_pre_key_id, kyber_pre_key_id FROM bundle_metadata WHERE id = 1",
278 )?;
279
280 match stmt.query_row([], |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?))) {
281 Ok(result) => Ok(Some(result)),
282 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
283 Err(e) => Err(Box::new(e)),
284 }
285 }
286}
287
288impl SignalStorageContainer for SqliteStorage {
289 type SessionStore = SqliteSessionStore;
290 type IdentityStore = SqliteIdentityStore;
291 type PreKeyStore = SqlitePreKeyStore;
292 type SignedPreKeyStore = SqliteSignedPreKeyStore;
293 type KyberPreKeyStore = SqliteKyberPreKeyStore;
294
295 fn session_store(&mut self) -> &mut Self::SessionStore {
296 if self.is_closed {
297 panic!("Storage has been closed");
298 }
299 self.session_store
300 .as_mut()
301 .expect("Storage not initialized")
302 }
303
304 fn identity_store(&mut self) -> &mut Self::IdentityStore {
305 if self.is_closed {
306 panic!("Storage has been closed");
307 }
308 self.identity_store
309 .as_mut()
310 .expect("Storage not initialized")
311 }
312
313 fn pre_key_store(&mut self) -> &mut Self::PreKeyStore {
314 if self.is_closed {
315 panic!("Storage has been closed");
316 }
317 self.pre_key_store
318 .as_mut()
319 .expect("Storage not initialized")
320 }
321
322 fn signed_pre_key_store(&mut self) -> &mut Self::SignedPreKeyStore {
323 if self.is_closed {
324 panic!("Storage has been closed");
325 }
326 self.signed_pre_key_store
327 .as_mut()
328 .expect("Storage not initialized")
329 }
330
331 fn kyber_pre_key_store(&mut self) -> &mut Self::KyberPreKeyStore {
332 if self.is_closed {
333 panic!("Storage has been closed");
334 }
335 self.kyber_pre_key_store
336 .as_mut()
337 .expect("Storage not initialized")
338 }
339
340 fn initialize(&mut self) -> Result<(), Box<dyn std::error::Error>> {
341 self.initialize_schema()
342 }
343
344 fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
345 if self.is_closed {
346 return Ok(());
347 }
348
349 {
350 let conn = self.connection.lock().unwrap();
351 conn.execute("PRAGMA optimize", [])?;
352 conn.execute("PRAGMA wal_checkpoint(FULL)", []).ok(); }
354
355 self.is_closed = true;
356
357 self.session_store = None;
358 self.identity_store = None;
359 self.pre_key_store = None;
360 self.signed_pre_key_store = None;
361 self.kyber_pre_key_store = None;
362
363 Ok(())
364 }
365
366 fn storage_type(&self) -> &'static str {
367 "sqlite"
368 }
369}
370
371pub struct SqliteSessionStore {
373 connection: Arc<Mutex<Connection>>,
374}
375
376impl SqliteSessionStore {
377 pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
378 Self { connection }
379 }
380
381 pub fn create_tables(connection: &Connection) -> Result<(), Box<dyn std::error::Error>> {
382 connection.execute(
383 "CREATE TABLE IF NOT EXISTS sessions (
384 address TEXT NOT NULL,
385 device_id INTEGER NOT NULL DEFAULT 1,
386 session_data BLOB NOT NULL,
387 created_at INTEGER DEFAULT (strftime('%s', 'now')),
388 updated_at INTEGER DEFAULT (strftime('%s', 'now')),
389 PRIMARY KEY (address, device_id)
390 )",
391 [],
392 )?;
393
394 Ok(())
395 }
396}
397
398#[async_trait(?Send)]
399impl SessionStore for SqliteSessionStore {
400 async fn load_session(
401 &self,
402 address: &ProtocolAddress,
403 ) -> Result<Option<SessionRecord>, SignalProtocolError> {
404 let conn = self.connection.lock().unwrap();
405 let mut stmt = conn
406 .prepare("SELECT session_data FROM sessions WHERE address = ? AND device_id = ?")
407 .map_err(|e| {
408 SignalProtocolError::InvalidState(
409 "storage",
410 format!("Failed to prepare statement: {}", e),
411 )
412 })?;
413
414 let result = stmt.query_row(
415 [address.name(), &u32::from(address.device_id()).to_string()],
416 |row| {
417 let data: Vec<u8> = row.get(0)?;
418 Ok(data)
419 },
420 );
421
422 match result {
423 Ok(data) => {
424 let session = SessionRecord::deserialize(&data).map_err(|e| {
425 SignalProtocolError::InvalidState(
426 "storage",
427 format!("Failed to deserialize session: {}", e),
428 )
429 })?;
430 Ok(Some(session))
431 }
432 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
433 Err(e) => Err(SignalProtocolError::InvalidState(
434 "storage",
435 format!("Database error: {}", e),
436 )),
437 }
438 }
439
440 async fn store_session(
441 &mut self,
442 address: &ProtocolAddress,
443 record: &SessionRecord,
444 ) -> Result<(), SignalProtocolError> {
445 let conn = self.connection.lock().unwrap();
446 let serialized = record.serialize().map_err(|e| {
447 SignalProtocolError::InvalidState(
448 "storage",
449 format!("Failed to serialize session: {}", e),
450 )
451 })?;
452
453 conn.execute(
454 "INSERT OR REPLACE INTO sessions (address, device_id, session_data, updated_at)
455 VALUES (?, ?, ?, strftime('%s', 'now'))",
456 rusqlite::params![address.name(), u32::from(address.device_id()), &serialized],
457 )
458 .map_err(|e| {
459 SignalProtocolError::InvalidState("storage", format!("Failed to store session: {}", e))
460 })?;
461
462 Ok(())
463 }
464}
465
466#[async_trait(?Send)]
467impl ExtendedSessionStore for SqliteSessionStore {
468 async fn session_count(&self) -> usize {
469 let conn = self.connection.lock().unwrap();
470 let mut stmt = conn.prepare("SELECT COUNT(*) FROM sessions").unwrap();
471 stmt.query_row([], |row| {
472 let count: i64 = row.get(0)?;
473 Ok(count as usize)
474 })
475 .unwrap_or(0)
476 }
477
478 async fn clear_all_sessions(&mut self) -> Result<(), Box<dyn std::error::Error>> {
479 let conn = self.connection.lock().unwrap();
480 conn.execute("DELETE FROM sessions", [])?;
481 Ok(())
482 }
483 async fn delete_session(
484 &mut self,
485 address: &ProtocolAddress,
486 ) -> Result<(), Box<dyn std::error::Error>> {
487 let conn = self.connection.lock().unwrap();
488 conn.execute(
489 "DELETE FROM sessions WHERE address = ?1 AND device_id = ?2",
490 [address.name(), &u32::from(address.device_id()).to_string()],
491 )?;
492 Ok(())
493 }
494}
495
496pub struct SqliteIdentityStore {
498 connection: Arc<Mutex<Connection>>,
499}
500
501impl SqliteIdentityStore {
502 pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
503 Self { connection }
504 }
505
506 pub fn create_tables(connection: &Connection) -> Result<(), Box<dyn std::error::Error>> {
507 connection.execute(
508 "CREATE TABLE IF NOT EXISTS local_identity (
509 id INTEGER PRIMARY KEY CHECK (id = 1),
510 registration_id INTEGER NOT NULL,
511 private_key BLOB NOT NULL,
512 public_key BLOB NOT NULL,
513 created_at INTEGER DEFAULT (strftime('%s', 'now')),
514 updated_at INTEGER DEFAULT (strftime('%s', 'now'))
515 )",
516 [],
517 )?;
518
519 connection.execute(
520 "CREATE TABLE IF NOT EXISTS identity_keys (
521 address TEXT NOT NULL,
522 device_id INTEGER NOT NULL DEFAULT 1,
523 public_key BLOB NOT NULL,
524 trust_level INTEGER DEFAULT 0,
525 first_seen INTEGER DEFAULT (strftime('%s', 'now')),
526 last_seen INTEGER DEFAULT (strftime('%s', 'now')),
527 verified_at INTEGER NULL,
528 PRIMARY KEY (address, device_id)
529 )",
530 [],
531 )?;
532
533 Ok(())
534 }
535}
536
537#[async_trait(?Send)]
538impl IdentityKeyStore for SqliteIdentityStore {
539 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
540 let conn = self.connection.lock().unwrap();
541
542 let mut stmt = conn
543 .prepare("SELECT private_key FROM local_identity WHERE id = 1")
544 .map_err(|e| {
545 SignalProtocolError::InvalidState(
546 "storage",
547 format!("Failed to prepare statement: {}", e),
548 )
549 })?;
550
551 let result = stmt.query_row([], |row| {
552 let serialized_keypair: Vec<u8> = row.get(0)?;
553 Ok(serialized_keypair)
554 });
555
556 match result {
557 Ok(serialized_keypair) => {
558 let identity_key_pair = IdentityKeyPair::try_from(&serialized_keypair[..])
559 .map_err(|e| {
560 SignalProtocolError::InvalidState(
561 "storage",
562 format!("Failed to deserialize identity key pair: {}", e),
563 )
564 })?;
565 Ok(identity_key_pair)
566 }
567 Err(rusqlite::Error::QueryReturnedNoRows) => Err(SignalProtocolError::InvalidState(
568 "storage",
569 "Local identity key pair not set".to_string(),
570 )),
571 Err(e) => Err(SignalProtocolError::InvalidState(
572 "storage",
573 format!("Database error: {}", e),
574 )),
575 }
576 }
577
578 async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> {
579 let conn = self.connection.lock().unwrap();
580
581 let mut stmt = conn
582 .prepare("SELECT registration_id FROM local_identity WHERE id = 1")
583 .map_err(|e| {
584 SignalProtocolError::InvalidState(
585 "storage",
586 format!("Failed to prepare statement: {}", e),
587 )
588 })?;
589
590 let result = stmt.query_row([], |row| {
591 let registration_id: u32 = row.get(0)?;
592 Ok(registration_id)
593 });
594
595 match result {
596 Ok(registration_id) => Ok(registration_id),
597 Err(rusqlite::Error::QueryReturnedNoRows) => Err(SignalProtocolError::InvalidState(
598 "storage",
599 "Local registration ID not set".to_string(),
600 )),
601 Err(e) => Err(SignalProtocolError::InvalidState(
602 "storage",
603 format!("Database error: {}", e),
604 )),
605 }
606 }
607
608 async fn save_identity(
609 &mut self,
610 address: &ProtocolAddress,
611 identity_key: &IdentityKey,
612 ) -> Result<IdentityChange, SignalProtocolError> {
613 let conn = self.connection.lock().unwrap();
614 let serialized_key = identity_key.serialize();
615
616 let mut stmt = conn
617 .prepare("SELECT public_key FROM identity_keys WHERE address = ? AND device_id = ?")
618 .map_err(|e| {
619 SignalProtocolError::InvalidState(
620 "storage",
621 format!("Failed to prepare statement: {}", e),
622 )
623 })?;
624
625 let existing_key = stmt.query_row(
626 [address.name(), &u32::from(address.device_id()).to_string()],
627 |row| {
628 let key_data: Vec<u8> = row.get(0)?;
629 Ok(key_data)
630 },
631 );
632
633 let change_type = match existing_key {
634 Ok(existing_data) => {
635 if existing_data == serialized_key.as_ref() {
636 IdentityChange::NewOrUnchanged
637 } else {
638 IdentityChange::ReplacedExisting
639 }
640 }
641 Err(rusqlite::Error::QueryReturnedNoRows) => IdentityChange::NewOrUnchanged,
642 Err(e) => {
643 return Err(SignalProtocolError::InvalidState(
644 "storage",
645 format!("Database error: {}", e),
646 ))
647 }
648 };
649
650 conn.execute(
651 "INSERT OR REPLACE INTO identity_keys (address, device_id, public_key, last_seen)
652 VALUES (?, ?, ?, strftime('%s', 'now'))",
653 rusqlite::params![
654 address.name(),
655 u32::from(address.device_id()),
656 &serialized_key.as_ref()
657 ],
658 )
659 .map_err(|e| {
660 SignalProtocolError::InvalidState("storage", format!("Failed to store identity: {}", e))
661 })?;
662
663 Ok(change_type)
664 }
665
666 async fn is_trusted_identity(
667 &self,
668 address: &ProtocolAddress,
669 identity_key: &IdentityKey,
670 _direction: Direction,
671 ) -> Result<bool, SignalProtocolError> {
672 let stored_identity = self.get_identity(address).await?;
673
674 match stored_identity {
675 None => Ok(true), Some(stored_key) => Ok(stored_key == *identity_key), }
678 }
679
680 async fn get_identity(
681 &self,
682 address: &ProtocolAddress,
683 ) -> Result<Option<IdentityKey>, SignalProtocolError> {
684 let conn = self.connection.lock().unwrap();
685 let mut stmt = conn
686 .prepare("SELECT public_key FROM identity_keys WHERE address = ? AND device_id = ?")
687 .map_err(|e| {
688 SignalProtocolError::InvalidState(
689 "storage",
690 format!("Failed to prepare statement: {}", e),
691 )
692 })?;
693
694 let result = stmt.query_row(
695 [address.name(), &u32::from(address.device_id()).to_string()],
696 |row| {
697 let key_data: Vec<u8> = row.get(0)?;
698 Ok(key_data)
699 },
700 );
701
702 match result {
703 Ok(key_data) => {
704 let identity_key = IdentityKey::try_from(&key_data[..]).map_err(|e| {
705 SignalProtocolError::InvalidState(
706 "storage",
707 format!("Failed to deserialize identity key: {}", e),
708 )
709 })?;
710 Ok(Some(identity_key))
711 }
712 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
713 Err(e) => Err(SignalProtocolError::InvalidState(
714 "storage",
715 format!("Database error: {}", e),
716 )),
717 }
718 }
719}
720
721#[async_trait(?Send)]
722impl ExtendedIdentityStore for SqliteIdentityStore {
723 async fn identity_count(&self) -> usize {
724 let conn = self.connection.lock().unwrap();
725 let mut stmt = conn.prepare("SELECT COUNT(*) FROM identity_keys").unwrap();
726 stmt.query_row([], |row| {
727 let count: i64 = row.get(0)?;
728 Ok(count as usize)
729 })
730 .unwrap_or(0)
731 }
732
733 async fn set_local_identity_key_pair(
734 &self,
735 identity_key_pair: &IdentityKeyPair,
736 ) -> Result<(), Box<dyn std::error::Error>> {
737 let conn = self.connection.lock().unwrap();
738
739 let serialized_keypair = identity_key_pair.serialize();
740 let public_key_bytes = identity_key_pair.identity_key().serialize();
741
742 conn.execute(
743 "INSERT OR REPLACE INTO local_identity (id, private_key, public_key, registration_id, updated_at)
744 VALUES (1, ?, ?, COALESCE((SELECT registration_id FROM local_identity WHERE id = 1), 0), strftime('%s', 'now'))",
745 rusqlite::params![&serialized_keypair.as_ref(), &public_key_bytes.as_ref()],
746 )?;
747
748 Ok(())
749 }
750
751 async fn set_local_registration_id(
752 &self,
753 registration_id: u32,
754 ) -> Result<(), Box<dyn std::error::Error>> {
755 let conn = self.connection.lock().unwrap();
756
757 conn.execute(
758 "INSERT OR REPLACE INTO local_identity (id, registration_id, private_key, public_key, updated_at)
759 VALUES (1, ?, COALESCE((SELECT private_key FROM local_identity WHERE id = 1), X''), COALESCE((SELECT public_key FROM local_identity WHERE id = 1), X''), strftime('%s', 'now'))",
760 rusqlite::params![registration_id],
761 )?;
762
763 Ok(())
764 }
765 async fn get_peer_identity(
766 &self,
767 address: &ProtocolAddress,
768 ) -> Result<Option<IdentityKey>, Box<dyn std::error::Error>> {
769 let conn = self.connection.lock().unwrap();
770 let mut stmt = conn
771 .prepare("SELECT public_key FROM identity_keys WHERE address = ? AND device_id = ?")?;
772 match stmt.query_row(
773 [address.name(), &u32::from(address.device_id()).to_string()],
774 |row| {
775 let public_key_bytes: Vec<u8> = row.get(0)?;
776 Ok(public_key_bytes)
777 },
778 ) {
779 Ok(key_bytes) => Ok(Some(IdentityKey::decode(&key_bytes)?)),
780 Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
781 Err(e) => Err(e.into()),
782 }
783 }
784 async fn delete_identity(
785 &mut self,
786 address: &ProtocolAddress,
787 ) -> Result<(), Box<dyn std::error::Error>> {
788 let conn = self.connection.lock().unwrap();
789 conn.execute(
790 "DELETE FROM identity_keys WHERE address = ? AND device_id = ?",
791 [address.name(), &u32::from(address.device_id()).to_string()],
792 )?;
793 Ok(())
794 }
795 async fn clear_all_identities(&mut self) -> Result<(), Box<dyn std::error::Error>> {
796 let conn = self.connection.lock().unwrap();
797 conn.execute("DELETE FROM identity_keys", [])?;
798 Ok(())
799 }
800 async fn clear_local_identity(&mut self) -> Result<(), Box<dyn std::error::Error>> {
801 let conn = self.connection.lock().unwrap();
802 conn.execute("DELETE FROM local_identity", [])?;
803 Ok(())
804 }
805}
806
807pub struct SqlitePreKeyStore {
809 connection: Arc<Mutex<Connection>>,
810}
811
812impl SqlitePreKeyStore {
813 pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
814 Self { connection }
815 }
816
817 pub fn create_tables(conn: &Connection) -> Result<(), Box<dyn std::error::Error>> {
818 conn.execute(
819 "CREATE TABLE IF NOT EXISTS pre_keys (
820 id INTEGER PRIMARY KEY,
821 key_data BLOB NOT NULL,
822 created_at INTEGER DEFAULT (strftime('%s', 'now')),
823 used_at INTEGER NULL
824 )",
825 [],
826 )?;
827 Ok(())
828 }
829}
830
831#[async_trait(?Send)]
832impl PreKeyStore for SqlitePreKeyStore {
833 async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result<PreKeyRecord, SignalProtocolError> {
834 let conn = self.connection.lock().unwrap();
835 let mut stmt = conn
836 .prepare("SELECT key_data FROM pre_keys WHERE id = ?")
837 .map_err(|e| {
838 SignalProtocolError::InvalidState(
839 "storage",
840 format!("Failed to prepare statement: {}", e),
841 )
842 })?;
843
844 let result = stmt.query_row([u32::from(prekey_id)], |row| {
845 let key_data: Vec<u8> = row.get(0)?;
846 Ok(key_data)
847 });
848
849 match result {
850 Ok(key_data) => {
851 let prekey_record = PreKeyRecord::deserialize(&key_data).map_err(|e| {
852 SignalProtocolError::InvalidState(
853 "storage",
854 format!("Failed to deserialize pre key: {}", e),
855 )
856 })?;
857 Ok(prekey_record)
858 }
859 Err(rusqlite::Error::QueryReturnedNoRows) => Err(SignalProtocolError::InvalidPreKeyId),
860 Err(e) => Err(SignalProtocolError::InvalidState(
861 "storage",
862 format!("Database error: {}", e),
863 )),
864 }
865 }
866
867 async fn save_pre_key(
868 &mut self,
869 prekey_id: PreKeyId,
870 record: &PreKeyRecord,
871 ) -> Result<(), SignalProtocolError> {
872 let conn = self.connection.lock().unwrap();
873 let serialized = record.serialize().map_err(|e| {
874 SignalProtocolError::InvalidState(
875 "storage",
876 format!("Failed to serialize pre key: {}", e),
877 )
878 })?;
879
880 conn.execute(
881 "INSERT OR REPLACE INTO pre_keys (id, key_data, created_at, used_at)
882 VALUES (?, ?, strftime('%s', 'now'), NULL)",
883 rusqlite::params![u32::from(prekey_id), &serialized],
884 )
885 .map_err(|e| {
886 SignalProtocolError::InvalidState("storage", format!("Failed to store pre key: {}", e))
887 })?;
888
889 Ok(())
890 }
891
892 async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> {
893 let conn = self.connection.lock().unwrap();
894 conn.execute(
895 "DELETE FROM pre_keys WHERE id = ?",
896 rusqlite::params![u32::from(prekey_id)],
897 )
898 .map_err(|e| {
899 SignalProtocolError::InvalidState("storage", format!("Failed to remove pre key: {}", e))
900 })?;
901 Ok(())
902 }
903}
904
905#[async_trait(?Send)]
906impl ExtendedPreKeyStore for SqlitePreKeyStore {
907 async fn pre_key_count(&self) -> usize {
908 let conn = self.connection.lock().unwrap();
909 let mut stmt = conn.prepare("SELECT COUNT(*) FROM pre_keys").unwrap();
910 stmt.query_row([], |row| {
911 let count: i64 = row.get(0)?;
912 Ok(count as usize)
913 })
914 .unwrap_or(0)
915 }
916
917 async fn clear_all_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
918 let conn = self.connection.lock().unwrap();
919 conn.execute("DELETE FROM pre_keys", [])?;
920 Ok(())
921 }
922
923 async fn get_max_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
924 let conn = self.connection.lock().unwrap();
925 let mut stmt = conn.prepare("SELECT MAX(id) FROM pre_keys")?;
926 let max_id: Option<u32> = stmt.query_row([], |row| row.get(0)).ok();
927 Ok(max_id)
928 }
929
930 async fn delete_pre_key(&mut self, id: PreKeyId) -> Result<(), Box<dyn std::error::Error>> {
931 let conn = self.connection.lock().unwrap();
932 conn.execute("DELETE FROM pre_keys WHERE id = ?", [u32::from(id)])?;
933 Ok(())
934 }
935}
936
937pub struct SqliteSignedPreKeyStore {
939 connection: Arc<Mutex<Connection>>,
940}
941
942impl SqliteSignedPreKeyStore {
943 pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
944 Self { connection }
945 }
946
947 pub fn create_tables(conn: &Connection) -> Result<(), Box<dyn std::error::Error>> {
948 conn.execute(
949 "CREATE TABLE IF NOT EXISTS signed_pre_keys (
950 id INTEGER PRIMARY KEY,
951 key_data BLOB NOT NULL,
952 signature BLOB NOT NULL,
953 created_at INTEGER DEFAULT (strftime('%s', 'now')),
954 expires_at INTEGER NOT NULL,
955 is_current BOOLEAN DEFAULT FALSE
956 )",
957 [],
958 )?;
959 Ok(())
960 }
961}
962
963#[async_trait(?Send)]
964impl SignedPreKeyStore for SqliteSignedPreKeyStore {
965 async fn get_signed_pre_key(
966 &self,
967 signed_prekey_id: SignedPreKeyId,
968 ) -> Result<SignedPreKeyRecord, SignalProtocolError> {
969 let conn = self.connection.lock().unwrap();
970 let mut stmt = conn
971 .prepare("SELECT key_data FROM signed_pre_keys WHERE id = ?")
972 .map_err(|e| {
973 SignalProtocolError::InvalidState(
974 "storage",
975 format!("Failed to prepare statement: {}", e),
976 )
977 })?;
978
979 let result = stmt.query_row([u32::from(signed_prekey_id)], |row| {
980 let key_data: Vec<u8> = row.get(0)?;
981 Ok(key_data)
982 });
983
984 match result {
985 Ok(key_data) => {
986 let signed_prekey_record =
987 SignedPreKeyRecord::deserialize(&key_data).map_err(|e| {
988 SignalProtocolError::InvalidState(
989 "storage",
990 format!("Failed to deserialize signed pre key: {}", e),
991 )
992 })?;
993 Ok(signed_prekey_record)
994 }
995 Err(rusqlite::Error::QueryReturnedNoRows) => {
996 Err(SignalProtocolError::InvalidSignedPreKeyId)
997 }
998 Err(e) => Err(SignalProtocolError::InvalidState(
999 "storage",
1000 format!("Database error: {}", e),
1001 )),
1002 }
1003 }
1004
1005 async fn save_signed_pre_key(
1006 &mut self,
1007 signed_prekey_id: SignedPreKeyId,
1008 record: &SignedPreKeyRecord,
1009 ) -> Result<(), SignalProtocolError> {
1010 let conn = self.connection.lock().unwrap();
1011 let serialized = record.serialize().map_err(|e| {
1012 SignalProtocolError::InvalidState(
1013 "storage",
1014 format!("Failed to serialize signed pre key: {}", e),
1015 )
1016 })?;
1017
1018 let record_timestamp_secs = record
1019 .timestamp()
1020 .map_err(|e| {
1021 SignalProtocolError::InvalidState(
1022 "storage",
1023 format!("Failed to get record timestamp: {}", e),
1024 )
1025 })?
1026 .epoch_millis()
1027 / 1000;
1028
1029 let expires_at = record_timestamp_secs + (30 * 24 * 60 * 60); conn.execute(
1033 "INSERT OR REPLACE INTO signed_pre_keys (id, key_data, signature, created_at, expires_at, is_current)
1034 VALUES (?, ?, ?, ?, ?, FALSE)",
1035 rusqlite::params![u32::from(signed_prekey_id), &serialized, &record.signature().map_err(|e| SignalProtocolError::InvalidState("storage", format!("Failed to get signature: {}", e)))?, record_timestamp_secs as i64, expires_at],
1036 ).map_err(|e| SignalProtocolError::InvalidState("storage", format!("Failed to store signed pre key: {}", e)))?;
1037
1038 Ok(())
1039 }
1040}
1041
1042#[async_trait(?Send)]
1043impl ExtendedSignedPreKeyStore for SqliteSignedPreKeyStore {
1044 async fn signed_pre_key_count(&self) -> usize {
1045 let conn = self.connection.lock().unwrap();
1046 let mut stmt = conn
1047 .prepare("SELECT COUNT(*) FROM signed_pre_keys")
1048 .unwrap();
1049 stmt.query_row([], |row| {
1050 let count: i64 = row.get(0)?;
1051 Ok(count as usize)
1052 })
1053 .unwrap_or(0)
1054 }
1055
1056 async fn clear_all_signed_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
1057 let conn = self.connection.lock().unwrap();
1058 conn.execute("DELETE FROM signed_pre_keys", [])?;
1059 Ok(())
1060 }
1061
1062 async fn get_max_signed_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
1063 let conn = self.connection.lock().unwrap();
1064 let mut stmt = conn.prepare("SELECT MAX(id) FROM signed_pre_keys")?;
1065 let max_id: Option<u32> = stmt.query_row([], |row| row.get(0)).ok();
1066 Ok(max_id)
1067 }
1068
1069 async fn delete_signed_pre_key(
1070 &mut self,
1071 id: SignedPreKeyId,
1072 ) -> Result<(), Box<dyn std::error::Error>> {
1073 let conn = self.connection.lock().unwrap();
1074 conn.execute("DELETE FROM signed_pre_keys WHERE id = ?", [u32::from(id)])?;
1075 Ok(())
1076 }
1077
1078 async fn get_signed_pre_keys_older_than(
1079 &self,
1080 timestamp_millis: u64,
1081 ) -> Result<Vec<SignedPreKeyId>, Box<dyn std::error::Error>> {
1082 let conn = self.connection.lock().unwrap();
1083 let timestamp_secs = (timestamp_millis / 1000) as i64;
1084 let mut stmt = conn.prepare("SELECT id FROM signed_pre_keys WHERE created_at < ?")?;
1085 let ids = stmt
1086 .query_map([timestamp_secs], |row| {
1087 let id: u32 = row.get(0)?;
1088 Ok(SignedPreKeyId::from(id))
1089 })?
1090 .collect::<Result<Vec<_>, _>>()?;
1091 Ok(ids)
1092 }
1093}
1094
1095pub struct SqliteKyberPreKeyStore {
1097 connection: Arc<Mutex<Connection>>,
1098}
1099
1100impl SqliteKyberPreKeyStore {
1101 pub fn new(connection: Arc<Mutex<Connection>>) -> Self {
1102 Self { connection }
1103 }
1104
1105 pub fn create_tables(conn: &Connection) -> Result<(), Box<dyn std::error::Error>> {
1106 conn.execute(
1107 "CREATE TABLE IF NOT EXISTS kyber_pre_keys (
1108 id INTEGER PRIMARY KEY,
1109 key_data BLOB NOT NULL,
1110 signature BLOB NOT NULL,
1111 created_at INTEGER DEFAULT (strftime('%s', 'now')),
1112 expires_at INTEGER NOT NULL,
1113 is_current BOOLEAN DEFAULT FALSE
1114 )",
1115 [],
1116 )?;
1117 Ok(())
1118 }
1119}
1120
1121#[async_trait(?Send)]
1122impl KyberPreKeyStore for SqliteKyberPreKeyStore {
1123 async fn get_kyber_pre_key(
1124 &self,
1125 kyber_prekey_id: KyberPreKeyId,
1126 ) -> Result<KyberPreKeyRecord, SignalProtocolError> {
1127 let conn = self.connection.lock().unwrap();
1128 let mut stmt = conn
1129 .prepare("SELECT key_data FROM kyber_pre_keys WHERE id = ?")
1130 .map_err(|e| {
1131 SignalProtocolError::InvalidState(
1132 "storage",
1133 format!("Failed to prepare statement: {}", e),
1134 )
1135 })?;
1136
1137 let result = stmt.query_row([u32::from(kyber_prekey_id)], |row| {
1138 let key_data: Vec<u8> = row.get(0)?;
1139 Ok(key_data)
1140 });
1141
1142 match result {
1143 Ok(key_data) => {
1144 let kyber_prekey_record =
1145 KyberPreKeyRecord::deserialize(&key_data).map_err(|e| {
1146 SignalProtocolError::InvalidState(
1147 "storage",
1148 format!("Failed to deserialize kyber pre key: {}", e),
1149 )
1150 })?;
1151 Ok(kyber_prekey_record)
1152 }
1153 Err(rusqlite::Error::QueryReturnedNoRows) => {
1154 Err(SignalProtocolError::InvalidKyberPreKeyId)
1155 }
1156 Err(e) => Err(SignalProtocolError::InvalidState(
1157 "storage",
1158 format!("Database error: {}", e),
1159 )),
1160 }
1161 }
1162
1163 async fn save_kyber_pre_key(
1164 &mut self,
1165 kyber_prekey_id: KyberPreKeyId,
1166 record: &KyberPreKeyRecord,
1167 ) -> Result<(), SignalProtocolError> {
1168 let conn = self.connection.lock().unwrap();
1169 let serialized = record.serialize().map_err(|e| {
1170 SignalProtocolError::InvalidState(
1171 "storage",
1172 format!("Failed to serialize kyber pre key: {}", e),
1173 )
1174 })?;
1175
1176 let record_timestamp_secs = record
1178 .timestamp()
1179 .map_err(|e| {
1180 SignalProtocolError::InvalidState(
1181 "storage",
1182 format!("Failed to get timestamp: {}", e),
1183 )
1184 })?
1185 .epoch_millis()
1186 / 1000;
1187
1188 let expires_at = record_timestamp_secs + (30 * 24 * 60 * 60); conn.execute(
1192 "INSERT OR REPLACE INTO kyber_pre_keys (id, key_data, signature, created_at, expires_at, is_current)
1193 VALUES (?, ?, ?, ?, ?, FALSE)",
1194 rusqlite::params![u32::from(kyber_prekey_id), &serialized, &record.signature().map_err(|e| SignalProtocolError::InvalidState("storage", format!("Failed to get signature: {}", e)))?, record_timestamp_secs as i64, expires_at],
1195 ).map_err(|e| SignalProtocolError::InvalidState("storage", format!("Failed to store kyber pre key: {}", e)))?;
1196
1197 Ok(())
1198 }
1199
1200 async fn mark_kyber_pre_key_used(
1201 &mut self,
1202 kyber_prekey_id: KyberPreKeyId,
1203 ) -> Result<(), SignalProtocolError> {
1204 let conn = self.connection.lock().unwrap();
1205 conn.execute(
1206 "UPDATE kyber_pre_keys SET is_current = FALSE WHERE id = ?",
1207 rusqlite::params![u32::from(kyber_prekey_id)],
1208 )
1209 .map_err(|e| {
1210 SignalProtocolError::InvalidState(
1211 "storage",
1212 format!("Failed to mark kyber pre key as used: {}", e),
1213 )
1214 })?;
1215 Ok(())
1216 }
1217}
1218
1219#[async_trait(?Send)]
1220impl ExtendedKyberPreKeyStore for SqliteKyberPreKeyStore {
1221 async fn kyber_pre_key_count(&self) -> usize {
1222 let conn = self.connection.lock().unwrap();
1223 let mut stmt = conn.prepare("SELECT COUNT(*) FROM kyber_pre_keys").unwrap();
1224 stmt.query_row([], |row| {
1225 let count: i64 = row.get(0)?;
1226 Ok(count as usize)
1227 })
1228 .unwrap_or(0)
1229 }
1230
1231 async fn clear_all_kyber_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
1232 let conn = self.connection.lock().unwrap();
1233 conn.execute("DELETE FROM kyber_pre_keys", [])?;
1234 Ok(())
1235 }
1236
1237 async fn get_max_kyber_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
1238 let conn = self.connection.lock().unwrap();
1239 let mut stmt = conn.prepare("SELECT MAX(id) FROM kyber_pre_keys")?;
1240 let max_id: Option<u32> = stmt.query_row([], |row| row.get(0)).ok();
1241 Ok(max_id)
1242 }
1243
1244 async fn delete_kyber_pre_key(
1245 &mut self,
1246 id: KyberPreKeyId,
1247 ) -> Result<(), Box<dyn std::error::Error>> {
1248 let conn = self.connection.lock().unwrap();
1249 conn.execute("DELETE FROM kyber_pre_keys WHERE id = ?", [u32::from(id)])?;
1250 Ok(())
1251 }
1252
1253 async fn get_kyber_pre_keys_older_than(
1254 &self,
1255 timestamp_millis: u64,
1256 ) -> Result<Vec<KyberPreKeyId>, Box<dyn std::error::Error>> {
1257 let conn = self.connection.lock().unwrap();
1258 let timestamp_secs = (timestamp_millis / 1000) as i64;
1259 let mut stmt = conn.prepare("SELECT id FROM kyber_pre_keys WHERE created_at < ?")?;
1260 let ids = stmt
1261 .query_map([timestamp_secs], |row| {
1262 let id: u32 = row.get(0)?;
1263 Ok(KyberPreKeyId::from(id))
1264 })?
1265 .collect::<Result<Vec<_>, _>>()?;
1266 Ok(ids)
1267 }
1268}
1269
1270#[async_trait(?Send)]
1271impl ExtendedStorageOps for SqliteStorage {
1272 async fn establish_session_from_bundle(
1273 &mut self,
1274 address: &ProtocolAddress,
1275 bundle: &PreKeyBundle,
1276 ) -> Result<(), Box<dyn std::error::Error>> {
1277 let mut rng = rand::rng();
1278 let timestamp = std::time::SystemTime::now();
1279
1280 process_prekey_bundle(
1281 address,
1282 self.session_store
1283 .as_mut()
1284 .expect("Storage not initialized"),
1285 self.identity_store
1286 .as_mut()
1287 .expect("Storage not initialized"),
1288 bundle,
1289 timestamp,
1290 &mut rng,
1291 UsePQRatchet::Yes,
1292 )
1293 .await?;
1294
1295 Ok(())
1296 }
1297
1298 async fn encrypt_message(
1299 &mut self,
1300 remote_address: &ProtocolAddress,
1301 plaintext: &[u8],
1302 ) -> Result<CiphertextMessage, SignalProtocolError> {
1303 let mut rng = rand::rng();
1304 let now = std::time::SystemTime::now();
1305
1306 message_encrypt(
1307 plaintext,
1308 remote_address,
1309 self.session_store
1310 .as_mut()
1311 .expect("Storage not initialized"),
1312 self.identity_store
1313 .as_mut()
1314 .expect("Storage not initialized"),
1315 now,
1316 &mut rng,
1317 )
1318 .await
1319 }
1320
1321 async fn decrypt_message(
1322 &mut self,
1323 remote_address: &ProtocolAddress,
1324 ciphertext: &CiphertextMessage,
1325 ) -> Result<Vec<u8>, SignalProtocolError> {
1326 let mut rng = rand::rng();
1327
1328 message_decrypt(
1329 ciphertext,
1330 remote_address,
1331 self.session_store
1332 .as_mut()
1333 .expect("Storage not initialized"),
1334 self.identity_store
1335 .as_mut()
1336 .expect("Storage not initialized"),
1337 self.pre_key_store
1338 .as_mut()
1339 .expect("Storage not initialized"),
1340 self.signed_pre_key_store
1341 .as_mut()
1342 .expect("Storage not initialized"),
1343 self.kyber_pre_key_store
1344 .as_mut()
1345 .expect("Storage not initialized"),
1346 &mut rng,
1347 UsePQRatchet::Yes,
1348 )
1349 .await
1350 }
1351}
1352
1353#[cfg(test)]
1354mod tests {
1355 use super::*;
1356 use libsignal_protocol::{kem, process_prekey_bundle, DeviceId, PreKeyBundle, Timestamp};
1357
1358 #[tokio::test]
1359 async fn test_sqlite_storage_creation() -> Result<(), Box<dyn std::error::Error>> {
1360 let storage = SqliteStorage::new(":memory:").await?;
1361 assert_eq!(storage.storage_type(), "sqlite");
1362 Ok(())
1363 }
1364
1365 #[tokio::test]
1366 async fn test_sqlite_storage_schema_initialization() -> Result<(), Box<dyn std::error::Error>> {
1367 let mut storage = SqliteStorage::new(":memory:").await?;
1368 storage.initialize_schema()?;
1369
1370 let version = storage.get_schema_version()?;
1371 assert_eq!(version, 2);
1372
1373 Ok(())
1374 }
1375
1376 #[tokio::test]
1377 async fn test_sqlite_session_store_count_empty() -> Result<(), Box<dyn std::error::Error>> {
1378 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1379 SqliteSessionStore::create_tables(&connection.lock().unwrap())?;
1380 let session_store = SqliteSessionStore::new(connection);
1381
1382 assert_eq!(session_store.session_count().await, 0);
1383
1384 Ok(())
1385 }
1386
1387 #[tokio::test]
1388 async fn test_sqlite_identity_store_count_empty() -> Result<(), Box<dyn std::error::Error>> {
1389 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1390 SqliteIdentityStore::create_tables(&connection.lock().unwrap())?;
1391 let identity_store = SqliteIdentityStore::new(connection);
1392
1393 assert_eq!(identity_store.identity_count().await, 0);
1394
1395 Ok(())
1396 }
1397
1398 #[tokio::test]
1399 async fn test_sqlite_pre_key_store_count_empty() -> Result<(), Box<dyn std::error::Error>> {
1400 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1401 SqlitePreKeyStore::create_tables(&connection.lock().unwrap())?;
1402 let pre_key_store = SqlitePreKeyStore::new(connection);
1403
1404 assert_eq!(pre_key_store.pre_key_count().await, 0);
1405
1406 Ok(())
1407 }
1408
1409 #[tokio::test]
1410 async fn test_sqlite_signed_pre_key_store_count_empty() -> Result<(), Box<dyn std::error::Error>>
1411 {
1412 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1413 SqliteSignedPreKeyStore::create_tables(&connection.lock().unwrap())?;
1414 let signed_pre_key_store = SqliteSignedPreKeyStore::new(connection);
1415 assert_eq!(signed_pre_key_store.signed_pre_key_count().await, 0);
1416
1417 Ok(())
1418 }
1419
1420 #[tokio::test]
1421 async fn test_sqlite_kyber_pre_key_store_count_empty() -> Result<(), Box<dyn std::error::Error>>
1422 {
1423 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1424 SqliteKyberPreKeyStore::create_tables(&connection.lock().unwrap())?;
1425 let kyber_pre_key_store = SqliteKyberPreKeyStore::new(connection);
1426 assert_eq!(kyber_pre_key_store.kyber_pre_key_count().await, 0);
1427
1428 Ok(())
1429 }
1430
1431 #[tokio::test]
1432 async fn test_sqlite_storage_container() -> Result<(), Box<dyn std::error::Error>> {
1433 let mut storage = SqliteStorage::new(":memory:").await?;
1434 storage.initialize()?;
1435
1436 assert_eq!(storage.storage_type(), "sqlite");
1437 assert_eq!(storage.session_store().session_count().await, 0);
1438 assert_eq!(storage.identity_store().identity_count().await, 0);
1439
1440 Ok(())
1441 }
1442
1443 #[tokio::test]
1444 async fn test_identity_key_store_trait() -> Result<(), Box<dyn std::error::Error>> {
1445 let mut storage = SqliteStorage::new(":memory:").await?;
1446 storage.initialize_schema()?;
1447
1448 let mut rng = rand::rng();
1449 let identity_key_pair = IdentityKeyPair::generate(&mut rng);
1450
1451 storage
1452 .identity_store()
1453 .set_local_identity_key_pair(&identity_key_pair)
1454 .await?;
1455 storage
1456 .identity_store()
1457 .set_local_registration_id(12345)
1458 .await?;
1459
1460 let retrieved_registration = storage.identity_store().get_local_registration_id().await?;
1461 assert_eq!(retrieved_registration, 12345);
1462
1463 let retrieved_identity = storage.identity_store().get_identity_key_pair().await?;
1464 assert_eq!(
1465 retrieved_identity.identity_key().serialize(),
1466 identity_key_pair.identity_key().serialize()
1467 );
1468
1469 Ok(())
1470 }
1471
1472 #[tokio::test]
1473 async fn test_pre_key_store_trait() -> Result<(), Box<dyn std::error::Error>> {
1474 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1475 SqlitePreKeyStore::create_tables(&connection.lock().unwrap())?;
1476 let mut pre_key_store = SqlitePreKeyStore::new(connection);
1477
1478 let mut rng = rand::rng();
1479 let key_pair = libsignal_protocol::KeyPair::generate(&mut rng);
1480 let prekey_id = PreKeyId::from(42u32);
1481 let prekey_record = PreKeyRecord::new(prekey_id, &key_pair);
1482
1483 pre_key_store
1484 .save_pre_key(prekey_id, &prekey_record)
1485 .await?;
1486 let retrieved_prekey = pre_key_store.get_pre_key(prekey_id).await?;
1487
1488 assert_eq!(retrieved_prekey.id()?, prekey_id);
1489 assert_eq!(
1490 retrieved_prekey.public_key()?.serialize(),
1491 prekey_record.public_key()?.serialize()
1492 );
1493
1494 let non_existent_id = PreKeyId::from(999u32);
1495 let result = pre_key_store.get_pre_key(non_existent_id).await;
1496 assert!(result.is_err());
1497
1498 Ok(())
1499 }
1500
1501 #[tokio::test]
1502 async fn test_signed_pre_key_store_trait() -> Result<(), Box<dyn std::error::Error>> {
1503 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1504 SqliteSignedPreKeyStore::create_tables(&connection.lock().unwrap())?;
1505 let mut signed_pre_key_store = SqliteSignedPreKeyStore::new(connection);
1506
1507 let mut rng = rand::rng();
1508 let identity_key_pair = IdentityKeyPair::generate(&mut rng);
1509 let key_pair = libsignal_protocol::KeyPair::generate(&mut rng);
1510 let signed_prekey_id = SignedPreKeyId::from(42u32);
1511 let timestamp = Timestamp::from_epoch_millis(
1512 std::time::SystemTime::now()
1513 .duration_since(std::time::UNIX_EPOCH)?
1514 .as_millis() as u64,
1515 );
1516 let signature = identity_key_pair
1517 .private_key()
1518 .calculate_signature(&key_pair.public_key.serialize(), &mut rng)?;
1519 let signed_prekey_record =
1520 SignedPreKeyRecord::new(signed_prekey_id, timestamp, &key_pair, &signature);
1521
1522 signed_pre_key_store
1523 .save_signed_pre_key(signed_prekey_id, &signed_prekey_record)
1524 .await?;
1525 let retrieved_prekey = signed_pre_key_store
1526 .get_signed_pre_key(signed_prekey_id)
1527 .await?;
1528
1529 assert_eq!(retrieved_prekey.id()?, signed_prekey_id);
1530 assert_eq!(
1531 retrieved_prekey.public_key()?.serialize(),
1532 signed_prekey_record.public_key()?.serialize()
1533 );
1534
1535 let non_existent_id = SignedPreKeyId::from(999u32);
1536 let result = signed_pre_key_store
1537 .get_signed_pre_key(non_existent_id)
1538 .await;
1539 assert!(result.is_err());
1540
1541 Ok(())
1542 }
1543
1544 #[tokio::test]
1545 async fn test_identity_store_remote_identities() -> Result<(), Box<dyn std::error::Error>> {
1546 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1547 SqliteIdentityStore::create_tables(&connection.lock().unwrap())?;
1548 let mut identity_store = SqliteIdentityStore::new(connection);
1549
1550 let mut rng = rand::rng();
1551 let alice_address =
1552 ProtocolAddress::new("alice@example.com".to_string(), DeviceId::new(1)?);
1553 let bob_address = ProtocolAddress::new("bob@example.com".to_string(), DeviceId::new(1)?);
1554 let alice_identity =
1555 IdentityKey::new(libsignal_protocol::KeyPair::generate(&mut rng).public_key);
1556
1557 let result = identity_store
1558 .save_identity(&alice_address, &alice_identity)
1559 .await?;
1560 assert_eq!(result, IdentityChange::NewOrUnchanged);
1561
1562 let retrieved_identity = identity_store.get_identity(&alice_address).await?;
1563 assert_eq!(retrieved_identity, Some(alice_identity));
1564
1565 assert!(
1566 identity_store
1567 .is_trusted_identity(&alice_address, &alice_identity, Direction::Receiving)
1568 .await?
1569 );
1570 assert!(
1571 identity_store
1572 .is_trusted_identity(&alice_address, &alice_identity, Direction::Sending)
1573 .await?
1574 );
1575
1576 let alice_new_identity =
1577 IdentityKey::new(libsignal_protocol::KeyPair::generate(&mut rng).public_key);
1578 let result = identity_store
1579 .save_identity(&alice_address, &alice_new_identity)
1580 .await?;
1581 assert_eq!(result, IdentityChange::ReplacedExisting);
1582
1583 let unknown_identity = identity_store.get_identity(&bob_address).await?;
1584 assert_eq!(unknown_identity, None);
1585
1586 Ok(())
1587 }
1588
1589 #[tokio::test]
1590 async fn test_sqlite_storage_close() -> Result<(), Box<dyn std::error::Error>> {
1591 let mut storage = SqliteStorage::new(":memory:").await?;
1592 storage.initialize_schema()?;
1593
1594 let mut rng = rand::rng();
1595 let identity_key_pair = IdentityKeyPair::generate(&mut rng);
1596 storage
1597 .identity_store()
1598 .set_local_identity_key_pair(&identity_key_pair)
1599 .await?;
1600
1601 storage.close()?;
1602
1603 storage.close()?; assert!(storage.is_closed(), "Storage should be marked as closed");
1606
1607 Ok(())
1608 }
1609
1610 #[tokio::test]
1611 async fn test_kyber_pre_key_store_trait() -> Result<(), Box<dyn std::error::Error>> {
1612 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1613 SqliteKyberPreKeyStore::create_tables(&connection.lock().unwrap())?;
1614 let mut kyber_pre_key_store = SqliteKyberPreKeyStore::new(connection);
1615
1616 let mut rng = rand::rng();
1617 let key_pair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
1618 let kyber_prekey_id = KyberPreKeyId::from(42u32);
1619 let timestamp = Timestamp::from_epoch_millis(
1620 std::time::SystemTime::now()
1621 .duration_since(std::time::UNIX_EPOCH)?
1622 .as_millis() as u64,
1623 );
1624 let signature = b"test_signature";
1625 let kyber_prekey_record =
1626 KyberPreKeyRecord::new(kyber_prekey_id, timestamp, &key_pair, signature);
1627
1628 kyber_pre_key_store
1629 .save_kyber_pre_key(kyber_prekey_id, &kyber_prekey_record)
1630 .await?;
1631 let retrieved_prekey = kyber_pre_key_store
1632 .get_kyber_pre_key(kyber_prekey_id)
1633 .await?;
1634
1635 assert_eq!(retrieved_prekey.id()?, kyber_prekey_id);
1636 assert_eq!(
1637 retrieved_prekey.public_key()?.serialize(),
1638 kyber_prekey_record.public_key()?.serialize()
1639 );
1640
1641 kyber_pre_key_store
1642 .mark_kyber_pre_key_used(kyber_prekey_id)
1643 .await?;
1644
1645 let non_existent_id = KyberPreKeyId::from(999u32);
1646 let result = kyber_pre_key_store.get_kyber_pre_key(non_existent_id).await;
1647 assert!(result.is_err());
1648
1649 Ok(())
1650 }
1651
1652 struct MockIdentityStore {
1653 identity_key_pair: IdentityKeyPair,
1654 registration_id: u32,
1655 }
1656
1657 impl MockIdentityStore {
1658 fn new(identity_key_pair: IdentityKeyPair) -> Self {
1659 Self {
1660 identity_key_pair,
1661 registration_id: 12345,
1662 }
1663 }
1664 }
1665
1666 #[async_trait(?Send)]
1667 impl IdentityKeyStore for MockIdentityStore {
1668 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
1669 Ok(self.identity_key_pair)
1670 }
1671
1672 async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> {
1673 Ok(self.registration_id)
1674 }
1675
1676 async fn save_identity(
1677 &mut self,
1678 _address: &ProtocolAddress,
1679 _identity: &IdentityKey,
1680 ) -> Result<IdentityChange, SignalProtocolError> {
1681 Ok(IdentityChange::NewOrUnchanged)
1682 }
1683
1684 async fn is_trusted_identity(
1685 &self,
1686 _address: &ProtocolAddress,
1687 _identity: &IdentityKey,
1688 _direction: Direction,
1689 ) -> Result<bool, SignalProtocolError> {
1690 Ok(true)
1691 }
1692
1693 async fn get_identity(
1694 &self,
1695 _address: &ProtocolAddress,
1696 ) -> Result<Option<IdentityKey>, SignalProtocolError> {
1697 Ok(None)
1698 }
1699 }
1700
1701 #[tokio::test]
1702 async fn test_session_store_persistence() -> Result<(), Box<dyn std::error::Error>> {
1703 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1704 SqliteSessionStore::create_tables(&connection.lock().unwrap())?;
1705 let mut session_store = SqliteSessionStore::new(connection);
1706
1707 let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
1708
1709 let initial_result = session_store.load_session(&address).await?;
1710 assert!(
1711 initial_result.is_none(),
1712 "No session should exist initially"
1713 );
1714
1715 let mut rng = rand::rng();
1718 let alice_identity = IdentityKeyPair::generate(&mut rng);
1719 let bob_identity = IdentityKeyPair::generate(&mut rng);
1720
1721 let bob_pre_keys = crate::keys::generate_pre_keys(1, 1).await?;
1722 let bob_signed_pre_key = crate::keys::generate_signed_pre_key(&bob_identity, 1).await?;
1723 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
1724 let kyber_signature = bob_identity
1725 .private_key()
1726 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
1727
1728 let bundle = PreKeyBundle::new(
1729 12345,
1730 address.device_id(),
1731 Some((bob_pre_keys[0].0.into(), bob_pre_keys[0].1.public_key)),
1732 bob_signed_pre_key.id()?,
1733 bob_signed_pre_key.public_key()?,
1734 bob_signed_pre_key.signature()?.to_vec(),
1735 KyberPreKeyId::from(1u32),
1736 kyber_keypair.public_key,
1737 kyber_signature.to_vec(),
1738 *bob_identity.identity_key(),
1739 )?;
1740
1741 use libsignal_protocol::UsePQRatchet;
1742 use std::time::SystemTime;
1743
1744 process_prekey_bundle(
1745 &address,
1746 &mut session_store,
1747 &mut MockIdentityStore::new(alice_identity),
1748 &bundle,
1749 SystemTime::now(),
1750 &mut rng,
1751 UsePQRatchet::Yes,
1752 )
1753 .await?;
1754
1755 let loaded_session = session_store.load_session(&address).await?;
1756 assert!(
1757 loaded_session.is_some(),
1758 "Session should exist after process_prekey_bundle"
1759 );
1760
1761 let loaded_session = loaded_session.unwrap();
1762 let _serialized = loaded_session.serialize()?;
1763
1764 Ok(())
1765 }
1766
1767 #[tokio::test]
1768 async fn test_pre_key_store_removal() -> Result<(), Box<dyn std::error::Error>> {
1769 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1770 SqlitePreKeyStore::create_tables(&connection.lock().unwrap())?;
1771 let mut pre_key_store = SqlitePreKeyStore::new(connection);
1772
1773 let mut rng = rand::rng();
1774 let key_pair = libsignal_protocol::KeyPair::generate(&mut rng);
1775 let prekey_id = PreKeyId::from(42u32);
1776 let prekey_record = PreKeyRecord::new(prekey_id, &key_pair);
1777
1778 pre_key_store
1779 .save_pre_key(prekey_id, &prekey_record)
1780 .await?;
1781
1782 let retrieved_prekey = pre_key_store.get_pre_key(prekey_id).await?;
1783 assert_eq!(retrieved_prekey.id()?, prekey_id);
1784
1785 pre_key_store.remove_pre_key(prekey_id).await?;
1786
1787 let result = pre_key_store.get_pre_key(prekey_id).await;
1788 assert!(
1789 result.is_err(),
1790 "PreKey should be removed and no longer retrievable"
1791 );
1792
1793 pre_key_store.remove_pre_key(prekey_id).await?;
1794
1795 Ok(())
1796 }
1797
1798 #[tokio::test]
1799 async fn test_error_handling_nonexistent_keys() -> Result<(), Box<dyn std::error::Error>> {
1800 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1801 SqlitePreKeyStore::create_tables(&connection.lock().unwrap())?;
1802 SqliteSignedPreKeyStore::create_tables(&connection.lock().unwrap())?;
1803 SqliteKyberPreKeyStore::create_tables(&connection.lock().unwrap())?;
1804 SqliteSessionStore::create_tables(&connection.lock().unwrap())?;
1805 SqliteIdentityStore::create_tables(&connection.lock().unwrap())?;
1806
1807 let pre_key_store = SqlitePreKeyStore::new(connection.clone());
1808 let signed_pre_key_store = SqliteSignedPreKeyStore::new(connection.clone());
1809 let kyber_pre_key_store = SqliteKyberPreKeyStore::new(connection.clone());
1810 let session_store = SqliteSessionStore::new(connection.clone());
1811 let identity_store = SqliteIdentityStore::new(connection);
1812
1813 let result = pre_key_store.get_pre_key(PreKeyId::from(999u32)).await;
1814 assert!(result.is_err(), "Getting nonexistent PreKey should fail");
1815
1816 let result = signed_pre_key_store
1817 .get_signed_pre_key(SignedPreKeyId::from(999u32))
1818 .await;
1819 assert!(
1820 result.is_err(),
1821 "Getting nonexistent SignedPreKey should fail"
1822 );
1823
1824 let result = kyber_pre_key_store
1825 .get_kyber_pre_key(KyberPreKeyId::from(999u32))
1826 .await;
1827 assert!(
1828 result.is_err(),
1829 "Getting nonexistent KyberPreKey should fail"
1830 );
1831
1832 let address = ProtocolAddress::new("nonexistent".to_string(), DeviceId::new(1)?);
1833 let result = session_store.load_session(&address).await?;
1834 assert!(
1835 result.is_none(),
1836 "Loading nonexistent session should return None"
1837 );
1838
1839 let result = identity_store.get_identity(&address).await?;
1840 assert!(
1841 result.is_none(),
1842 "Getting nonexistent remote identity should return None"
1843 );
1844
1845 Ok(())
1846 }
1847
1848 #[tokio::test]
1849 async fn test_error_handling_database_constraints() -> Result<(), Box<dyn std::error::Error>> {
1850 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1851 SqlitePreKeyStore::create_tables(&connection.lock().unwrap())?;
1852 let mut pre_key_store = SqlitePreKeyStore::new(connection);
1853
1854 let mut rng = rand::rng();
1855 let key_pair1 = libsignal_protocol::KeyPair::generate(&mut rng);
1856 let key_pair2 = libsignal_protocol::KeyPair::generate(&mut rng);
1857 let prekey_id = PreKeyId::from(42u32);
1858 let prekey_record1 = PreKeyRecord::new(prekey_id, &key_pair1);
1859 let prekey_record2 = PreKeyRecord::new(prekey_id, &key_pair2);
1860
1861 pre_key_store
1862 .save_pre_key(prekey_id, &prekey_record1)
1863 .await?;
1864
1865 pre_key_store
1866 .save_pre_key(prekey_id, &prekey_record2)
1867 .await?;
1868
1869 let retrieved = pre_key_store.get_pre_key(prekey_id).await?;
1870 assert_eq!(
1871 retrieved.public_key()?.serialize(),
1872 prekey_record2.public_key()?.serialize()
1873 );
1874
1875 Ok(())
1876 }
1877
1878 #[tokio::test]
1879 async fn test_identity_change_detection() -> Result<(), Box<dyn std::error::Error>> {
1880 let connection = Arc::new(Mutex::new(Connection::open(":memory:")?));
1881 SqliteIdentityStore::create_tables(&connection.lock().unwrap())?;
1882 let mut identity_store = SqliteIdentityStore::new(connection);
1883
1884 let mut rng = rand::rng();
1885 let identity1 = IdentityKeyPair::generate(&mut rng);
1886 let identity2 = IdentityKeyPair::generate(&mut rng);
1887 let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
1888
1889 let result = identity_store
1890 .save_identity(&address, identity1.identity_key())
1891 .await?;
1892 assert_eq!(
1893 result,
1894 IdentityChange::NewOrUnchanged,
1895 "First save should be NewOrUnchanged"
1896 );
1897
1898 let result = identity_store
1899 .save_identity(&address, identity1.identity_key())
1900 .await?;
1901 assert_eq!(
1902 result,
1903 IdentityChange::NewOrUnchanged,
1904 "Same identity should be NewOrUnchanged"
1905 );
1906
1907 let result = identity_store
1908 .save_identity(&address, identity2.identity_key())
1909 .await?;
1910 assert_eq!(
1911 result,
1912 IdentityChange::ReplacedExisting,
1913 "Different identity should be ReplacedExisting"
1914 );
1915
1916 let retrieved = identity_store.get_identity(&address).await?;
1917 assert!(retrieved.is_some(), "Identity should be retrievable");
1918 assert_eq!(
1919 retrieved.unwrap().serialize(),
1920 identity2.identity_key().serialize(),
1921 "New identity should be stored"
1922 );
1923
1924 Ok(())
1925 }
1926
1927 #[tokio::test]
1928 async fn test_extended_storage_ops_session_establishment(
1929 ) -> Result<(), Box<dyn std::error::Error>> {
1930 use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
1931 use libsignal_protocol::*;
1932
1933 let mut storage = SqliteStorage::new(":memory:").await?;
1934 storage.initialize()?;
1935
1936 let mut rng = rand::rng();
1937 let identity_key_pair = generate_identity_key_pair().await?;
1938 let registration_id = 12345u32;
1939
1940 storage
1941 .identity_store
1942 .as_mut()
1943 .unwrap()
1944 .set_local_identity_key_pair(&identity_key_pair)
1945 .await?;
1946 storage
1947 .identity_store
1948 .as_mut()
1949 .unwrap()
1950 .set_local_registration_id(registration_id)
1951 .await?;
1952
1953 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
1954
1955 let bob_identity = generate_identity_key_pair().await?;
1956 let bob_pre_keys = generate_pre_keys(1, 1).await?;
1957 let bob_signed_pre_key = generate_signed_pre_key(&bob_identity, 1).await?;
1958
1959 let kyber_keypair = libsignal_protocol::kem::KeyPair::generate(
1960 libsignal_protocol::kem::KeyType::Kyber1024,
1961 &mut rng,
1962 );
1963 let kyber_signature = bob_identity
1964 .private_key()
1965 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
1966
1967 let bob_bundle = PreKeyBundle::new(
1968 registration_id,
1969 DeviceId::new(1)?,
1970 Some((
1971 PreKeyId::from(bob_pre_keys[0].0),
1972 bob_pre_keys[0].1.public_key,
1973 )),
1974 SignedPreKeyId::from(1u32),
1975 bob_signed_pre_key.public_key()?,
1976 bob_signed_pre_key.signature()?.to_vec(),
1977 KyberPreKeyId::from(1u32),
1978 kyber_keypair.public_key,
1979 kyber_signature.to_vec(),
1980 *bob_identity.identity_key(),
1981 )?;
1982
1983 let session_count_before = storage.session_store().session_count().await;
1984 assert_eq!(session_count_before, 0, "Should start with no sessions");
1985
1986 storage
1987 .establish_session_from_bundle(&bob_address, &bob_bundle)
1988 .await?;
1989
1990 let session_count_after = storage.session_store().session_count().await;
1991 assert_eq!(
1992 session_count_after, 1,
1993 "Should have one session after establishment"
1994 );
1995
1996 let has_session = storage
1997 .session_store()
1998 .load_session(&bob_address)
1999 .await?
2000 .is_some();
2001 assert!(has_session, "Session should exist for Bob");
2002
2003 Ok(())
2004 }
2005
2006 #[tokio::test]
2007 async fn test_extended_storage_ops_encrypt_decrypt_message(
2008 ) -> Result<(), Box<dyn std::error::Error>> {
2009 use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
2010 use libsignal_protocol::*;
2011
2012 let mut alice_storage = SqliteStorage::new(":memory:").await?;
2013 alice_storage.initialize()?;
2014 let mut bob_storage = SqliteStorage::new(":memory:").await?;
2015 bob_storage.initialize()?;
2016
2017 let mut rng = rand::rng();
2018 let alice_identity = generate_identity_key_pair().await?;
2019 let alice_registration_id = 11111u32;
2020 alice_storage
2021 .identity_store
2022 .as_mut()
2023 .unwrap()
2024 .set_local_identity_key_pair(&alice_identity)
2025 .await?;
2026 alice_storage
2027 .identity_store
2028 .as_mut()
2029 .unwrap()
2030 .set_local_registration_id(alice_registration_id)
2031 .await?;
2032
2033 let bob_identity = generate_identity_key_pair().await?;
2034 let bob_registration_id = 22222u32;
2035 bob_storage
2036 .identity_store
2037 .as_mut()
2038 .unwrap()
2039 .set_local_identity_key_pair(&bob_identity)
2040 .await?;
2041 bob_storage
2042 .identity_store
2043 .as_mut()
2044 .unwrap()
2045 .set_local_registration_id(bob_registration_id)
2046 .await?;
2047
2048 let alice_address = ProtocolAddress::new("alice".to_string(), DeviceId::new(1)?);
2049 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
2050
2051 let bob_pre_keys = generate_pre_keys(1, 1).await?;
2052 let bob_signed_pre_key = generate_signed_pre_key(&bob_identity, 1).await?;
2053 bob_storage
2054 .pre_key_store
2055 .as_mut()
2056 .unwrap()
2057 .save_pre_key(
2058 PreKeyId::from(bob_pre_keys[0].0),
2059 &PreKeyRecord::new(PreKeyId::from(bob_pre_keys[0].0), &bob_pre_keys[0].1),
2060 )
2061 .await?;
2062 bob_storage
2063 .signed_pre_key_store
2064 .as_mut()
2065 .unwrap()
2066 .save_signed_pre_key(SignedPreKeyId::from(1u32), &bob_signed_pre_key)
2067 .await?;
2068
2069 let kyber_keypair = libsignal_protocol::kem::KeyPair::generate(
2070 libsignal_protocol::kem::KeyType::Kyber1024,
2071 &mut rng,
2072 );
2073 let kyber_signature = bob_identity
2074 .private_key()
2075 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
2076
2077 let now = std::time::SystemTime::now();
2078 let kyber_record = KyberPreKeyRecord::new(
2079 KyberPreKeyId::from(1u32),
2080 Timestamp::from_epoch_millis(
2081 now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
2082 ),
2083 &kyber_keypair,
2084 &kyber_signature,
2085 );
2086 bob_storage
2087 .kyber_pre_key_store
2088 .as_mut()
2089 .unwrap()
2090 .save_kyber_pre_key(KyberPreKeyId::from(1u32), &kyber_record)
2091 .await?;
2092
2093 let bob_bundle = PreKeyBundle::new(
2094 bob_registration_id,
2095 DeviceId::new(1)?,
2096 Some((
2097 PreKeyId::from(bob_pre_keys[0].0),
2098 bob_pre_keys[0].1.public_key,
2099 )),
2100 SignedPreKeyId::from(1u32),
2101 bob_signed_pre_key.public_key()?,
2102 bob_signed_pre_key.signature()?.to_vec(),
2103 KyberPreKeyId::from(1u32),
2104 kyber_keypair.public_key,
2105 kyber_signature.to_vec(),
2106 *bob_identity.identity_key(),
2107 )?;
2108
2109 alice_storage
2110 .establish_session_from_bundle(&bob_address, &bob_bundle)
2111 .await?;
2112
2113 let plaintext = b"Hello, Bob! This is a test message.";
2114 let ciphertext = alice_storage
2115 .encrypt_message(&bob_address, plaintext)
2116 .await?;
2117
2118 assert!(
2119 matches!(ciphertext.message_type(), CiphertextMessageType::PreKey),
2120 "First message should be a PreKey message"
2121 );
2122
2123 bob_storage
2124 .identity_store
2125 .as_mut()
2126 .unwrap()
2127 .save_identity(&alice_address, alice_identity.identity_key())
2128 .await?;
2129 let decrypted = bob_storage
2130 .decrypt_message(&alice_address, &ciphertext)
2131 .await?;
2132
2133 assert_eq!(
2134 decrypted, plaintext,
2135 "Decrypted message should match original"
2136 );
2137
2138 let response_plaintext = b"Hello, Alice! Got your message.";
2139 let response_ciphertext = bob_storage
2140 .encrypt_message(&alice_address, response_plaintext)
2141 .await?;
2142
2143 assert!(
2144 matches!(
2145 response_ciphertext.message_type(),
2146 CiphertextMessageType::Whisper
2147 ),
2148 "Response message should be a Whisper message"
2149 );
2150
2151 let response_decrypted = alice_storage
2152 .decrypt_message(&bob_address, &response_ciphertext)
2153 .await?;
2154 assert_eq!(
2155 response_decrypted, response_plaintext,
2156 "Response should decrypt correctly"
2157 );
2158
2159 Ok(())
2160 }
2161
2162 #[tokio::test]
2163 async fn test_extended_storage_ops_error_handling() -> Result<(), Box<dyn std::error::Error>> {
2164 let mut storage = SqliteStorage::new(":memory:").await?;
2165 storage.initialize()?;
2166
2167 let invalid_address = ProtocolAddress::new("nonexistent".to_string(), DeviceId::new(1)?);
2168
2169 let plaintext = b"test message";
2170 let encrypt_result = storage.encrypt_message(&invalid_address, plaintext).await;
2171 assert!(
2172 encrypt_result.is_err(),
2173 "Should fail to encrypt without session"
2174 );
2175
2176 Ok(())
2177 }
2178
2179 #[tokio::test]
2180 async fn test_sqlite_extended_storage_ops_integration() -> Result<(), Box<dyn std::error::Error>>
2181 {
2182 use libsignal_protocol::*;
2183
2184 use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
2185 use crate::sqlite_storage::SqliteStorage;
2186 use crate::storage_trait::{
2187 ExtendedIdentityStore, ExtendedSessionStore, ExtendedStorageOps, SignalStorageContainer,
2188 };
2189 use std::fs;
2190
2191 let process_id = std::process::id();
2192 let timestamp = std::time::SystemTime::now()
2193 .duration_since(std::time::UNIX_EPOCH)
2194 .unwrap()
2195 .as_millis();
2196
2197 let temp_dir = std::env::temp_dir();
2198 let db_path = temp_dir.join(format!("test_sqlite_alice_{}_{}.db", process_id, timestamp));
2199 let db_path_str = db_path.to_str().unwrap();
2200
2201 if db_path.exists() {
2202 let _ = fs::remove_file(&db_path);
2203 }
2204
2205 let mut alice_storage = SqliteStorage::new(db_path_str).await?;
2206 alice_storage.initialize()?;
2207
2208 let db_path2 = temp_dir.join(format!("test_sqlite_bob_{}_{}.db", process_id, timestamp));
2209 let db_path2_str = db_path2.to_str().unwrap();
2210
2211 if db_path2.exists() {
2212 let _ = fs::remove_file(&db_path2);
2213 }
2214
2215 let mut bob_storage = SqliteStorage::new(db_path2_str).await?;
2216 bob_storage.initialize()?;
2217
2218 let mut rng = rand::rng();
2219 let alice_identity = generate_identity_key_pair().await?;
2220 let alice_registration_id = 11111u32;
2221 alice_storage
2222 .identity_store()
2223 .set_local_identity_key_pair(&alice_identity)
2224 .await?;
2225 alice_storage
2226 .identity_store()
2227 .set_local_registration_id(alice_registration_id)
2228 .await?;
2229
2230 let bob_identity = generate_identity_key_pair().await?;
2231 let bob_registration_id = 22222u32;
2232 bob_storage
2233 .identity_store()
2234 .set_local_identity_key_pair(&bob_identity)
2235 .await?;
2236 bob_storage
2237 .identity_store()
2238 .set_local_registration_id(bob_registration_id)
2239 .await?;
2240
2241 let alice_address = ProtocolAddress::new("alice".to_string(), DeviceId::new(1)?);
2242 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
2243
2244 let bob_pre_keys = generate_pre_keys(1, 1).await?;
2245 let bob_signed_pre_key = generate_signed_pre_key(&bob_identity, 1).await?;
2246 bob_storage
2247 .pre_key_store()
2248 .save_pre_key(
2249 PreKeyId::from(bob_pre_keys[0].0),
2250 &PreKeyRecord::new(PreKeyId::from(bob_pre_keys[0].0), &bob_pre_keys[0].1),
2251 )
2252 .await?;
2253 bob_storage
2254 .signed_pre_key_store()
2255 .save_signed_pre_key(SignedPreKeyId::from(1u32), &bob_signed_pre_key)
2256 .await?;
2257
2258 let kyber_keypair = libsignal_protocol::kem::KeyPair::generate(
2259 libsignal_protocol::kem::KeyType::Kyber1024,
2260 &mut rng,
2261 );
2262 let kyber_signature = bob_identity
2263 .private_key()
2264 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
2265
2266 let now = std::time::SystemTime::now();
2267 let kyber_record = KyberPreKeyRecord::new(
2268 KyberPreKeyId::from(1u32),
2269 Timestamp::from_epoch_millis(
2270 now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
2271 ),
2272 &kyber_keypair,
2273 &kyber_signature,
2274 );
2275 bob_storage
2276 .kyber_pre_key_store()
2277 .save_kyber_pre_key(KyberPreKeyId::from(1u32), &kyber_record)
2278 .await?;
2279
2280 let bob_bundle = PreKeyBundle::new(
2281 bob_registration_id,
2282 DeviceId::new(1)?,
2283 Some((
2284 PreKeyId::from(bob_pre_keys[0].0),
2285 bob_pre_keys[0].1.public_key,
2286 )),
2287 SignedPreKeyId::from(1u32),
2288 bob_signed_pre_key.public_key()?,
2289 bob_signed_pre_key.signature()?.to_vec(),
2290 KyberPreKeyId::from(1u32),
2291 kyber_keypair.public_key,
2292 kyber_signature.to_vec(),
2293 *bob_identity.identity_key(),
2294 )?;
2295
2296 alice_storage
2297 .establish_session_from_bundle(&bob_address, &bob_bundle)
2298 .await?;
2299
2300 assert_eq!(
2301 alice_storage.session_store().session_count().await,
2302 1,
2303 "Alice should have established session with Bob"
2304 );
2305
2306 let plaintext = b"Hello Bob! This is an integration test with real SQLite files.";
2307 let ciphertext = alice_storage
2308 .encrypt_message(&bob_address, plaintext)
2309 .await?;
2310
2311 assert!(
2312 matches!(ciphertext.message_type(), CiphertextMessageType::PreKey),
2313 "First message should be a PreKey message"
2314 );
2315
2316 bob_storage
2317 .identity_store()
2318 .save_identity(&alice_address, alice_identity.identity_key())
2319 .await?;
2320 let decrypted = bob_storage
2321 .decrypt_message(&alice_address, &ciphertext)
2322 .await?;
2323
2324 assert_eq!(
2325 decrypted, plaintext,
2326 "Decrypted message should match original"
2327 );
2328
2329 let response_plaintext =
2330 b"Hello Alice! Integration test successful with persistent SQLite storage!";
2331 let response_ciphertext = bob_storage
2332 .encrypt_message(&alice_address, response_plaintext)
2333 .await?;
2334
2335 assert!(
2336 matches!(
2337 response_ciphertext.message_type(),
2338 CiphertextMessageType::Whisper
2339 ),
2340 "Response message should be a Whisper message"
2341 );
2342
2343 let response_decrypted = alice_storage
2344 .decrypt_message(&bob_address, &response_ciphertext)
2345 .await?;
2346 assert_eq!(
2347 response_decrypted, response_plaintext,
2348 "Response should decrypt correctly"
2349 );
2350
2351 alice_storage.close()?;
2352 bob_storage.close()?;
2353
2354 assert!(
2355 db_path.exists(),
2356 "SQLite database file should exist after test"
2357 );
2358 assert!(
2359 db_path2.exists(),
2360 "Bob's SQLite database file should exist after test"
2361 );
2362
2363 drop(alice_storage);
2364 drop(bob_storage);
2365
2366 #[cfg(windows)]
2367 std::thread::sleep(std::time::Duration::from_millis(100));
2368
2369 let _ = fs::remove_file(&db_path);
2370 let _ = fs::remove_file(&db_path2);
2371
2372 Ok(())
2373 }
2374
2375 #[tokio::test]
2376 async fn test_sqlite_storage_preserves_identity_across_instantiations(
2377 ) -> Result<(), Box<dyn std::error::Error>> {
2378 use crate::keys::generate_identity_key_pair;
2379 use std::{env, fs};
2380
2381 let temp_dir = env::temp_dir();
2382 let process_id = std::process::id();
2383 let timestamp = std::time::SystemTime::now()
2384 .duration_since(std::time::UNIX_EPOCH)
2385 .unwrap()
2386 .as_millis();
2387
2388 let db_path = temp_dir.join(format!(
2389 "test_identity_persistence_{}_{}.db",
2390 process_id, timestamp
2391 ));
2392 let db_path_str = db_path.to_str().unwrap();
2393
2394 let (original_identity, original_registration_id) = {
2395 let mut storage = SqliteStorage::new(db_path_str).await?;
2396 storage.initialize()?;
2397
2398 let identity = generate_identity_key_pair().await?;
2399 let registration_id = 12345u32;
2400
2401 storage
2402 .identity_store()
2403 .set_local_identity_key_pair(&identity)
2404 .await?;
2405 storage
2406 .identity_store()
2407 .set_local_registration_id(registration_id)
2408 .await?;
2409
2410 (identity, registration_id)
2411 }; {
2414 let mut storage_reopened = SqliteStorage::new(db_path_str).await?;
2415 storage_reopened.initialize()?;
2416
2417 let retrieved_identity = storage_reopened
2418 .identity_store()
2419 .get_identity_key_pair()
2420 .await?;
2421 let retrieved_registration_id = storage_reopened
2422 .identity_store()
2423 .get_local_registration_id()
2424 .await?;
2425
2426 assert_eq!(
2427 retrieved_identity.identity_key().serialize(),
2428 original_identity.identity_key().serialize(),
2429 "Identity key should be preserved across storage instantiations"
2430 );
2431 assert_eq!(
2432 retrieved_registration_id, original_registration_id,
2433 "Registration ID should be preserved across storage instantiations"
2434 );
2435 }
2436
2437 let _ = fs::remove_file(&db_path);
2438
2439 Ok(())
2440 }
2441
2442 #[tokio::test]
2443 async fn test_session_delete_operations() -> Result<(), Box<dyn std::error::Error>> {
2444 use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
2445 use libsignal_protocol::*;
2446
2447 let mut storage = SqliteStorage::new(":memory:").await?;
2448 storage.initialize()?;
2449
2450 let identity = generate_identity_key_pair().await?;
2451 let registration_id = 12345u32;
2452 storage
2453 .identity_store()
2454 .set_local_identity_key_pair(&identity)
2455 .await?;
2456 storage
2457 .identity_store()
2458 .set_local_registration_id(registration_id)
2459 .await?;
2460
2461 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
2462 let charlie_address = ProtocolAddress::new("charlie".to_string(), DeviceId::new(1)?);
2463
2464 let bob_identity = generate_identity_key_pair().await?;
2465 let bob_pre_keys = generate_pre_keys(1, 1).await?;
2466 let bob_signed_pre_key = generate_signed_pre_key(&bob_identity, 1).await?;
2467
2468 let mut rng = rand::rng();
2469 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
2470 let kyber_signature = bob_identity
2471 .private_key()
2472 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
2473
2474 let bob_bundle = PreKeyBundle::new(
2475 registration_id,
2476 DeviceId::new(1)?,
2477 Some((
2478 PreKeyId::from(bob_pre_keys[0].0),
2479 bob_pre_keys[0].1.public_key,
2480 )),
2481 SignedPreKeyId::from(1u32),
2482 bob_signed_pre_key.public_key()?,
2483 bob_signed_pre_key.signature()?.to_vec(),
2484 KyberPreKeyId::from(1u32),
2485 kyber_keypair.public_key,
2486 kyber_signature.to_vec(),
2487 *bob_identity.identity_key(),
2488 )?;
2489
2490 storage
2491 .establish_session_from_bundle(&bob_address, &bob_bundle)
2492 .await?;
2493 storage
2494 .establish_session_from_bundle(&charlie_address, &bob_bundle)
2495 .await?;
2496
2497 assert_eq!(
2498 storage.session_store().session_count().await,
2499 2,
2500 "Should have 2 sessions"
2501 );
2502
2503 storage.session_store().delete_session(&bob_address).await?;
2504 assert_eq!(
2505 storage.session_store().session_count().await,
2506 1,
2507 "Should have 1 session after deleting Bob's"
2508 );
2509
2510 let bob_session = storage.session_store().load_session(&bob_address).await?;
2511 assert!(bob_session.is_none(), "Bob's session should be deleted");
2512
2513 let charlie_session = storage
2514 .session_store()
2515 .load_session(&charlie_address)
2516 .await?;
2517 assert!(
2518 charlie_session.is_some(),
2519 "Charlie's session should still exist"
2520 );
2521
2522 Ok(())
2523 }
2524
2525 #[tokio::test]
2526 async fn test_identity_management_operations() -> Result<(), Box<dyn std::error::Error>> {
2527 use crate::keys::generate_identity_key_pair;
2528 use libsignal_protocol::*;
2529
2530 let mut storage = SqliteStorage::new(":memory:").await?;
2531 storage.initialize()?;
2532
2533 let local_identity = generate_identity_key_pair().await?;
2534 storage
2535 .identity_store()
2536 .set_local_identity_key_pair(&local_identity)
2537 .await?;
2538
2539 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
2540 let charlie_address = ProtocolAddress::new("charlie".to_string(), DeviceId::new(1)?);
2541
2542 let bob_identity = generate_identity_key_pair().await?;
2543 let charlie_identity = generate_identity_key_pair().await?;
2544
2545 let result = storage
2546 .identity_store()
2547 .get_peer_identity(&bob_address)
2548 .await?;
2549 assert!(
2550 result.is_none(),
2551 "Should return None for non-existent peer identity"
2552 );
2553
2554 storage
2555 .identity_store()
2556 .save_identity(&bob_address, bob_identity.identity_key())
2557 .await?;
2558 storage
2559 .identity_store()
2560 .save_identity(&charlie_address, charlie_identity.identity_key())
2561 .await?;
2562
2563 assert_eq!(
2564 storage.identity_store().identity_count().await,
2565 2,
2566 "Should have 2 peer identities"
2567 );
2568
2569 let retrieved_bob = storage
2570 .identity_store()
2571 .get_peer_identity(&bob_address)
2572 .await?;
2573 assert!(retrieved_bob.is_some(), "Should retrieve Bob's identity");
2574 assert_eq!(
2575 retrieved_bob.unwrap(),
2576 *bob_identity.identity_key(),
2577 "Retrieved identity should match stored"
2578 );
2579
2580 let retrieved_charlie = storage
2581 .identity_store()
2582 .get_peer_identity(&charlie_address)
2583 .await?;
2584 assert!(
2585 retrieved_charlie.is_some(),
2586 "Should retrieve Charlie's identity"
2587 );
2588 assert_eq!(
2589 retrieved_charlie.unwrap(),
2590 *charlie_identity.identity_key(),
2591 "Retrieved identity should match stored"
2592 );
2593
2594 storage
2595 .identity_store()
2596 .delete_identity(&bob_address)
2597 .await?;
2598 assert_eq!(
2599 storage.identity_store().identity_count().await,
2600 1,
2601 "Should have 1 identity after deleting Bob's"
2602 );
2603
2604 let deleted_bob = storage
2605 .identity_store()
2606 .get_peer_identity(&bob_address)
2607 .await?;
2608 assert!(deleted_bob.is_none(), "Bob's identity should be deleted");
2609
2610 let still_charlie = storage
2611 .identity_store()
2612 .get_peer_identity(&charlie_address)
2613 .await?;
2614 assert!(
2615 still_charlie.is_some(),
2616 "Charlie's identity should still exist"
2617 );
2618
2619 storage.identity_store().clear_all_identities().await?;
2620 assert_eq!(
2621 storage.identity_store().identity_count().await,
2622 0,
2623 "Should have 0 identities after clearing all"
2624 );
2625
2626 let cleared_charlie = storage
2627 .identity_store()
2628 .get_peer_identity(&charlie_address)
2629 .await?;
2630 assert!(
2631 cleared_charlie.is_none(),
2632 "Charlie's identity should be cleared"
2633 );
2634
2635 Ok(())
2636 }
2637
2638 #[tokio::test]
2639 async fn test_clear_local_identity() -> Result<(), Box<dyn std::error::Error>> {
2640 use crate::keys::generate_identity_key_pair;
2641 use libsignal_protocol::*;
2642
2643 let mut storage = SqliteStorage::new(":memory:").await?;
2644 storage.initialize()?;
2645
2646 let identity = generate_identity_key_pair().await?;
2647 let registration_id = 12345u32;
2648 storage
2649 .identity_store()
2650 .set_local_identity_key_pair(&identity)
2651 .await?;
2652 storage
2653 .identity_store()
2654 .set_local_registration_id(registration_id)
2655 .await?;
2656
2657 let retrieved_identity = storage.identity_store().get_identity_key_pair().await?;
2658 assert_eq!(
2659 retrieved_identity.identity_key().serialize(),
2660 identity.identity_key().serialize()
2661 );
2662
2663 let retrieved_registration = storage.identity_store().get_local_registration_id().await?;
2664 assert_eq!(retrieved_registration, registration_id);
2665
2666 storage.identity_store().clear_local_identity().await?;
2667
2668 let result = storage.identity_store().get_identity_key_pair().await;
2669 assert!(
2670 result.is_err(),
2671 "Should return error when local identity is cleared"
2672 );
2673
2674 let result = storage.identity_store().get_local_registration_id().await;
2675 assert!(
2676 result.is_err(),
2677 "Should return error when local registration ID is cleared"
2678 );
2679
2680 Ok(())
2681 }
2682
2683 #[tokio::test]
2684 async fn test_bundle_metadata_record_and_retrieve() -> Result<(), Box<dyn std::error::Error>> {
2685 let mut storage = SqliteStorage::new(":memory:").await?;
2686 storage.initialize_schema()?;
2687
2688 let initial_metadata = storage.get_last_published_bundle_metadata()?;
2690 assert_eq!(initial_metadata, None, "Should have no metadata initially");
2691
2692 storage.record_published_bundle(100, 2, 2)?;
2694
2695 let metadata = storage.get_last_published_bundle_metadata()?;
2697 assert_eq!(
2698 metadata,
2699 Some((100, 2, 2)),
2700 "Should retrieve the recorded bundle metadata"
2701 );
2702
2703 storage.record_published_bundle(99, 3, 3)?;
2705
2706 let updated_metadata = storage.get_last_published_bundle_metadata()?;
2708 assert_eq!(
2709 updated_metadata,
2710 Some((99, 3, 3)),
2711 "Should retrieve the updated bundle metadata"
2712 );
2713
2714 Ok(())
2715 }
2716}