signal_bridge/
key_rotation.rs

1//! Key rotation functionality for Signal Protocol keys
2//!
3//! Implements periodic rotation of signed pre-keys and Kyber pre-keys,
4//! plus consumption-based pre-key management following Signal Protocol security model.
5
6use crate::keys::{generate_pre_keys, generate_signed_pre_key};
7use crate::storage_trait::{
8    ExtendedKyberPreKeyStore, ExtendedPreKeyStore, ExtendedSignedPreKeyStore,
9    SignalStorageContainer,
10};
11use libsignal_protocol::*;
12
13/// Minimum pre-key count before replenishment is triggered
14pub const MIN_PRE_KEY_COUNT: usize = 50;
15/// Number of pre-keys to generate when replenishing
16pub const REPLENISH_COUNT: u32 = 100;
17/// Key rotation interval (7 days in seconds)
18pub const ROTATION_INTERVAL_SECS: u64 = 7 * 24 * 60 * 60;
19/// Grace period before deleting old keys (7 days in seconds)
20pub const GRACE_PERIOD_SECS: u64 = 7 * 24 * 60 * 60;
21
22/// Rotates the signed pre-key by generating and storing a new one
23///
24/// # Arguments
25/// * `storage` - Signal Protocol storage container
26/// * `identity_key_pair` - Identity key pair to sign the new pre-key
27pub async fn rotate_signed_pre_key<S: SignalStorageContainer>(
28    storage: &mut S,
29    identity_key_pair: &IdentityKeyPair,
30) -> Result<(), Box<dyn std::error::Error>> {
31    let current_max_id = storage
32        .signed_pre_key_store()
33        .get_max_signed_pre_key_id()
34        .await?
35        .unwrap_or(0);
36
37    let new_signed_pre_key = generate_signed_pre_key(identity_key_pair, current_max_id + 1).await?;
38
39    storage
40        .signed_pre_key_store()
41        .save_signed_pre_key(new_signed_pre_key.id()?, &new_signed_pre_key)
42        .await?;
43
44    Ok(())
45}
46
47/// Checks if the current signed pre-key needs rotation
48///
49/// # Arguments
50/// * `storage` - Signal Protocol storage container
51///
52/// # Returns
53/// true if key is older than rotation interval, false otherwise
54pub async fn signed_pre_key_needs_rotation<S: SignalStorageContainer>(
55    storage: &mut S,
56) -> Result<bool, Box<dyn std::error::Error>> {
57    let max_id = storage
58        .signed_pre_key_store()
59        .get_max_signed_pre_key_id()
60        .await?;
61
62    let Some(id) = max_id else {
63        return Ok(true);
64    };
65
66    let current_key = storage
67        .signed_pre_key_store()
68        .get_signed_pre_key(SignedPreKeyId::from(id))
69        .await?;
70
71    let key_timestamp = current_key.timestamp()?.epoch_millis();
72    let now = std::time::SystemTime::now()
73        .duration_since(std::time::UNIX_EPOCH)?
74        .as_millis() as u64;
75
76    let key_age_secs = (now - key_timestamp) / 1000;
77    Ok(key_age_secs > ROTATION_INTERVAL_SECS)
78}
79
80/// Deletes signed pre-keys older than the grace period
81///
82/// # Arguments
83/// * `storage` - Signal Protocol storage container
84pub async fn cleanup_expired_signed_pre_keys<S: SignalStorageContainer>(
85    storage: &mut S,
86) -> Result<(), Box<dyn std::error::Error>> {
87    let now = std::time::SystemTime::now()
88        .duration_since(std::time::UNIX_EPOCH)?
89        .as_millis() as u64;
90
91    let cutoff = now - (GRACE_PERIOD_SECS * 1000);
92
93    let expired_ids = storage
94        .signed_pre_key_store()
95        .get_signed_pre_keys_older_than(cutoff)
96        .await?;
97
98    let current_count = storage.signed_pre_key_store().signed_pre_key_count().await;
99    if current_count <= 1 {
100        return Ok(());
101    }
102
103    for id in expired_ids {
104        if storage.signed_pre_key_store().signed_pre_key_count().await > 1 {
105            storage
106                .signed_pre_key_store()
107                .delete_signed_pre_key(id)
108                .await?;
109        }
110    }
111
112    Ok(())
113}
114
115/// Consumes a pre-key and triggers replenishment if count falls below threshold
116///
117/// # Arguments
118/// * `storage` - Signal Protocol storage container
119/// * `pre_key_id` - ID of the pre-key to consume
120pub async fn consume_pre_key<S: SignalStorageContainer>(
121    storage: &mut S,
122    pre_key_id: PreKeyId,
123) -> Result<(), Box<dyn std::error::Error>> {
124    storage.pre_key_store().delete_pre_key(pre_key_id).await?;
125
126    let current_count = storage.pre_key_store().pre_key_count().await;
127    if current_count < MIN_PRE_KEY_COUNT {
128        replenish_pre_keys(storage).await?;
129    }
130
131    Ok(())
132}
133
134/// Generates and stores new batch of pre-keys
135///
136/// # Arguments
137/// * `storage` - Signal Protocol storage container
138pub async fn replenish_pre_keys<S: SignalStorageContainer>(
139    storage: &mut S,
140) -> Result<(), Box<dyn std::error::Error>> {
141    let next_id = storage
142        .pre_key_store()
143        .get_max_pre_key_id()
144        .await?
145        .map(|id| id + 1)
146        .unwrap_or(1);
147
148    let new_pre_keys = generate_pre_keys(next_id, REPLENISH_COUNT).await?;
149
150    for (key_id, key_pair) in &new_pre_keys {
151        let record = PreKeyRecord::new((*key_id).into(), key_pair);
152        storage
153            .pre_key_store()
154            .save_pre_key((*key_id).into(), &record)
155            .await?;
156    }
157
158    Ok(())
159}
160
161/// Rotates the Kyber post-quantum pre-key by generating and storing a new one
162///
163/// # Arguments
164/// * `storage` - Signal Protocol storage container
165/// * `identity_key_pair` - Identity key pair to sign the new Kyber pre-key
166pub async fn rotate_kyber_pre_key<S: SignalStorageContainer>(
167    storage: &mut S,
168    identity_key_pair: &IdentityKeyPair,
169) -> Result<(), Box<dyn std::error::Error>> {
170    let current_max_id = storage
171        .kyber_pre_key_store()
172        .get_max_kyber_pre_key_id()
173        .await?
174        .unwrap_or(0);
175
176    let mut rng = rand::rng();
177    let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
178    let kyber_signature = identity_key_pair
179        .private_key()
180        .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
181
182    let now = std::time::SystemTime::now();
183    let kyber_record = KyberPreKeyRecord::new(
184        KyberPreKeyId::from(current_max_id + 1),
185        Timestamp::from_epoch_millis(now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64),
186        &kyber_keypair,
187        &kyber_signature,
188    );
189
190    storage
191        .kyber_pre_key_store()
192        .save_kyber_pre_key(KyberPreKeyId::from(current_max_id + 1), &kyber_record)
193        .await?;
194
195    Ok(())
196}
197
198/// Checks if the current Kyber pre-key needs rotation
199///
200/// # Arguments
201/// * `storage` - Signal Protocol storage container
202///
203/// # Returns
204/// true if key is older than rotation interval, false otherwise
205pub async fn kyber_pre_key_needs_rotation<S: SignalStorageContainer>(
206    storage: &mut S,
207) -> Result<bool, Box<dyn std::error::Error>> {
208    let max_id = storage
209        .kyber_pre_key_store()
210        .get_max_kyber_pre_key_id()
211        .await?;
212
213    let Some(id) = max_id else {
214        return Ok(true);
215    };
216
217    let current_key = storage
218        .kyber_pre_key_store()
219        .get_kyber_pre_key(KyberPreKeyId::from(id))
220        .await?;
221
222    let key_timestamp = current_key.timestamp()?.epoch_millis();
223    let now = std::time::SystemTime::now()
224        .duration_since(std::time::UNIX_EPOCH)?
225        .as_millis() as u64;
226
227    let key_age_secs = (now - key_timestamp) / 1000;
228    Ok(key_age_secs > ROTATION_INTERVAL_SECS)
229}
230
231/// Deletes Kyber pre-keys older than the grace period
232///
233/// # Arguments
234/// * `storage` - Signal Protocol storage container
235pub async fn cleanup_expired_kyber_pre_keys<S: SignalStorageContainer>(
236    storage: &mut S,
237) -> Result<(), Box<dyn std::error::Error>> {
238    let now = std::time::SystemTime::now()
239        .duration_since(std::time::UNIX_EPOCH)?
240        .as_millis() as u64;
241
242    let cutoff = now - (GRACE_PERIOD_SECS * 1000);
243
244    let expired_ids = storage
245        .kyber_pre_key_store()
246        .get_kyber_pre_keys_older_than(cutoff)
247        .await?;
248
249    let current_count = storage.kyber_pre_key_store().kyber_pre_key_count().await;
250    if current_count <= 1 {
251        return Ok(());
252    }
253
254    for id in expired_ids {
255        if storage.kyber_pre_key_store().kyber_pre_key_count().await > 1 {
256            storage
257                .kyber_pre_key_store()
258                .delete_kyber_pre_key(id)
259                .await?;
260        }
261    }
262
263    Ok(())
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::keys::{generate_identity_key_pair, generate_signed_pre_key};
270    use crate::sqlite_storage::SqliteStorage;
271    use crate::storage_trait::{
272        ExtendedIdentityStore, ExtendedKyberPreKeyStore, ExtendedPreKeyStore,
273        ExtendedSignedPreKeyStore, SignalStorageContainer,
274    };
275
276    #[tokio::test]
277    async fn test_rotate_signed_pre_key_generates_new_key() -> Result<(), Box<dyn std::error::Error>>
278    {
279        let mut storage = SqliteStorage::new(":memory:").await?;
280        storage.initialize()?;
281
282        let identity_key_pair = generate_identity_key_pair().await?;
283        storage
284            .identity_store()
285            .set_local_identity_key_pair(&identity_key_pair)
286            .await?;
287
288        let initial_key = generate_signed_pre_key(&identity_key_pair, 1).await?;
289        storage
290            .signed_pre_key_store()
291            .save_signed_pre_key(initial_key.id()?, &initial_key)
292            .await?;
293
294        assert_eq!(
295            storage.signed_pre_key_store().signed_pre_key_count().await,
296            1
297        );
298
299        rotate_signed_pre_key(&mut storage, &identity_key_pair).await?;
300
301        assert_eq!(
302            storage.signed_pre_key_store().signed_pre_key_count().await,
303            2
304        );
305
306        let max_id = storage
307            .signed_pre_key_store()
308            .get_max_signed_pre_key_id()
309            .await?
310            .expect("Should have max ID");
311        assert_eq!(max_id, 2);
312
313        Ok(())
314    }
315
316    #[tokio::test]
317    async fn test_signed_pre_key_needs_rotation_when_old() -> Result<(), Box<dyn std::error::Error>>
318    {
319        let mut storage = SqliteStorage::new(":memory:").await?;
320        storage.initialize()?;
321
322        let identity_key_pair = generate_identity_key_pair().await?;
323        storage
324            .identity_store()
325            .set_local_identity_key_pair(&identity_key_pair)
326            .await?;
327
328        let old_timestamp = std::time::SystemTime::now()
329            .duration_since(std::time::UNIX_EPOCH)?
330            .as_secs()
331            - (8 * 24 * 60 * 60);
332
333        let mut rng = rand::rng();
334        let key_pair = KeyPair::generate(&mut rng);
335        let signature = identity_key_pair
336            .private_key()
337            .calculate_signature(&key_pair.public_key.serialize(), &mut rng)?;
338        let old_key = SignedPreKeyRecord::new(
339            1.into(),
340            Timestamp::from_epoch_millis(old_timestamp * 1000),
341            &key_pair,
342            &signature,
343        );
344        storage
345            .signed_pre_key_store()
346            .save_signed_pre_key(old_key.id()?, &old_key)
347            .await?;
348
349        let needs_rotation = signed_pre_key_needs_rotation(&mut storage).await?;
350        assert!(
351            needs_rotation,
352            "Key older than rotation interval should need rotation"
353        );
354
355        Ok(())
356    }
357
358    #[tokio::test]
359    async fn test_signed_pre_key_does_not_need_rotation_when_fresh(
360    ) -> Result<(), Box<dyn std::error::Error>> {
361        let mut storage = SqliteStorage::new(":memory:").await?;
362        storage.initialize()?;
363
364        let identity_key_pair = generate_identity_key_pair().await?;
365        storage
366            .identity_store()
367            .set_local_identity_key_pair(&identity_key_pair)
368            .await?;
369
370        let fresh_key = generate_signed_pre_key(&identity_key_pair, 1).await?;
371        storage
372            .signed_pre_key_store()
373            .save_signed_pre_key(fresh_key.id()?, &fresh_key)
374            .await?;
375
376        let needs_rotation = signed_pre_key_needs_rotation(&mut storage).await?;
377        assert!(!needs_rotation, "Fresh key should not need rotation");
378
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn test_cleanup_expired_signed_pre_keys() -> Result<(), Box<dyn std::error::Error>> {
384        let mut storage = SqliteStorage::new(":memory:").await?;
385        storage.initialize()?;
386
387        let identity_key_pair = generate_identity_key_pair().await?;
388        storage
389            .identity_store()
390            .set_local_identity_key_pair(&identity_key_pair)
391            .await?;
392
393        let old_timestamp = std::time::SystemTime::now()
394            .duration_since(std::time::UNIX_EPOCH)?
395            .as_secs()
396            - (14 * 24 * 60 * 60);
397
398        let mut rng = rand::rng();
399        let key_pair = KeyPair::generate(&mut rng);
400        let signature = identity_key_pair
401            .private_key()
402            .calculate_signature(&key_pair.public_key.serialize(), &mut rng)?;
403        let expired_key = SignedPreKeyRecord::new(
404            1.into(),
405            Timestamp::from_epoch_millis(old_timestamp * 1000),
406            &key_pair,
407            &signature,
408        );
409        storage
410            .signed_pre_key_store()
411            .save_signed_pre_key(expired_key.id()?, &expired_key)
412            .await?;
413
414        let fresh_key = generate_signed_pre_key(&identity_key_pair, 2).await?;
415        storage
416            .signed_pre_key_store()
417            .save_signed_pre_key(fresh_key.id()?, &fresh_key)
418            .await?;
419
420        assert_eq!(
421            storage.signed_pre_key_store().signed_pre_key_count().await,
422            2
423        );
424
425        cleanup_expired_signed_pre_keys(&mut storage).await?;
426
427        assert_eq!(
428            storage.signed_pre_key_store().signed_pre_key_count().await,
429            1
430        );
431
432        Ok(())
433    }
434
435    #[tokio::test]
436    async fn test_consume_pre_key_removes_key() -> Result<(), Box<dyn std::error::Error>> {
437        let mut storage = SqliteStorage::new(":memory:").await?;
438        storage.initialize()?;
439
440        let mut rng = rand::rng();
441        for i in 1..=51 {
442            let key_pair = KeyPair::generate(&mut rng);
443            let record = PreKeyRecord::new(i.into(), &key_pair);
444            storage
445                .pre_key_store()
446                .save_pre_key(i.into(), &record)
447                .await?;
448        }
449
450        assert_eq!(storage.pre_key_store().pre_key_count().await, 51);
451
452        consume_pre_key(&mut storage, PreKeyId::from(1)).await?;
453
454        assert_eq!(storage.pre_key_store().pre_key_count().await, 50);
455
456        Ok(())
457    }
458
459    #[tokio::test]
460    async fn test_replenish_pre_keys_when_low() -> Result<(), Box<dyn std::error::Error>> {
461        let mut storage = SqliteStorage::new(":memory:").await?;
462        storage.initialize()?;
463
464        assert_eq!(storage.pre_key_store().pre_key_count().await, 0);
465
466        replenish_pre_keys(&mut storage).await?;
467
468        assert_eq!(
469            storage.pre_key_store().pre_key_count().await,
470            REPLENISH_COUNT as usize
471        );
472
473        Ok(())
474    }
475
476    #[tokio::test]
477    async fn test_consume_pre_key_triggers_replenishment() -> Result<(), Box<dyn std::error::Error>>
478    {
479        let mut storage = SqliteStorage::new(":memory:").await?;
480        storage.initialize()?;
481
482        let mut rng = rand::rng();
483        for i in 1..=49 {
484            let key_pair = KeyPair::generate(&mut rng);
485            let record = PreKeyRecord::new(i.into(), &key_pair);
486            storage
487                .pre_key_store()
488                .save_pre_key(i.into(), &record)
489                .await?;
490        }
491
492        assert_eq!(storage.pre_key_store().pre_key_count().await, 49);
493
494        consume_pre_key(&mut storage, PreKeyId::from(1)).await?;
495
496        assert!(storage.pre_key_store().pre_key_count().await >= MIN_PRE_KEY_COUNT);
497
498        Ok(())
499    }
500
501    #[tokio::test]
502    async fn test_rotate_kyber_pre_key_generates_new_key() -> Result<(), Box<dyn std::error::Error>>
503    {
504        let mut storage = SqliteStorage::new(":memory:").await?;
505        storage.initialize()?;
506
507        let identity_key_pair = generate_identity_key_pair().await?;
508        storage
509            .identity_store()
510            .set_local_identity_key_pair(&identity_key_pair)
511            .await?;
512
513        let mut rng = rand::rng();
514        let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
515        let kyber_signature = identity_key_pair
516            .private_key()
517            .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
518        let now = std::time::SystemTime::now();
519        let initial_key = KyberPreKeyRecord::new(
520            KyberPreKeyId::from(1u32),
521            Timestamp::from_epoch_millis(
522                now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
523            ),
524            &kyber_keypair,
525            &kyber_signature,
526        );
527        storage
528            .kyber_pre_key_store()
529            .save_kyber_pre_key(KyberPreKeyId::from(1u32), &initial_key)
530            .await?;
531
532        assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 1);
533
534        rotate_kyber_pre_key(&mut storage, &identity_key_pair).await?;
535
536        assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 2);
537
538        let max_id = storage
539            .kyber_pre_key_store()
540            .get_max_kyber_pre_key_id()
541            .await?
542            .expect("Should have max ID");
543        assert_eq!(max_id, 2);
544
545        Ok(())
546    }
547
548    #[tokio::test]
549    async fn test_kyber_pre_key_needs_rotation_when_old() -> Result<(), Box<dyn std::error::Error>>
550    {
551        let mut storage = SqliteStorage::new(":memory:").await?;
552        storage.initialize()?;
553
554        let identity_key_pair = generate_identity_key_pair().await?;
555        storage
556            .identity_store()
557            .set_local_identity_key_pair(&identity_key_pair)
558            .await?;
559
560        // Create a Kyber key with an old timestamp (8 days ago)
561        let old_timestamp = std::time::SystemTime::now()
562            .duration_since(std::time::UNIX_EPOCH)?
563            .as_secs()
564            - (8 * 24 * 60 * 60);
565
566        let mut rng = rand::rng();
567        let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
568        let kyber_signature = identity_key_pair
569            .private_key()
570            .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
571        let old_key = KyberPreKeyRecord::new(
572            KyberPreKeyId::from(1u32),
573            Timestamp::from_epoch_millis(old_timestamp * 1000),
574            &kyber_keypair,
575            &kyber_signature,
576        );
577        storage
578            .kyber_pre_key_store()
579            .save_kyber_pre_key(KyberPreKeyId::from(1u32), &old_key)
580            .await?;
581
582        let needs_rotation = kyber_pre_key_needs_rotation(&mut storage).await?;
583        assert!(
584            needs_rotation,
585            "Kyber key older than rotation interval should need rotation"
586        );
587
588        Ok(())
589    }
590
591    #[tokio::test]
592    async fn test_kyber_pre_key_does_not_need_rotation_when_fresh(
593    ) -> Result<(), Box<dyn std::error::Error>> {
594        let mut storage = SqliteStorage::new(":memory:").await?;
595        storage.initialize()?;
596
597        let identity_key_pair = generate_identity_key_pair().await?;
598        storage
599            .identity_store()
600            .set_local_identity_key_pair(&identity_key_pair)
601            .await?;
602
603        // Create a fresh Kyber key
604        let mut rng = rand::rng();
605        let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
606        let kyber_signature = identity_key_pair
607            .private_key()
608            .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
609        let now = std::time::SystemTime::now();
610        let fresh_key = KyberPreKeyRecord::new(
611            KyberPreKeyId::from(1u32),
612            Timestamp::from_epoch_millis(
613                now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
614            ),
615            &kyber_keypair,
616            &kyber_signature,
617        );
618        storage
619            .kyber_pre_key_store()
620            .save_kyber_pre_key(KyberPreKeyId::from(1u32), &fresh_key)
621            .await?;
622
623        let needs_rotation = kyber_pre_key_needs_rotation(&mut storage).await?;
624        assert!(!needs_rotation, "Fresh Kyber key should not need rotation");
625
626        Ok(())
627    }
628
629    #[tokio::test]
630    async fn test_cleanup_expired_kyber_pre_keys() -> Result<(), Box<dyn std::error::Error>> {
631        let mut storage = SqliteStorage::new(":memory:").await?;
632        storage.initialize()?;
633
634        let identity_key_pair = generate_identity_key_pair().await?;
635        storage
636            .identity_store()
637            .set_local_identity_key_pair(&identity_key_pair)
638            .await?;
639
640        // Create an expired Kyber key (14 days old, past grace period)
641        let old_timestamp = std::time::SystemTime::now()
642            .duration_since(std::time::UNIX_EPOCH)?
643            .as_secs()
644            - (14 * 24 * 60 * 60);
645
646        let mut rng = rand::rng();
647        let kyber_keypair1 = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
648        let kyber_signature1 = identity_key_pair
649            .private_key()
650            .calculate_signature(&kyber_keypair1.public_key.serialize(), &mut rng)?;
651        let expired_key = KyberPreKeyRecord::new(
652            KyberPreKeyId::from(1u32),
653            Timestamp::from_epoch_millis(old_timestamp * 1000),
654            &kyber_keypair1,
655            &kyber_signature1,
656        );
657        storage
658            .kyber_pre_key_store()
659            .save_kyber_pre_key(KyberPreKeyId::from(1u32), &expired_key)
660            .await?;
661
662        // Create a fresh Kyber key
663        let kyber_keypair2 = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
664        let kyber_signature2 = identity_key_pair
665            .private_key()
666            .calculate_signature(&kyber_keypair2.public_key.serialize(), &mut rng)?;
667        let now = std::time::SystemTime::now();
668        let fresh_key = KyberPreKeyRecord::new(
669            KyberPreKeyId::from(2u32),
670            Timestamp::from_epoch_millis(
671                now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
672            ),
673            &kyber_keypair2,
674            &kyber_signature2,
675        );
676        storage
677            .kyber_pre_key_store()
678            .save_kyber_pre_key(KyberPreKeyId::from(2u32), &fresh_key)
679            .await?;
680
681        assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 2);
682
683        cleanup_expired_kyber_pre_keys(&mut storage).await?;
684
685        assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 1);
686
687        Ok(())
688    }
689}