signal_bridge/
memory_storage.rs

1//! In-memory storage implementation for libsignal
2//!
3//! This module provides an in-memory storage implementation that implements
4//! the storage traits. Data is lost when the process terminates.
5
6use crate::storage_trait::*;
7use async_trait::async_trait;
8use libsignal_protocol::*;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::Mutex;
12
13fn address_key(address: &ProtocolAddress) -> String {
14    format!("{}:{}", address.name(), u32::from(address.device_id()))
15}
16
17/// In-memory session storage (data lost on process termination)
18pub struct MemorySessionStore {
19    sessions: Arc<Mutex<HashMap<String, SessionRecord>>>,
20}
21
22impl Default for MemorySessionStore {
23    fn default() -> Self {
24        Self::new()
25    }
26}
27
28impl MemorySessionStore {
29    pub fn new() -> Self {
30        Self {
31            sessions: Arc::new(Mutex::new(HashMap::new())),
32        }
33    }
34}
35
36#[async_trait(?Send)]
37impl SessionStore for MemorySessionStore {
38    async fn load_session(
39        &self,
40        address: &ProtocolAddress,
41    ) -> Result<Option<SessionRecord>, SignalProtocolError> {
42        let key = address_key(address);
43        let store = self.sessions.lock().await;
44        Ok(store.get(&key).cloned())
45    }
46
47    async fn store_session(
48        &mut self,
49        address: &ProtocolAddress,
50        record: &SessionRecord,
51    ) -> Result<(), SignalProtocolError> {
52        let key = address_key(address);
53        let mut store = self.sessions.lock().await;
54        store.insert(key, record.clone());
55        Ok(())
56    }
57}
58
59#[async_trait(?Send)]
60impl ExtendedSessionStore for MemorySessionStore {
61    async fn session_count(&self) -> usize {
62        let store = self.sessions.lock().await;
63        store.len()
64    }
65
66    async fn clear_all_sessions(&mut self) -> Result<(), Box<dyn std::error::Error>> {
67        let mut store = self.sessions.lock().await;
68        store.clear();
69        Ok(())
70    }
71    async fn delete_session(
72        &mut self,
73        address: &ProtocolAddress,
74    ) -> Result<(), Box<dyn std::error::Error>> {
75        let key = address_key(address);
76        let mut store = self.sessions.lock().await;
77        store.remove(&key);
78        Ok(())
79    }
80}
81
82/// In-memory identity key storage (data lost on process termination)
83pub struct MemoryIdentityStore {
84    identity_keys: Arc<Mutex<HashMap<String, IdentityKey>>>,
85    local_identity_key_pair: Arc<Mutex<Option<IdentityKeyPair>>>,
86    local_registration_id: Arc<Mutex<Option<u32>>>,
87}
88
89impl Default for MemoryIdentityStore {
90    fn default() -> Self {
91        Self::new()
92    }
93}
94
95impl MemoryIdentityStore {
96    pub fn new() -> Self {
97        Self {
98            identity_keys: Arc::new(Mutex::new(HashMap::new())),
99            local_identity_key_pair: Arc::new(Mutex::new(None)),
100            local_registration_id: Arc::new(Mutex::new(None)),
101        }
102    }
103}
104
105#[async_trait(?Send)]
106impl IdentityKeyStore for MemoryIdentityStore {
107    async fn get_identity_key_pair(&self) -> Result<IdentityKeyPair, SignalProtocolError> {
108        let store = self.local_identity_key_pair.lock().await;
109        match *store {
110            Some(identity_key_pair) => Ok(identity_key_pair),
111            None => Err(SignalProtocolError::InvalidState(
112                "storage",
113                "Local identity key pair not set".to_string(),
114            )),
115        }
116    }
117
118    async fn get_local_registration_id(&self) -> Result<u32, SignalProtocolError> {
119        let store = self.local_registration_id.lock().await;
120        match *store {
121            Some(registration_id) => Ok(registration_id),
122            None => Err(SignalProtocolError::InvalidState(
123                "storage",
124                "Local registration ID not set".to_string(),
125            )),
126        }
127    }
128
129    async fn save_identity(
130        &mut self,
131        address: &ProtocolAddress,
132        identity_key: &IdentityKey,
133    ) -> Result<IdentityChange, SignalProtocolError> {
134        let existing = self.get_identity(address).await.ok().flatten();
135
136        let key = address_key(address);
137        let mut store = self.identity_keys.lock().await;
138        store.insert(key, *identity_key);
139
140        match existing {
141            Some(existing_key) if existing_key != *identity_key => {
142                Ok(IdentityChange::ReplacedExisting)
143            }
144            Some(_) => Ok(IdentityChange::NewOrUnchanged),
145            None => Ok(IdentityChange::NewOrUnchanged),
146        }
147    }
148
149    async fn is_trusted_identity(
150        &self,
151        address: &ProtocolAddress,
152        identity_key: &IdentityKey,
153        _direction: Direction,
154    ) -> Result<bool, SignalProtocolError> {
155        let key = address_key(address);
156        let store = self.identity_keys.lock().await;
157        match store.get(&key) {
158            Some(stored_key) => Ok(*stored_key == *identity_key),
159            None => Ok(true), // Trust on first use
160        }
161    }
162
163    async fn get_identity(
164        &self,
165        address: &ProtocolAddress,
166    ) -> Result<Option<IdentityKey>, SignalProtocolError> {
167        let key = address_key(address);
168        let store = self.identity_keys.lock().await;
169        Ok(store.get(&key).copied())
170    }
171}
172
173#[async_trait(?Send)]
174impl ExtendedIdentityStore for MemoryIdentityStore {
175    async fn identity_count(&self) -> usize {
176        let store = self.identity_keys.lock().await;
177        store.len()
178    }
179
180    async fn set_local_identity_key_pair(
181        &self,
182        identity_key_pair: &IdentityKeyPair,
183    ) -> Result<(), Box<dyn std::error::Error>> {
184        let mut store = self.local_identity_key_pair.lock().await;
185        *store = Some(*identity_key_pair);
186        Ok(())
187    }
188
189    async fn set_local_registration_id(
190        &self,
191        registration_id: u32,
192    ) -> Result<(), Box<dyn std::error::Error>> {
193        let mut store = self.local_registration_id.lock().await;
194        *store = Some(registration_id);
195        Ok(())
196    }
197    async fn get_peer_identity(
198        &self,
199        address: &ProtocolAddress,
200    ) -> Result<Option<IdentityKey>, Box<dyn std::error::Error>> {
201        let key = address_key(address);
202        let store = self.identity_keys.lock().await;
203        Ok(store.get(&key).cloned())
204    }
205    async fn delete_identity(
206        &mut self,
207        address: &ProtocolAddress,
208    ) -> Result<(), Box<dyn std::error::Error>> {
209        let key = address_key(address);
210        let mut store = self.identity_keys.lock().await;
211        store.remove(&key);
212        Ok(())
213    }
214    async fn clear_all_identities(&mut self) -> Result<(), Box<dyn std::error::Error>> {
215        let mut store = self.identity_keys.lock().await;
216        store.clear();
217        Ok(())
218    }
219    async fn clear_local_identity(&mut self) -> Result<(), Box<dyn std::error::Error>> {
220        let mut identity_store = self.local_identity_key_pair.lock().await;
221        let mut registration_store = self.local_registration_id.lock().await;
222        *identity_store = None;
223        *registration_store = None;
224        Ok(())
225    }
226}
227
228/// In-memory pre-key storage (data lost on process termination)
229pub struct MemoryPreKeyStore {
230    pre_keys: Arc<Mutex<HashMap<u32, KeyPair>>>,
231}
232
233impl Default for MemoryPreKeyStore {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239impl MemoryPreKeyStore {
240    pub fn new() -> Self {
241        Self {
242            pre_keys: Arc::new(Mutex::new(HashMap::new())),
243        }
244    }
245}
246
247#[async_trait(?Send)]
248impl PreKeyStore for MemoryPreKeyStore {
249    async fn get_pre_key(&self, prekey_id: PreKeyId) -> Result<PreKeyRecord, SignalProtocolError> {
250        let store = self.pre_keys.lock().await;
251        match store.get(&u32::from(prekey_id)) {
252            Some(key_pair) => Ok(PreKeyRecord::new(prekey_id, key_pair)),
253            None => Err(SignalProtocolError::InvalidPreKeyId),
254        }
255    }
256
257    async fn save_pre_key(
258        &mut self,
259        prekey_id: PreKeyId,
260        record: &PreKeyRecord,
261    ) -> Result<(), SignalProtocolError> {
262        let key_pair = record.key_pair()?;
263        let mut store = self.pre_keys.lock().await;
264        store.insert(u32::from(prekey_id), key_pair);
265        Ok(())
266    }
267
268    async fn remove_pre_key(&mut self, prekey_id: PreKeyId) -> Result<(), SignalProtocolError> {
269        let mut store = self.pre_keys.lock().await;
270        store.remove(&u32::from(prekey_id));
271        Ok(())
272    }
273}
274
275#[async_trait(?Send)]
276impl ExtendedPreKeyStore for MemoryPreKeyStore {
277    async fn pre_key_count(&self) -> usize {
278        let store = self.pre_keys.lock().await;
279        store.len()
280    }
281
282    async fn clear_all_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
283        let mut store = self.pre_keys.lock().await;
284        store.clear();
285        Ok(())
286    }
287
288    async fn get_max_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
289        let store = self.pre_keys.lock().await;
290        Ok(store.keys().max().copied())
291    }
292
293    async fn delete_pre_key(&mut self, id: PreKeyId) -> Result<(), Box<dyn std::error::Error>> {
294        let mut store = self.pre_keys.lock().await;
295        store.remove(&u32::from(id));
296        Ok(())
297    }
298}
299
300/// In-memory signed pre-key storage (data lost on process termination)
301pub struct MemorySignedPreKeyStore {
302    signed_pre_keys: Arc<Mutex<HashMap<u32, SignedPreKeyRecord>>>,
303}
304
305impl Default for MemorySignedPreKeyStore {
306    fn default() -> Self {
307        Self::new()
308    }
309}
310
311impl MemorySignedPreKeyStore {
312    pub fn new() -> Self {
313        Self {
314            signed_pre_keys: Arc::new(Mutex::new(HashMap::new())),
315        }
316    }
317}
318
319#[async_trait(?Send)]
320impl SignedPreKeyStore for MemorySignedPreKeyStore {
321    async fn get_signed_pre_key(
322        &self,
323        signed_prekey_id: SignedPreKeyId,
324    ) -> Result<SignedPreKeyRecord, SignalProtocolError> {
325        let store = self.signed_pre_keys.lock().await;
326        match store.get(&u32::from(signed_prekey_id)) {
327            Some(record) => Ok(record.clone()),
328            None => Err(SignalProtocolError::InvalidSignedPreKeyId),
329        }
330    }
331
332    async fn save_signed_pre_key(
333        &mut self,
334        signed_prekey_id: SignedPreKeyId,
335        record: &SignedPreKeyRecord,
336    ) -> Result<(), SignalProtocolError> {
337        let mut store = self.signed_pre_keys.lock().await;
338        store.insert(u32::from(signed_prekey_id), record.clone());
339        Ok(())
340    }
341}
342
343#[async_trait(?Send)]
344impl ExtendedSignedPreKeyStore for MemorySignedPreKeyStore {
345    async fn signed_pre_key_count(&self) -> usize {
346        let store = self.signed_pre_keys.lock().await;
347        store.len()
348    }
349
350    async fn clear_all_signed_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
351        let mut store = self.signed_pre_keys.lock().await;
352        store.clear();
353        Ok(())
354    }
355
356    async fn get_max_signed_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
357        let store = self.signed_pre_keys.lock().await;
358        Ok(store.keys().max().copied())
359    }
360
361    async fn delete_signed_pre_key(
362        &mut self,
363        id: SignedPreKeyId,
364    ) -> Result<(), Box<dyn std::error::Error>> {
365        let mut store = self.signed_pre_keys.lock().await;
366        store.remove(&u32::from(id));
367        Ok(())
368    }
369
370    async fn get_signed_pre_keys_older_than(
371        &self,
372        timestamp_millis: u64,
373    ) -> Result<Vec<SignedPreKeyId>, Box<dyn std::error::Error>> {
374        let store = self.signed_pre_keys.lock().await;
375        let mut expired = Vec::new();
376        for (id, record) in store.iter() {
377            if let Ok(ts) = record.timestamp() {
378                if ts.epoch_millis() < timestamp_millis {
379                    expired.push(SignedPreKeyId::from(*id));
380                }
381            }
382        }
383        Ok(expired)
384    }
385}
386
387/// In-memory Kyber post-quantum pre-key storage (data lost on process termination)
388pub struct MemoryKyberPreKeyStore {
389    kyber_pre_keys: Arc<Mutex<HashMap<u32, KyberPreKeyRecord>>>,
390}
391
392impl Default for MemoryKyberPreKeyStore {
393    fn default() -> Self {
394        Self::new()
395    }
396}
397
398impl MemoryKyberPreKeyStore {
399    pub fn new() -> Self {
400        Self {
401            kyber_pre_keys: Arc::new(Mutex::new(HashMap::new())),
402        }
403    }
404}
405
406#[async_trait(?Send)]
407impl KyberPreKeyStore for MemoryKyberPreKeyStore {
408    async fn get_kyber_pre_key(
409        &self,
410        kyber_prekey_id: KyberPreKeyId,
411    ) -> Result<KyberPreKeyRecord, SignalProtocolError> {
412        let store = self.kyber_pre_keys.lock().await;
413        match store.get(&u32::from(kyber_prekey_id)) {
414            Some(record) => Ok(record.clone()),
415            None => Err(SignalProtocolError::InvalidKyberPreKeyId),
416        }
417    }
418
419    async fn save_kyber_pre_key(
420        &mut self,
421        kyber_prekey_id: KyberPreKeyId,
422        record: &KyberPreKeyRecord,
423    ) -> Result<(), SignalProtocolError> {
424        let mut store = self.kyber_pre_keys.lock().await;
425        store.insert(u32::from(kyber_prekey_id), record.clone());
426        Ok(())
427    }
428
429    async fn mark_kyber_pre_key_used(
430        &mut self,
431        _kyber_prekey_id: KyberPreKeyId,
432    ) -> Result<(), SignalProtocolError> {
433        // For memory storage, we don't implement usage tracking since this is atest/development storage
434        // In production storage, this would typically mark keys as consumed to prevent reuse
435        Ok(())
436    }
437}
438
439#[async_trait(?Send)]
440impl ExtendedKyberPreKeyStore for MemoryKyberPreKeyStore {
441    async fn kyber_pre_key_count(&self) -> usize {
442        let store = self.kyber_pre_keys.lock().await;
443        store.len()
444    }
445
446    async fn clear_all_kyber_pre_keys(&mut self) -> Result<(), Box<dyn std::error::Error>> {
447        let mut store = self.kyber_pre_keys.lock().await;
448        store.clear();
449        Ok(())
450    }
451
452    async fn get_max_kyber_pre_key_id(&self) -> Result<Option<u32>, Box<dyn std::error::Error>> {
453        let store = self.kyber_pre_keys.lock().await;
454        Ok(store.keys().max().copied())
455    }
456
457    async fn delete_kyber_pre_key(
458        &mut self,
459        id: KyberPreKeyId,
460    ) -> Result<(), Box<dyn std::error::Error>> {
461        let mut store = self.kyber_pre_keys.lock().await;
462        store.remove(&u32::from(id));
463        Ok(())
464    }
465
466    async fn get_kyber_pre_keys_older_than(
467        &self,
468        timestamp_millis: u64,
469    ) -> Result<Vec<KyberPreKeyId>, Box<dyn std::error::Error>> {
470        let store = self.kyber_pre_keys.lock().await;
471        let mut expired = Vec::new();
472        for (id, record) in store.iter() {
473            if let Ok(ts) = record.timestamp() {
474                if ts.epoch_millis() < timestamp_millis {
475                    expired.push(KyberPreKeyId::from(*id));
476                }
477            }
478        }
479        Ok(expired)
480    }
481}
482
483/// Complete in-memory Signal Protocol storage implementation
484pub struct MemoryStorage {
485    pub session_store: MemorySessionStore,
486    pub identity_store: MemoryIdentityStore,
487    pub pre_key_store: MemoryPreKeyStore,
488    pub signed_pre_key_store: MemorySignedPreKeyStore,
489    pub kyber_pre_key_store: MemoryKyberPreKeyStore,
490}
491
492impl Default for MemoryStorage {
493    fn default() -> Self {
494        Self::new()
495    }
496}
497
498impl MemoryStorage {
499    pub fn new() -> Self {
500        Self {
501            session_store: MemorySessionStore::new(),
502            identity_store: MemoryIdentityStore::new(),
503            pre_key_store: MemoryPreKeyStore::new(),
504            signed_pre_key_store: MemorySignedPreKeyStore::new(),
505            kyber_pre_key_store: MemoryKyberPreKeyStore::new(),
506        }
507    }
508}
509
510impl SignalStorageContainer for MemoryStorage {
511    type SessionStore = MemorySessionStore;
512    type IdentityStore = MemoryIdentityStore;
513    type PreKeyStore = MemoryPreKeyStore;
514    type SignedPreKeyStore = MemorySignedPreKeyStore;
515    type KyberPreKeyStore = MemoryKyberPreKeyStore;
516
517    fn session_store(&mut self) -> &mut Self::SessionStore {
518        &mut self.session_store
519    }
520
521    fn identity_store(&mut self) -> &mut Self::IdentityStore {
522        &mut self.identity_store
523    }
524
525    fn pre_key_store(&mut self) -> &mut Self::PreKeyStore {
526        &mut self.pre_key_store
527    }
528
529    fn signed_pre_key_store(&mut self) -> &mut Self::SignedPreKeyStore {
530        &mut self.signed_pre_key_store
531    }
532
533    fn kyber_pre_key_store(&mut self) -> &mut Self::KyberPreKeyStore {
534        &mut self.kyber_pre_key_store
535    }
536
537    fn initialize(&mut self) -> Result<(), Box<dyn std::error::Error>> {
538        Ok(())
539    }
540
541    fn close(&mut self) -> Result<(), Box<dyn std::error::Error>> {
542        Ok(())
543    }
544
545    fn storage_type(&self) -> &'static str {
546        "memory"
547    }
548}
549
550#[async_trait(?Send)]
551impl ExtendedStorageOps for MemoryStorage {
552    async fn establish_session_from_bundle(
553        &mut self,
554        address: &ProtocolAddress,
555        bundle: &PreKeyBundle,
556    ) -> Result<(), Box<dyn std::error::Error>> {
557        MemoryStorage::establish_session_from_bundle(self, address, bundle).await
558    }
559
560    async fn encrypt_message(
561        &mut self,
562        remote_address: &ProtocolAddress,
563        plaintext: &[u8],
564    ) -> Result<CiphertextMessage, SignalProtocolError> {
565        MemoryStorage::encrypt_message(self, remote_address, plaintext).await
566    }
567
568    async fn decrypt_message(
569        &mut self,
570        remote_address: &ProtocolAddress,
571        ciphertext: &CiphertextMessage,
572    ) -> Result<Vec<u8>, SignalProtocolError> {
573        MemoryStorage::decrypt_message(self, remote_address, ciphertext).await
574    }
575}
576
577impl MemoryStorage {
578    pub async fn establish_session_from_bundle(
579        &mut self,
580        address: &ProtocolAddress,
581        bundle: &PreKeyBundle,
582    ) -> Result<(), Box<dyn std::error::Error>> {
583        let mut rng = rand::rng();
584        let timestamp = std::time::SystemTime::now();
585
586        process_prekey_bundle(
587            address,
588            &mut self.session_store,
589            &mut self.identity_store,
590            bundle,
591            timestamp,
592            &mut rng,
593            UsePQRatchet::Yes,
594        )
595        .await?;
596
597        Ok(())
598    }
599
600    pub async fn encrypt_message(
601        &mut self,
602        remote_address: &ProtocolAddress,
603        plaintext: &[u8],
604    ) -> Result<CiphertextMessage, SignalProtocolError> {
605        let mut rng = rand::rng();
606        let now = std::time::SystemTime::now();
607        message_encrypt(
608            plaintext,
609            remote_address,
610            &mut self.session_store,
611            &mut self.identity_store,
612            now,
613            &mut rng,
614        )
615        .await
616    }
617
618    pub async fn decrypt_message(
619        &mut self,
620        remote_address: &ProtocolAddress,
621        ciphertext: &CiphertextMessage,
622    ) -> Result<Vec<u8>, SignalProtocolError> {
623        let mut rng = rand::rng();
624
625        message_decrypt(
626            ciphertext,
627            remote_address,
628            &mut self.session_store,
629            &mut self.identity_store,
630            &mut self.pre_key_store,
631            &self.signed_pre_key_store,
632            &mut self.kyber_pre_key_store,
633            &mut rng,
634            UsePQRatchet::Yes,
635        )
636        .await
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    #[tokio::test]
645    async fn test_memory_storage_creation() -> Result<(), Box<dyn std::error::Error>> {
646        let mut storage = MemoryStorage::new();
647        storage.initialize()?;
648        assert_eq!(storage.storage_type(), "memory");
649        Ok(())
650    }
651
652    #[tokio::test]
653    async fn test_session_storage() -> Result<(), Box<dyn std::error::Error>> {
654        let mut storage = MemoryStorage::new();
655        let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
656        let session_record = SessionRecord::new_fresh();
657
658        assert_eq!(storage.session_store.session_count().await, 0);
659
660        storage
661            .session_store
662            .store_session(&address, &session_record)
663            .await?;
664        assert_eq!(storage.session_store.session_count().await, 1);
665
666        let loaded = storage.session_store.load_session(&address).await?;
667        assert!(loaded.is_some());
668
669        storage.session_store.clear_all_sessions().await?;
670        assert_eq!(storage.session_store.session_count().await, 0);
671
672        Ok(())
673    }
674
675    #[tokio::test]
676    async fn test_identity_storage() -> Result<(), Box<dyn std::error::Error>> {
677        let mut storage = MemoryStorage::new();
678        let mut rng = rand::rng();
679        let identity_key_pair = IdentityKeyPair::generate(&mut rng);
680        let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
681
682        storage
683            .identity_store
684            .set_local_identity_key_pair(&identity_key_pair)
685            .await?;
686        storage
687            .identity_store
688            .set_local_registration_id(12345)
689            .await?;
690
691        let retrieved_identity = storage.identity_store.get_identity_key_pair().await?;
692        assert_eq!(
693            retrieved_identity.identity_key().serialize(),
694            identity_key_pair.identity_key().serialize()
695        );
696
697        let retrieved_registration = storage.identity_store.get_local_registration_id().await?;
698        assert_eq!(retrieved_registration, 12345);
699
700        assert_eq!(storage.identity_store.identity_count().await, 0);
701
702        storage
703            .identity_store
704            .save_identity(&address, identity_key_pair.identity_key())
705            .await?;
706        assert_eq!(storage.identity_store.identity_count().await, 1);
707
708        let retrieved = storage.identity_store.get_identity(&address).await?;
709        assert!(retrieved.is_some());
710        assert_eq!(
711            retrieved.unwrap().serialize(),
712            identity_key_pair.identity_key().serialize()
713        );
714
715        Ok(())
716    }
717
718    #[tokio::test]
719    async fn test_key_generation_storage_integration() -> Result<(), Box<dyn std::error::Error>> {
720        use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
721        use crate::memory_storage::MemoryStorage;
722        use crate::storage_trait::{
723            ExtendedIdentityStore, ExtendedPreKeyStore, ExtendedSignedPreKeyStore,
724        };
725
726        let identity_key_pair = generate_identity_key_pair().await?;
727        let pre_keys = generate_pre_keys(1, 5).await?;
728        let signed_pre_key = generate_signed_pre_key(&identity_key_pair, 1).await?;
729
730        let mut storage = MemoryStorage::new();
731
732        let address = ProtocolAddress::new("test_user".to_string(), DeviceId::new(1)?);
733        storage
734            .identity_store
735            .save_identity(&address, identity_key_pair.identity_key())
736            .await?;
737        let retrieved_identity = storage.identity_store.get_identity(&address).await?;
738        assert!(retrieved_identity.is_some());
739        assert_eq!(
740            retrieved_identity.unwrap().serialize(),
741            identity_key_pair.identity_key().serialize()
742        );
743
744        for (key_id, key_pair) in &pre_keys {
745            let record = PreKeyRecord::new((*key_id).into(), key_pair);
746            storage
747                .pre_key_store
748                .save_pre_key((*key_id).into(), &record)
749                .await?;
750            let retrieved_record = storage.pre_key_store.get_pre_key((*key_id).into()).await?;
751            assert_eq!(
752                retrieved_record.key_pair()?.public_key.serialize(),
753                key_pair.public_key.serialize()
754            );
755        }
756
757        storage
758            .signed_pre_key_store
759            .save_signed_pre_key(signed_pre_key.id()?, &signed_pre_key)
760            .await?;
761        let retrieved_signed_key = storage
762            .signed_pre_key_store
763            .get_signed_pre_key(signed_pre_key.id()?)
764            .await?;
765        assert_eq!(retrieved_signed_key.id()?, signed_pre_key.id()?);
766
767        assert_eq!(storage.identity_store.identity_count().await, 1);
768        assert_eq!(storage.pre_key_store.pre_key_count().await, 5);
769        assert_eq!(storage.signed_pre_key_store.signed_pre_key_count().await, 1);
770
771        Ok(())
772    }
773
774    #[tokio::test]
775    async fn test_memory_session_delete_operations() -> Result<(), Box<dyn std::error::Error>> {
776        use crate::keys::{generate_identity_key_pair, generate_pre_keys, generate_signed_pre_key};
777        use libsignal_protocol::*;
778
779        let mut storage = MemoryStorage::new();
780
781        let identity = generate_identity_key_pair().await?;
782        let registration_id = 12345u32;
783        storage
784            .identity_store
785            .set_local_identity_key_pair(&identity)
786            .await?;
787        storage
788            .identity_store
789            .set_local_registration_id(registration_id)
790            .await?;
791
792        let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
793        let charlie_address = ProtocolAddress::new("charlie".to_string(), DeviceId::new(1)?);
794
795        let bob_identity = generate_identity_key_pair().await?;
796        let bob_pre_keys = generate_pre_keys(1, 1).await?;
797        let bob_signed_pre_key = generate_signed_pre_key(&bob_identity, 1).await?;
798
799        let mut rng = rand::rng();
800        let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
801        let kyber_signature = bob_identity
802            .private_key()
803            .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
804
805        let bob_bundle = PreKeyBundle::new(
806            registration_id,
807            DeviceId::new(1)?,
808            Some((
809                PreKeyId::from(bob_pre_keys[0].0),
810                bob_pre_keys[0].1.public_key,
811            )),
812            SignedPreKeyId::from(1u32),
813            bob_signed_pre_key.public_key()?,
814            bob_signed_pre_key.signature()?.to_vec(),
815            KyberPreKeyId::from(1u32),
816            kyber_keypair.public_key,
817            kyber_signature.to_vec(),
818            *bob_identity.identity_key(),
819        )?;
820
821        storage
822            .establish_session_from_bundle(&bob_address, &bob_bundle)
823            .await?;
824        storage
825            .establish_session_from_bundle(&charlie_address, &bob_bundle)
826            .await?;
827        assert_eq!(
828            storage.session_store.session_count().await,
829            2,
830            "Should have 2 sessions"
831        );
832
833        storage.session_store.delete_session(&bob_address).await?;
834        assert_eq!(
835            storage.session_store.session_count().await,
836            1,
837            "Should have 1 session after deleting Bob's"
838        );
839
840        let bob_session = storage.session_store.load_session(&bob_address).await?;
841        assert!(bob_session.is_none(), "Bob's session should be deleted");
842
843        let charlie_session = storage.session_store.load_session(&charlie_address).await?;
844        assert!(
845            charlie_session.is_some(),
846            "Charlie's session should still exist"
847        );
848
849        Ok(())
850    }
851
852    #[tokio::test]
853    async fn test_memory_identity_management_operations() -> Result<(), Box<dyn std::error::Error>>
854    {
855        use crate::keys::generate_identity_key_pair;
856        use libsignal_protocol::*;
857
858        let mut storage = MemoryStorage::new();
859
860        let local_identity = generate_identity_key_pair().await?;
861        storage
862            .identity_store
863            .set_local_identity_key_pair(&local_identity)
864            .await?;
865
866        let bob_address = ProtocolAddress::new("bob".to_string(), DeviceId::new(1)?);
867        let charlie_address = ProtocolAddress::new("charlie".to_string(), DeviceId::new(1)?);
868
869        let bob_identity = generate_identity_key_pair().await?;
870        let charlie_identity = generate_identity_key_pair().await?;
871
872        let result = storage
873            .identity_store
874            .get_peer_identity(&bob_address)
875            .await?;
876        assert!(
877            result.is_none(),
878            "Should return None for non-existent peer identity"
879        );
880
881        storage
882            .identity_store
883            .save_identity(&bob_address, bob_identity.identity_key())
884            .await?;
885        storage
886            .identity_store
887            .save_identity(&charlie_address, charlie_identity.identity_key())
888            .await?;
889
890        assert_eq!(
891            storage.identity_store.identity_count().await,
892            2,
893            "Should have 2 peer identities"
894        );
895
896        let retrieved_bob = storage
897            .identity_store
898            .get_peer_identity(&bob_address)
899            .await?;
900        assert!(retrieved_bob.is_some(), "Should retrieve Bob's identity");
901        assert_eq!(
902            retrieved_bob.unwrap(),
903            *bob_identity.identity_key(),
904            "Retrieved identity should match stored"
905        );
906
907        let retrieved_charlie = storage
908            .identity_store
909            .get_peer_identity(&charlie_address)
910            .await?;
911        assert!(
912            retrieved_charlie.is_some(),
913            "Should retrieve Charlie's identity"
914        );
915        assert_eq!(
916            retrieved_charlie.unwrap(),
917            *charlie_identity.identity_key(),
918            "Retrieved identity should match stored"
919        );
920
921        storage.identity_store.delete_identity(&bob_address).await?;
922        assert_eq!(
923            storage.identity_store.identity_count().await,
924            1,
925            "Should have 1 identity after deleting Bob's"
926        );
927
928        let deleted_bob = storage
929            .identity_store
930            .get_peer_identity(&bob_address)
931            .await?;
932        assert!(deleted_bob.is_none(), "Bob's identity should be deleted");
933
934        let still_charlie = storage
935            .identity_store
936            .get_peer_identity(&charlie_address)
937            .await?;
938        assert!(
939            still_charlie.is_some(),
940            "Charlie's identity should still exist"
941        );
942
943        storage.identity_store.clear_all_identities().await?;
944        assert_eq!(
945            storage.identity_store.identity_count().await,
946            0,
947            "Should have 0 identities after clearing all"
948        );
949
950        let cleared_charlie = storage
951            .identity_store
952            .get_peer_identity(&charlie_address)
953            .await?;
954        assert!(
955            cleared_charlie.is_none(),
956            "Charlie's identity should be cleared"
957        );
958
959        Ok(())
960    }
961
962    #[tokio::test]
963    async fn test_memory_clear_local_identity() -> Result<(), Box<dyn std::error::Error>> {
964        use crate::keys::generate_identity_key_pair;
965        use libsignal_protocol::*;
966
967        let mut storage = MemoryStorage::new();
968
969        let identity = generate_identity_key_pair().await?;
970        let registration_id = 12345u32;
971        storage
972            .identity_store
973            .set_local_identity_key_pair(&identity)
974            .await?;
975        storage
976            .identity_store
977            .set_local_registration_id(registration_id)
978            .await?;
979
980        let retrieved_identity = storage.identity_store.get_identity_key_pair().await?;
981        assert_eq!(
982            retrieved_identity.identity_key().serialize(),
983            identity.identity_key().serialize()
984        );
985
986        let retrieved_registration = storage.identity_store.get_local_registration_id().await?;
987        assert_eq!(retrieved_registration, registration_id);
988
989        storage.identity_store.clear_local_identity().await?;
990
991        let result = storage.identity_store.get_identity_key_pair().await;
992        assert!(
993            result.is_err(),
994            "Should return error when local identity is cleared"
995        );
996
997        let result = storage.identity_store.get_local_registration_id().await;
998        assert!(
999            result.is_err(),
1000            "Should return error when local registration ID is cleared"
1001        );
1002
1003        Ok(())
1004    }
1005}