1use crate::storage_trait::*;
7use async_trait::async_trait;
8use libsignal_protocol::*;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13fn address_key(address: &ProtocolAddress) -> String {
14 format!("{}:{}", address.name(), u32::from(address.device_id()))
15}
16
17pub struct MemorySessionStore {
19 sessions: Arc<Mutex<HashMap<String, SessionRecord>>>,
20}
21
22impl Default for MemorySessionStore {
23 fn default() -> Self {
24 Self::new()
25 }
26}
27
28impl MemorySessionStore {
29 pub fn new() -> Self {
30 Self {
31 sessions: Arc::new(Mutex::new(HashMap::new())),
32 }
33 }
34}
35
36#[async_trait(?Send)]
37impl SessionStore for MemorySessionStore {
38 async fn load_session(
39 &self,
40 address: &ProtocolAddress,
41 ) -> Result<Option<SessionRecord>, SignalProtocolError> {
42 let key = address_key(address);
43 let store = self.sessions.lock().await;
44 Ok(store.get(&key).cloned())
45 }
46
47 async fn store_session(
48 &mut self,
49 address: &ProtocolAddress,
50 record: &SessionRecord,
51 ) -> Result<(), SignalProtocolError> {
52 let key = address_key(address);
53 let mut store = self.sessions.lock().await;
54 store.insert(key, record.clone());
55 Ok(())
56 }
57}
58
59#[async_trait(?Send)]
60impl ExtendedSessionStore for MemorySessionStore {
61 async fn session_count(&self) -> usize {
62 let store = self.sessions.lock().await;
63 store.len()
64 }
65
66 async fn clear_all_sessions(&mut self) -> Result<(), Box<dyn std::error::Error>> {
67 let mut store = self.sessions.lock().await;
68 store.clear();
69 Ok(())
70 }
71 async fn delete_session(
72 &mut self,
73 address: &ProtocolAddress,
74 ) -> Result<(), Box<dyn std::error::Error>> {
75 let key = address_key(address);
76 let mut store = self.sessions.lock().await;
77 store.remove(&key);
78 Ok(())
79 }
80}
81
82pub struct MemoryIdentityStore {
84 identity_keys: Arc<Mutex<HashMap<String, IdentityKey>>>,
85 local_identity_key_pair: Arc<Mutex<Option<IdentityKeyPair>>>,
86 local_registration_id: Arc<Mutex<Option<u32>>>,
87}
88
89impl Default for MemoryIdentityStore {
90 fn default() -> Self {
91 Self::new()
92 }
93}
94
95impl MemoryIdentityStore {
96 pub fn new() -> Self {
97 Self {
98 identity_keys: Arc::new(Mutex::new(HashMap::new())),
99 local_identity_key_pair: Arc::new(Mutex::new(None)),
100 local_registration_id: Arc::new(Mutex::new(None)),
101 }
102 }
103}
104
105#[async_trait(?Send)]
106impl IdentityKeyStore for MemoryIdentityStore {
107 async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
108 let store = self.local_identity_key_pair.lock().await;
109 match *store {
110 Some(identity_key_pair) => Ok(identity_key_pair),
111 None => Err(SignalProtocolError::InvalidState(
112 "storage",
113 "Local identity key pair not set".to_string(),
114 )),
115 }
116 }
117
118 async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> {
119 let store = self.local_registration_id.lock().await;
120 match *store {
121 Some(registration_id) => Ok(registration_id),
122 None => Err(SignalProtocolError::InvalidState(
123 "storage",
124 "Local registration ID not set".to_string(),
125 )),
126 }
127 }
128
129 async fn save_identity(
130 &mut self,
131 address: &ProtocolAddress,
132 identity_key: &IdentityKey,
133 ) -> Result<IdentityChange, SignalProtocolError> {
134 let existing = self.get_identity(address).await.ok().flatten();
135
136 let key = address_key(address);
137 let mut store = self.identity_keys.lock().await;
138 store.insert(key, *identity_key);
139
140 match existing {
141 Some(existing_key) if existing_key != *identity_key => {
142 Ok(IdentityChange::ReplacedExisting)
143 }
144 Some(_) => Ok(IdentityChange::NewOrUnchanged),
145 None => Ok(IdentityChange::NewOrUnchanged),
146 }
147 }
148
149 async fn is_trusted_identity(
150 &self,
151 address: &ProtocolAddress,
152 identity_key: &IdentityKey,
153 _direction: Direction,
154 ) -> Result<bool, SignalProtocolError> {
155 let key = address_key(address);
156 let store = self.identity_keys.lock().await;
157 match store.get(&key) {
158 Some(stored_key) => Ok(*stored_key == *identity_key),
159 None => Ok(true), }
161 }
162
163 async fn get_identity(
164 &self,
165 address: &ProtocolAddress,
166 ) -> Result<Option<IdentityKey>, SignalProtocolError> {
167 let key = address_key(address);
168 let store = self.identity_keys.lock().await;
169 Ok(store.get(&key).copied())
170 }
171}
172
173#[async_trait(?Send)]
174impl ExtendedIdentityStore for MemoryIdentityStore {
175 async fn identity_count(&self) -> usize {
176 let store = self.identity_keys.lock().await;
177 store.len()
178 }
179
180 async fn set_local_identity_key_pair(
181 &self,
182 identity_key_pair: &IdentityKeyPair,
183 ) -> Result<(), Box<dyn std::error::Error>> {
184 let mut store = self.local_identity_key_pair.lock().await;
185 *store = Some(*identity_key_pair);
186 Ok(())
187 }
188
189 async fn set_local_registration_id(
190 &self,
191 registration_id: u32,
192 ) -> Result<(), Box<dyn std::error::Error>> {
193 let mut store = self.local_registration_id.lock().await;
194 *store = Some(registration_id);
195 Ok(())
196 }
197 async fn get_peer_identity(
198 &self,
199 address: &ProtocolAddress,
200 ) -> Result<Option<IdentityKey>, Box<dyn std::error::Error>> {
201 let key = address_key(address);
202 let store = self.identity_keys.lock().await;
203 Ok(store.get(&key).cloned())
204 }
205 async fn delete_identity(
206 &mut self,
207 address: &ProtocolAddress,
208 ) -> Result<(), Box<dyn std::error::Error>> {
209 let key = address_key(address);
210 let mut store = self.identity_keys.lock().await;
211 store.remove(&key);
212 Ok(())
213 }
214 async fn clear_all_identities(&mut self) -> Result<(), Box<dyn std::error::Error>> {
215 let mut store = self.identity_keys.lock().await;
216 store.clear();
217 Ok(())
218 }
219 async fn clear_local_identity(&mut self) -> Result<(), Box<dyn std::error::Error>> {
220 let mut identity_store = self.local_identity_key_pair.lock().await;
221 let mut registration_store = self.local_registration_id.lock().await;
222 *identity_store = None;
223 *registration_store = None;
224 Ok(())
225 }
226}
227
228pub struct MemoryPreKeyStore {
230 pre_keys: Arc<Mutex<HashMap<u32, KeyPair>>>,
231}
232
233impl Default for MemoryPreKeyStore {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239impl MemoryPreKeyStore {
240 pub fn new() -> Self {
241 Self {
242 pre_keys: Arc::new(Mutex::new(HashMap::new())),
243 }
244 }
245}
246
247#[async_trait(?Send)]
248impl PreKeyStore for MemoryPreKeyStore {
249 async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result<PreKeyRecord, SignalProtocolError> {
250 let store = self.pre_keys.lock().await;
251 match store.get(&u32::from(prekey_id)) {
252 Some(key_pair) => Ok(PreKeyRecord::new(prekey_id, key_pair)),
253 None => Err(SignalProtocolError::InvalidPreKeyId),
254 }
255 }
256
257 async fn save_pre_key(
258 &mut self,
259 prekey_id: PreKeyId,
260 record: &PreKeyRecord,
261 ) -> Result<(), SignalProtocolError> {
262 let key_pair = record.key_pair()?;
263 let mut store = self.pre_keys.lock().await;
264 store.insert(u32::from(prekey_id), key_pair);
265 Ok(())
266 }
267
268 async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> {
269 let mut store = self.pre_keys.lock().await;
270 store.remove(&u32::from(prekey_id));
271 Ok(())
272 }
273}
274
275#[async_trait(?Send)]
276impl ExtendedPreKeyStore for MemoryPreKeyStore {
277 async fn pre_key_count(&self) -> usize {
278 let store = self.pre_keys.lock().await;
279 store.len()
280 }
281
282 async fn clear_all_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
283 let mut store = self.pre_keys.lock().await;
284 store.clear();
285 Ok(())
286 }
287
288 async fn get_max_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
289 let store = self.pre_keys.lock().await;
290 Ok(store.keys().max().copied())
291 }
292
293 async fn delete_pre_key(&mut self, id: PreKeyId) -> Result<(), Box<dyn std::error::Error>> {
294 let mut store = self.pre_keys.lock().await;
295 store.remove(&u32::from(id));
296 Ok(())
297 }
298}
299
300pub struct MemorySignedPreKeyStore {
302 signed_pre_keys: Arc<Mutex<HashMap<u32, SignedPreKeyRecord>>>,
303}
304
305impl Default for MemorySignedPreKeyStore {
306 fn default() -> Self {
307 Self::new()
308 }
309}
310
311impl MemorySignedPreKeyStore {
312 pub fn new() -> Self {
313 Self {
314 signed_pre_keys: Arc::new(Mutex::new(HashMap::new())),
315 }
316 }
317}
318
319#[async_trait(?Send)]
320impl SignedPreKeyStore for MemorySignedPreKeyStore {
321 async fn get_signed_pre_key(
322 &self,
323 signed_prekey_id: SignedPreKeyId,
324 ) -> Result<SignedPreKeyRecord, SignalProtocolError> {
325 let store = self.signed_pre_keys.lock().await;
326 match store.get(&u32::from(signed_prekey_id)) {
327 Some(record) => Ok(record.clone()),
328 None => Err(SignalProtocolError::InvalidSignedPreKeyId),
329 }
330 }
331
332 async fn save_signed_pre_key(
333 &mut self,
334 signed_prekey_id: SignedPreKeyId,
335 record: &SignedPreKeyRecord,
336 ) -> Result<(), SignalProtocolError> {
337 let mut store = self.signed_pre_keys.lock().await;
338 store.insert(u32::from(signed_prekey_id), record.clone());
339 Ok(())
340 }
341}
342
343#[async_trait(?Send)]
344impl ExtendedSignedPreKeyStore for MemorySignedPreKeyStore {
345 async fn signed_pre_key_count(&self) -> usize {
346 let store = self.signed_pre_keys.lock().await;
347 store.len()
348 }
349
350 async fn clear_all_signed_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
351 let mut store = self.signed_pre_keys.lock().await;
352 store.clear();
353 Ok(())
354 }
355
356 async fn get_max_signed_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
357 let store = self.signed_pre_keys.lock().await;
358 Ok(store.keys().max().copied())
359 }
360
361 async fn delete_signed_pre_key(
362 &mut self,
363 id: SignedPreKeyId,
364 ) -> Result<(), Box<dyn std::error::Error>> {
365 let mut store = self.signed_pre_keys.lock().await;
366 store.remove(&u32::from(id));
367 Ok(())
368 }
369
370 async fn get_signed_pre_keys_older_than(
371 &self,
372 timestamp_millis: u64,
373 ) -> Result<Vec<SignedPreKeyId>, Box<dyn std::error::Error>> {
374 let store = self.signed_pre_keys.lock().await;
375 let mut expired = Vec::new();
376 for (id, record) in store.iter() {
377 if let Ok(ts) = record.timestamp() {
378 if ts.epoch_millis() < timestamp_millis {
379 expired.push(SignedPreKeyId::from(*id));
380 }
381 }
382 }
383 Ok(expired)
384 }
385}
386
387pub struct MemoryKyberPreKeyStore {
389 kyber_pre_keys: Arc<Mutex<HashMap<u32, KyberPreKeyRecord>>>,
390}
391
392impl Default for MemoryKyberPreKeyStore {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398impl MemoryKyberPreKeyStore {
399 pub fn new() -> Self {
400 Self {
401 kyber_pre_keys: Arc::new(Mutex::new(HashMap::new())),
402 }
403 }
404}
405
406#[async_trait(?Send)]
407impl KyberPreKeyStore for MemoryKyberPreKeyStore {
408 async fn get_kyber_pre_key(
409 &self,
410 kyber_prekey_id: KyberPreKeyId,
411 ) -> Result<KyberPreKeyRecord, SignalProtocolError> {
412 let store = self.kyber_pre_keys.lock().await;
413 match store.get(&u32::from(kyber_prekey_id)) {
414 Some(record) => Ok(record.clone()),
415 None => Err(SignalProtocolError::InvalidKyberPreKeyId),
416 }
417 }
418
419 async fn save_kyber_pre_key(
420 &mut self,
421 kyber_prekey_id: KyberPreKeyId,
422 record: &KyberPreKeyRecord,
423 ) -> Result<(), SignalProtocolError> {
424 let mut store = self.kyber_pre_keys.lock().await;
425 store.insert(u32::from(kyber_prekey_id), record.clone());
426 Ok(())
427 }
428
429 async fn mark_kyber_pre_key_used(
430 &mut self,
431 _kyber_prekey_id: KyberPreKeyId,
432 ) -> Result<(), SignalProtocolError> {
433 Ok(())
436 }
437}
438
439#[async_trait(?Send)]
440impl ExtendedKyberPreKeyStore for MemoryKyberPreKeyStore {
441 async fn kyber_pre_key_count(&self) -> usize {
442 let store = self.kyber_pre_keys.lock().await;
443 store.len()
444 }
445
446 async fn clear_all_kyber_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
447 let mut store = self.kyber_pre_keys.lock().await;
448 store.clear();
449 Ok(())
450 }
451
452 async fn get_max_kyber_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
453 let store = self.kyber_pre_keys.lock().await;
454 Ok(store.keys().max().copied())
455 }
456
457 async fn delete_kyber_pre_key(
458 &mut self,
459 id: KyberPreKeyId,
460 ) -> Result<(), Box<dyn std::error::Error>> {
461 let mut store = self.kyber_pre_keys.lock().await;
462 store.remove(&u32::from(id));
463 Ok(())
464 }
465
466 async fn get_kyber_pre_keys_older_than(
467 &self,
468 timestamp_millis: u64,
469 ) -> Result<Vec<KyberPreKeyId>, Box<dyn std::error::Error>> {
470 let store = self.kyber_pre_keys.lock().await;
471 let mut expired = Vec::new();
472 for (id, record) in store.iter() {
473 if let Ok(ts) = record.timestamp() {
474 if ts.epoch_millis() < timestamp_millis {
475 expired.push(KyberPreKeyId::from(*id));
476 }
477 }
478 }
479 Ok(expired)
480 }
481}
482
483pub struct MemoryStorage {
485 pub session_store: MemorySessionStore,
486 pub identity_store: MemoryIdentityStore,
487 pub pre_key_store: MemoryPreKeyStore,
488 pub signed_pre_key_store: MemorySignedPreKeyStore,
489 pub kyber_pre_key_store: MemoryKyberPreKeyStore,
490}
491
492impl Default for MemoryStorage {
493 fn default() -> Self {
494 Self::new()
495 }
496}
497
498impl MemoryStorage {
499 pub fn new() -> Self {
500 Self {
501 session_store: MemorySessionStore::new(),
502 identity_store: MemoryIdentityStore::new(),
503 pre_key_store: MemoryPreKeyStore::new(),
504 signed_pre_key_store: MemorySignedPreKeyStore::new(),
505 kyber_pre_key_store: MemoryKyberPreKeyStore::new(),
506 }
507 }
508}
509
510impl SignalStorageContainer for MemoryStorage {
511 type SessionStore = MemorySessionStore;
512 type IdentityStore = MemoryIdentityStore;
513 type PreKeyStore = MemoryPreKeyStore;
514 type SignedPreKeyStore = MemorySignedPreKeyStore;
515 type KyberPreKeyStore = MemoryKyberPreKeyStore;
516
517 fn session_store(&mut self) -> &mut Self::SessionStore {
518 &mut self.session_store
519 }
520
521 fn identity_store(&mut self) -> &mut Self::IdentityStore {
522 &mut self.identity_store
523 }
524
525 fn pre_key_store(&mut self) -> &mut Self::PreKeyStore {
526 &mut self.pre_key_store
527 }
528
529 fn signed_pre_key_store(&mut self) -> &mut Self::SignedPreKeyStore {
530 &mut self.signed_pre_key_store
531 }
532
533 fn kyber_pre_key_store(&mut self) -> &mut Self::KyberPreKeyStore {
534 &mut self.kyber_pre_key_store
535 }
536
537 fn initialize(&mut self) -> Result<(), Box<dyn std::error::Error>> {
538 Ok(())
539 }
540
541 fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
542 Ok(())
543 }
544
545 fn storage_type(&self) -> &'static str {
546 "memory"
547 }
548}
549
550#[async_trait(?Send)]
551impl ExtendedStorageOps for MemoryStorage {
552 async fn establish_session_from_bundle(
553 &mut self,
554 address: &ProtocolAddress,
555 bundle: &PreKeyBundle,
556 ) -> Result<(), Box<dyn std::error::Error>> {
557 MemoryStorage::establish_session_from_bundle(self, address, bundle).await
558 }
559
560 async fn encrypt_message(
561 &mut self,
562 remote_address: &ProtocolAddress,
563 plaintext: &[u8],
564 ) -> Result<CiphertextMessage, SignalProtocolError> {
565 MemoryStorage::encrypt_message(self, remote_address, plaintext).await
566 }
567
568 async fn decrypt_message(
569 &mut self,
570 remote_address: &ProtocolAddress,
571 ciphertext: &CiphertextMessage,
572 ) -> Result<Vec<u8>, SignalProtocolError> {
573 MemoryStorage::decrypt_message(self, remote_address, ciphertext).await
574 }
575}
576
577impl MemoryStorage {
578 pub async fn establish_session_from_bundle(
579 &mut self,
580 address: &ProtocolAddress,
581 bundle: &PreKeyBundle,
582 ) -> Result<(), Box<dyn std::error::Error>> {
583 let mut rng = rand::rng();
584 let timestamp = std::time::SystemTime::now();
585
586 process_prekey_bundle(
587 address,
588 &mut self.session_store,
589 &mut self.identity_store,
590 bundle,
591 timestamp,
592 &mut rng,
593 UsePQRatchet::Yes,
594 )
595 .await?;
596
597 Ok(())
598 }
599
600 pub async fn encrypt_message(
601 &mut self,
602 remote_address: &ProtocolAddress,
603 plaintext: &[u8],
604 ) -> Result<CiphertextMessage, SignalProtocolError> {
605 let mut rng = rand::rng();
606 let now = std::time::SystemTime::now();
607 message_encrypt(
608 plaintext,
609 remote_address,
610 &mut self.session_store,
611 &mut self.identity_store,
612 now,
613 &mut rng,
614 )
615 .await
616 }
617
618 pub async fn decrypt_message(
619 &mut self,
620 remote_address: &ProtocolAddress,
621 ciphertext: &CiphertextMessage,
622 ) -> Result<Vec<u8>, SignalProtocolError> {
623 let mut rng = rand::rng();
624
625 message_decrypt(
626 ciphertext,
627 remote_address,
628 &mut self.session_store,
629 &mut self.identity_store,
630 &mut self.pre_key_store,
631 &self.signed_pre_key_store,
632 &mut self.kyber_pre_key_store,
633 &mut rng,
634 UsePQRatchet::Yes,
635 )
636 .await
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[tokio::test]
645 async fn test_memory_storage_creation() -> Result<(), Box<dyn std::error::Error>> {
646 let mut storage = MemoryStorage::new();
647 storage.initialize()?;
648 assert_eq!(storage.storage_type(), "memory");
649 Ok(())
650 }
651
652 #[tokio::test]
653 async fn test_session_storage() -> Result<(), Box<dyn std::error::Error>> {
654 let mut storage = MemoryStorage::new();
655 let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
656 let session_record = SessionRecord::new_fresh();
657
658 assert_eq!(storage.session_store.session_count().await, 0);
659
660 storage
661 .session_store
662 .store_session(&address, &session_record)
663 .await?;
664 assert_eq!(storage.session_store.session_count().await, 1);
665
666 let loaded = storage.session_store.load_session(&address).await?;
667 assert!(loaded.is_some());
668
669 storage.session_store.clear_all_sessions().await?;
670 assert_eq!(storage.session_store.session_count().await, 0);
671
672 Ok(())
673 }
674
675 #[tokio::test]
676 async fn test_identity_storage() -> Result<(), Box<dyn std::error::Error>> {
677 let mut storage = MemoryStorage::new();
678 let mut rng = rand::rng();
679 let identity_key_pair = IdentityKeyPair::generate(&mut rng);
680 let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
681
682 storage
683 .identity_store
684 .set_local_identity_key_pair(&identity_key_pair)
685 .await?;
686 storage
687 .identity_store
688 .set_local_registration_id(12345)
689 .await?;
690
691 let retrieved_identity = storage.identity_store.get_identity_key_pair().await?;
692 assert_eq!(
693 retrieved_identity.identity_key().serialize(),
694 identity_key_pair.identity_key().serialize()
695 );
696
697 let retrieved_registration = storage.identity_store.get_local_registration_id().await?;
698 assert_eq!(retrieved_registration, 12345);
699
700 assert_eq!(storage.identity_store.identity_count().await, 0);
701
702 storage
703 .identity_store
704 .save_identity(&address, identity_key_pair.identity_key())
705 .await?;
706 assert_eq!(storage.identity_store.identity_count().await, 1);
707
708 let retrieved = storage.identity_store.get_identity(&address).await?;
709 assert!(retrieved.is_some());
710 assert_eq!(
711 retrieved.unwrap().serialize(),
712 identity_key_pair.identity_key().serialize()
713 );
714
715 Ok(())
716 }
717
718 #[tokio::test]
719 async fn test_key_generation_storage_integration() -> Result<(), Box<dyn std::error::Error>> {
720 use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
721 use crate::memory_storage::MemoryStorage;
722 use crate::storage_trait::{
723 ExtendedIdentityStore, ExtendedPreKeyStore, ExtendedSignedPreKeyStore,
724 };
725
726 let identity_key_pair = generate_identity_key_pair().await?;
727 let pre_keys = generate_pre_keys(1, 5).await?;
728 let signed_pre_key = generate_signed_pre_key(&identity_key_pair, 1).await?;
729
730 let mut storage = MemoryStorage::new();
731
732 let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
733 storage
734 .identity_store
735 .save_identity(&address, identity_key_pair.identity_key())
736 .await?;
737 let retrieved_identity = storage.identity_store.get_identity(&address).await?;
738 assert!(retrieved_identity.is_some());
739 assert_eq!(
740 retrieved_identity.unwrap().serialize(),
741 identity_key_pair.identity_key().serialize()
742 );
743
744 for (key_id, key_pair) in &pre_keys {
745 let record = PreKeyRecord::new((*key_id).into(), key_pair);
746 storage
747 .pre_key_store
748 .save_pre_key((*key_id).into(), &record)
749 .await?;
750 let retrieved_record = storage.pre_key_store.get_pre_key((*key_id).into()).await?;
751 assert_eq!(
752 retrieved_record.key_pair()?.public_key.serialize(),
753 key_pair.public_key.serialize()
754 );
755 }
756
757 storage
758 .signed_pre_key_store
759 .save_signed_pre_key(signed_pre_key.id()?, &signed_pre_key)
760 .await?;
761 let retrieved_signed_key = storage
762 .signed_pre_key_store
763 .get_signed_pre_key(signed_pre_key.id()?)
764 .await?;
765 assert_eq!(retrieved_signed_key.id()?, signed_pre_key.id()?);
766
767 assert_eq!(storage.identity_store.identity_count().await, 1);
768 assert_eq!(storage.pre_key_store.pre_key_count().await, 5);
769 assert_eq!(storage.signed_pre_key_store.signed_pre_key_count().await, 1);
770
771 Ok(())
772 }
773
774 #[tokio::test]
775 async fn test_memory_session_delete_operations() -> Result<(), Box<dyn std::error::Error>> {
776 use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
777 use libsignal_protocol::*;
778
779 let mut storage = MemoryStorage::new();
780
781 let identity = generate_identity_key_pair().await?;
782 let registration_id = 12345u32;
783 storage
784 .identity_store
785 .set_local_identity_key_pair(&identity)
786 .await?;
787 storage
788 .identity_store
789 .set_local_registration_id(registration_id)
790 .await?;
791
792 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
793 let charlie_address = ProtocolAddress::new("charlie".to_string(), DeviceId::new(1)?);
794
795 let bob_identity = generate_identity_key_pair().await?;
796 let bob_pre_keys = generate_pre_keys(1, 1).await?;
797 let bob_signed_pre_key = generate_signed_pre_key(&bob_identity, 1).await?;
798
799 let mut rng = rand::rng();
800 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
801 let kyber_signature = bob_identity
802 .private_key()
803 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
804
805 let bob_bundle = PreKeyBundle::new(
806 registration_id,
807 DeviceId::new(1)?,
808 Some((
809 PreKeyId::from(bob_pre_keys[0].0),
810 bob_pre_keys[0].1.public_key,
811 )),
812 SignedPreKeyId::from(1u32),
813 bob_signed_pre_key.public_key()?,
814 bob_signed_pre_key.signature()?.to_vec(),
815 KyberPreKeyId::from(1u32),
816 kyber_keypair.public_key,
817 kyber_signature.to_vec(),
818 *bob_identity.identity_key(),
819 )?;
820
821 storage
822 .establish_session_from_bundle(&bob_address, &bob_bundle)
823 .await?;
824 storage
825 .establish_session_from_bundle(&charlie_address, &bob_bundle)
826 .await?;
827 assert_eq!(
828 storage.session_store.session_count().await,
829 2,
830 "Should have 2 sessions"
831 );
832
833 storage.session_store.delete_session(&bob_address).await?;
834 assert_eq!(
835 storage.session_store.session_count().await,
836 1,
837 "Should have 1 session after deleting Bob's"
838 );
839
840 let bob_session = storage.session_store.load_session(&bob_address).await?;
841 assert!(bob_session.is_none(), "Bob's session should be deleted");
842
843 let charlie_session = storage.session_store.load_session(&charlie_address).await?;
844 assert!(
845 charlie_session.is_some(),
846 "Charlie's session should still exist"
847 );
848
849 Ok(())
850 }
851
852 #[tokio::test]
853 async fn test_memory_identity_management_operations() -> Result<(), Box<dyn std::error::Error>>
854 {
855 use crate::keys::generate_identity_key_pair;
856 use libsignal_protocol::*;
857
858 let mut storage = MemoryStorage::new();
859
860 let local_identity = generate_identity_key_pair().await?;
861 storage
862 .identity_store
863 .set_local_identity_key_pair(&local_identity)
864 .await?;
865
866 let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
867 let charlie_address = ProtocolAddress::new("charlie".to_string(), DeviceId::new(1)?);
868
869 let bob_identity = generate_identity_key_pair().await?;
870 let charlie_identity = generate_identity_key_pair().await?;
871
872 let result = storage
873 .identity_store
874 .get_peer_identity(&bob_address)
875 .await?;
876 assert!(
877 result.is_none(),
878 "Should return None for non-existent peer identity"
879 );
880
881 storage
882 .identity_store
883 .save_identity(&bob_address, bob_identity.identity_key())
884 .await?;
885 storage
886 .identity_store
887 .save_identity(&charlie_address, charlie_identity.identity_key())
888 .await?;
889
890 assert_eq!(
891 storage.identity_store.identity_count().await,
892 2,
893 "Should have 2 peer identities"
894 );
895
896 let retrieved_bob = storage
897 .identity_store
898 .get_peer_identity(&bob_address)
899 .await?;
900 assert!(retrieved_bob.is_some(), "Should retrieve Bob's identity");
901 assert_eq!(
902 retrieved_bob.unwrap(),
903 *bob_identity.identity_key(),
904 "Retrieved identity should match stored"
905 );
906
907 let retrieved_charlie = storage
908 .identity_store
909 .get_peer_identity(&charlie_address)
910 .await?;
911 assert!(
912 retrieved_charlie.is_some(),
913 "Should retrieve Charlie's identity"
914 );
915 assert_eq!(
916 retrieved_charlie.unwrap(),
917 *charlie_identity.identity_key(),
918 "Retrieved identity should match stored"
919 );
920
921 storage.identity_store.delete_identity(&bob_address).await?;
922 assert_eq!(
923 storage.identity_store.identity_count().await,
924 1,
925 "Should have 1 identity after deleting Bob's"
926 );
927
928 let deleted_bob = storage
929 .identity_store
930 .get_peer_identity(&bob_address)
931 .await?;
932 assert!(deleted_bob.is_none(), "Bob's identity should be deleted");
933
934 let still_charlie = storage
935 .identity_store
936 .get_peer_identity(&charlie_address)
937 .await?;
938 assert!(
939 still_charlie.is_some(),
940 "Charlie's identity should still exist"
941 );
942
943 storage.identity_store.clear_all_identities().await?;
944 assert_eq!(
945 storage.identity_store.identity_count().await,
946 0,
947 "Should have 0 identities after clearing all"
948 );
949
950 let cleared_charlie = storage
951 .identity_store
952 .get_peer_identity(&charlie_address)
953 .await?;
954 assert!(
955 cleared_charlie.is_none(),
956 "Charlie's identity should be cleared"
957 );
958
959 Ok(())
960 }
961
962 #[tokio::test]
963 async fn test_memory_clear_local_identity() -> Result<(), Box<dyn std::error::Error>> {
964 use crate::keys::generate_identity_key_pair;
965 use libsignal_protocol::*;
966
967 let mut storage = MemoryStorage::new();
968
969 let identity = generate_identity_key_pair().await?;
970 let registration_id = 12345u32;
971 storage
972 .identity_store
973 .set_local_identity_key_pair(&identity)
974 .await?;
975 storage
976 .identity_store
977 .set_local_registration_id(registration_id)
978 .await?;
979
980 let retrieved_identity = storage.identity_store.get_identity_key_pair().await?;
981 assert_eq!(
982 retrieved_identity.identity_key().serialize(),
983 identity.identity_key().serialize()
984 );
985
986 let retrieved_registration = storage.identity_store.get_local_registration_id().await?;
987 assert_eq!(retrieved_registration, registration_id);
988
989 storage.identity_store.clear_local_identity().await?;
990
991 let result = storage.identity_store.get_identity_key_pair().await;
992 assert!(
993 result.is_err(),
994 "Should return error when local identity is cleared"
995 );
996
997 let result = storage.identity_store.get_local_registration_id().await;
998 assert!(
999 result.is_err(),
1000 "Should return error when local registration ID is cleared"
1001 );
1002
1003 Ok(())
1004 }
1005}