1use crate::keys::{generate_pre_keys, generate_signed_pre_key};
7use crate::storage_trait::{
8 ExtendedKyberPreKeyStore, ExtendedPreKeyStore, ExtendedSignedPreKeyStore,
9 SignalStorageContainer,
10};
11use libsignal_protocol::*;
12
13pub const MIN_PRE_KEY_COUNT: usize = 50;
15pub const REPLENISH_COUNT: u32 = 100;
17pub const ROTATION_INTERVAL_SECS: u64 = 7 * 24 * 60 * 60;
19pub const GRACE_PERIOD_SECS: u64 = 7 * 24 * 60 * 60;
21
22pub async fn rotate_signed_pre_key<S: SignalStorageContainer>(
28 storage: &mut S,
29 identity_key_pair: &IdentityKeyPair,
30) -> Result<(), Box<dyn std::error::Error>> {
31 let current_max_id = storage
32 .signed_pre_key_store()
33 .get_max_signed_pre_key_id()
34 .await?
35 .unwrap_or(0);
36
37 let new_signed_pre_key = generate_signed_pre_key(identity_key_pair, current_max_id + 1).await?;
38
39 storage
40 .signed_pre_key_store()
41 .save_signed_pre_key(new_signed_pre_key.id()?, &new_signed_pre_key)
42 .await?;
43
44 Ok(())
45}
46
47pub async fn signed_pre_key_needs_rotation<S: SignalStorageContainer>(
55 storage: &mut S,
56) -> Result<bool, Box<dyn std::error::Error>> {
57 let max_id = storage
58 .signed_pre_key_store()
59 .get_max_signed_pre_key_id()
60 .await?;
61
62 let Some(id) = max_id else {
63 return Ok(true);
64 };
65
66 let current_key = storage
67 .signed_pre_key_store()
68 .get_signed_pre_key(SignedPreKeyId::from(id))
69 .await?;
70
71 let key_timestamp = current_key.timestamp()?.epoch_millis();
72 let now = std::time::SystemTime::now()
73 .duration_since(std::time::UNIX_EPOCH)?
74 .as_millis() as u64;
75
76 let key_age_secs = (now - key_timestamp) / 1000;
77 Ok(key_age_secs > ROTATION_INTERVAL_SECS)
78}
79
80pub async fn cleanup_expired_signed_pre_keys<S: SignalStorageContainer>(
85 storage: &mut S,
86) -> Result<(), Box<dyn std::error::Error>> {
87 let now = std::time::SystemTime::now()
88 .duration_since(std::time::UNIX_EPOCH)?
89 .as_millis() as u64;
90
91 let cutoff = now - (GRACE_PERIOD_SECS * 1000);
92
93 let expired_ids = storage
94 .signed_pre_key_store()
95 .get_signed_pre_keys_older_than(cutoff)
96 .await?;
97
98 let current_count = storage.signed_pre_key_store().signed_pre_key_count().await;
99 if current_count <= 1 {
100 return Ok(());
101 }
102
103 for id in expired_ids {
104 if storage.signed_pre_key_store().signed_pre_key_count().await > 1 {
105 storage
106 .signed_pre_key_store()
107 .delete_signed_pre_key(id)
108 .await?;
109 }
110 }
111
112 Ok(())
113}
114
115pub async fn consume_pre_key<S: SignalStorageContainer>(
121 storage: &mut S,
122 pre_key_id: PreKeyId,
123) -> Result<(), Box<dyn std::error::Error>> {
124 storage.pre_key_store().delete_pre_key(pre_key_id).await?;
125
126 let current_count = storage.pre_key_store().pre_key_count().await;
127 if current_count < MIN_PRE_KEY_COUNT {
128 replenish_pre_keys(storage).await?;
129 }
130
131 Ok(())
132}
133
134pub async fn replenish_pre_keys<S: SignalStorageContainer>(
139 storage: &mut S,
140) -> Result<(), Box<dyn std::error::Error>> {
141 let next_id = storage
142 .pre_key_store()
143 .get_max_pre_key_id()
144 .await?
145 .map(|id| id + 1)
146 .unwrap_or(1);
147
148 let new_pre_keys = generate_pre_keys(next_id, REPLENISH_COUNT).await?;
149
150 for (key_id, key_pair) in &new_pre_keys {
151 let record = PreKeyRecord::new((*key_id).into(), key_pair);
152 storage
153 .pre_key_store()
154 .save_pre_key((*key_id).into(), &record)
155 .await?;
156 }
157
158 Ok(())
159}
160
161pub async fn rotate_kyber_pre_key<S: SignalStorageContainer>(
167 storage: &mut S,
168 identity_key_pair: &IdentityKeyPair,
169) -> Result<(), Box<dyn std::error::Error>> {
170 let current_max_id = storage
171 .kyber_pre_key_store()
172 .get_max_kyber_pre_key_id()
173 .await?
174 .unwrap_or(0);
175
176 let mut rng = rand::rng();
177 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
178 let kyber_signature = identity_key_pair
179 .private_key()
180 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
181
182 let now = std::time::SystemTime::now();
183 let kyber_record = KyberPreKeyRecord::new(
184 KyberPreKeyId::from(current_max_id + 1),
185 Timestamp::from_epoch_millis(now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64),
186 &kyber_keypair,
187 &kyber_signature,
188 );
189
190 storage
191 .kyber_pre_key_store()
192 .save_kyber_pre_key(KyberPreKeyId::from(current_max_id + 1), &kyber_record)
193 .await?;
194
195 Ok(())
196}
197
198pub async fn kyber_pre_key_needs_rotation<S: SignalStorageContainer>(
206 storage: &mut S,
207) -> Result<bool, Box<dyn std::error::Error>> {
208 let max_id = storage
209 .kyber_pre_key_store()
210 .get_max_kyber_pre_key_id()
211 .await?;
212
213 let Some(id) = max_id else {
214 return Ok(true);
215 };
216
217 let current_key = storage
218 .kyber_pre_key_store()
219 .get_kyber_pre_key(KyberPreKeyId::from(id))
220 .await?;
221
222 let key_timestamp = current_key.timestamp()?.epoch_millis();
223 let now = std::time::SystemTime::now()
224 .duration_since(std::time::UNIX_EPOCH)?
225 .as_millis() as u64;
226
227 let key_age_secs = (now - key_timestamp) / 1000;
228 Ok(key_age_secs > ROTATION_INTERVAL_SECS)
229}
230
231pub async fn cleanup_expired_kyber_pre_keys<S: SignalStorageContainer>(
236 storage: &mut S,
237) -> Result<(), Box<dyn std::error::Error>> {
238 let now = std::time::SystemTime::now()
239 .duration_since(std::time::UNIX_EPOCH)?
240 .as_millis() as u64;
241
242 let cutoff = now - (GRACE_PERIOD_SECS * 1000);
243
244 let expired_ids = storage
245 .kyber_pre_key_store()
246 .get_kyber_pre_keys_older_than(cutoff)
247 .await?;
248
249 let current_count = storage.kyber_pre_key_store().kyber_pre_key_count().await;
250 if current_count <= 1 {
251 return Ok(());
252 }
253
254 for id in expired_ids {
255 if storage.kyber_pre_key_store().kyber_pre_key_count().await > 1 {
256 storage
257 .kyber_pre_key_store()
258 .delete_kyber_pre_key(id)
259 .await?;
260 }
261 }
262
263 Ok(())
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::keys::{generate_identity_key_pair, generate_signed_pre_key};
270 use crate::sqlite_storage::SqliteStorage;
271 use crate::storage_trait::{
272 ExtendedIdentityStore, ExtendedKyberPreKeyStore, ExtendedPreKeyStore,
273 ExtendedSignedPreKeyStore, SignalStorageContainer,
274 };
275
276 #[tokio::test]
277 async fn test_rotate_signed_pre_key_generates_new_key() -> Result<(), Box<dyn std::error::Error>>
278 {
279 let mut storage = SqliteStorage::new(":memory:").await?;
280 storage.initialize()?;
281
282 let identity_key_pair = generate_identity_key_pair().await?;
283 storage
284 .identity_store()
285 .set_local_identity_key_pair(&identity_key_pair)
286 .await?;
287
288 let initial_key = generate_signed_pre_key(&identity_key_pair, 1).await?;
289 storage
290 .signed_pre_key_store()
291 .save_signed_pre_key(initial_key.id()?, &initial_key)
292 .await?;
293
294 assert_eq!(
295 storage.signed_pre_key_store().signed_pre_key_count().await,
296 1
297 );
298
299 rotate_signed_pre_key(&mut storage, &identity_key_pair).await?;
300
301 assert_eq!(
302 storage.signed_pre_key_store().signed_pre_key_count().await,
303 2
304 );
305
306 let max_id = storage
307 .signed_pre_key_store()
308 .get_max_signed_pre_key_id()
309 .await?
310 .expect("Should have max ID");
311 assert_eq!(max_id, 2);
312
313 Ok(())
314 }
315
316 #[tokio::test]
317 async fn test_signed_pre_key_needs_rotation_when_old() -> Result<(), Box<dyn std::error::Error>>
318 {
319 let mut storage = SqliteStorage::new(":memory:").await?;
320 storage.initialize()?;
321
322 let identity_key_pair = generate_identity_key_pair().await?;
323 storage
324 .identity_store()
325 .set_local_identity_key_pair(&identity_key_pair)
326 .await?;
327
328 let old_timestamp = std::time::SystemTime::now()
329 .duration_since(std::time::UNIX_EPOCH)?
330 .as_secs()
331 - (8 * 24 * 60 * 60);
332
333 let mut rng = rand::rng();
334 let key_pair = KeyPair::generate(&mut rng);
335 let signature = identity_key_pair
336 .private_key()
337 .calculate_signature(&key_pair.public_key.serialize(), &mut rng)?;
338 let old_key = SignedPreKeyRecord::new(
339 1.into(),
340 Timestamp::from_epoch_millis(old_timestamp * 1000),
341 &key_pair,
342 &signature,
343 );
344 storage
345 .signed_pre_key_store()
346 .save_signed_pre_key(old_key.id()?, &old_key)
347 .await?;
348
349 let needs_rotation = signed_pre_key_needs_rotation(&mut storage).await?;
350 assert!(
351 needs_rotation,
352 "Key older than rotation interval should need rotation"
353 );
354
355 Ok(())
356 }
357
358 #[tokio::test]
359 async fn test_signed_pre_key_does_not_need_rotation_when_fresh(
360 ) -> Result<(), Box<dyn std::error::Error>> {
361 let mut storage = SqliteStorage::new(":memory:").await?;
362 storage.initialize()?;
363
364 let identity_key_pair = generate_identity_key_pair().await?;
365 storage
366 .identity_store()
367 .set_local_identity_key_pair(&identity_key_pair)
368 .await?;
369
370 let fresh_key = generate_signed_pre_key(&identity_key_pair, 1).await?;
371 storage
372 .signed_pre_key_store()
373 .save_signed_pre_key(fresh_key.id()?, &fresh_key)
374 .await?;
375
376 let needs_rotation = signed_pre_key_needs_rotation(&mut storage).await?;
377 assert!(!needs_rotation, "Fresh key should not need rotation");
378
379 Ok(())
380 }
381
382 #[tokio::test]
383 async fn test_cleanup_expired_signed_pre_keys() -> Result<(), Box<dyn std::error::Error>> {
384 let mut storage = SqliteStorage::new(":memory:").await?;
385 storage.initialize()?;
386
387 let identity_key_pair = generate_identity_key_pair().await?;
388 storage
389 .identity_store()
390 .set_local_identity_key_pair(&identity_key_pair)
391 .await?;
392
393 let old_timestamp = std::time::SystemTime::now()
394 .duration_since(std::time::UNIX_EPOCH)?
395 .as_secs()
396 - (14 * 24 * 60 * 60);
397
398 let mut rng = rand::rng();
399 let key_pair = KeyPair::generate(&mut rng);
400 let signature = identity_key_pair
401 .private_key()
402 .calculate_signature(&key_pair.public_key.serialize(), &mut rng)?;
403 let expired_key = SignedPreKeyRecord::new(
404 1.into(),
405 Timestamp::from_epoch_millis(old_timestamp * 1000),
406 &key_pair,
407 &signature,
408 );
409 storage
410 .signed_pre_key_store()
411 .save_signed_pre_key(expired_key.id()?, &expired_key)
412 .await?;
413
414 let fresh_key = generate_signed_pre_key(&identity_key_pair, 2).await?;
415 storage
416 .signed_pre_key_store()
417 .save_signed_pre_key(fresh_key.id()?, &fresh_key)
418 .await?;
419
420 assert_eq!(
421 storage.signed_pre_key_store().signed_pre_key_count().await,
422 2
423 );
424
425 cleanup_expired_signed_pre_keys(&mut storage).await?;
426
427 assert_eq!(
428 storage.signed_pre_key_store().signed_pre_key_count().await,
429 1
430 );
431
432 Ok(())
433 }
434
435 #[tokio::test]
436 async fn test_consume_pre_key_removes_key() -> Result<(), Box<dyn std::error::Error>> {
437 let mut storage = SqliteStorage::new(":memory:").await?;
438 storage.initialize()?;
439
440 let mut rng = rand::rng();
441 for i in 1..=51 {
442 let key_pair = KeyPair::generate(&mut rng);
443 let record = PreKeyRecord::new(i.into(), &key_pair);
444 storage
445 .pre_key_store()
446 .save_pre_key(i.into(), &record)
447 .await?;
448 }
449
450 assert_eq!(storage.pre_key_store().pre_key_count().await, 51);
451
452 consume_pre_key(&mut storage, PreKeyId::from(1)).await?;
453
454 assert_eq!(storage.pre_key_store().pre_key_count().await, 50);
455
456 Ok(())
457 }
458
459 #[tokio::test]
460 async fn test_replenish_pre_keys_when_low() -> Result<(), Box<dyn std::error::Error>> {
461 let mut storage = SqliteStorage::new(":memory:").await?;
462 storage.initialize()?;
463
464 assert_eq!(storage.pre_key_store().pre_key_count().await, 0);
465
466 replenish_pre_keys(&mut storage).await?;
467
468 assert_eq!(
469 storage.pre_key_store().pre_key_count().await,
470 REPLENISH_COUNT as usize
471 );
472
473 Ok(())
474 }
475
476 #[tokio::test]
477 async fn test_consume_pre_key_triggers_replenishment() -> Result<(), Box<dyn std::error::Error>>
478 {
479 let mut storage = SqliteStorage::new(":memory:").await?;
480 storage.initialize()?;
481
482 let mut rng = rand::rng();
483 for i in 1..=49 {
484 let key_pair = KeyPair::generate(&mut rng);
485 let record = PreKeyRecord::new(i.into(), &key_pair);
486 storage
487 .pre_key_store()
488 .save_pre_key(i.into(), &record)
489 .await?;
490 }
491
492 assert_eq!(storage.pre_key_store().pre_key_count().await, 49);
493
494 consume_pre_key(&mut storage, PreKeyId::from(1)).await?;
495
496 assert!(storage.pre_key_store().pre_key_count().await >= MIN_PRE_KEY_COUNT);
497
498 Ok(())
499 }
500
501 #[tokio::test]
502 async fn test_rotate_kyber_pre_key_generates_new_key() -> Result<(), Box<dyn std::error::Error>>
503 {
504 let mut storage = SqliteStorage::new(":memory:").await?;
505 storage.initialize()?;
506
507 let identity_key_pair = generate_identity_key_pair().await?;
508 storage
509 .identity_store()
510 .set_local_identity_key_pair(&identity_key_pair)
511 .await?;
512
513 let mut rng = rand::rng();
514 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
515 let kyber_signature = identity_key_pair
516 .private_key()
517 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
518 let now = std::time::SystemTime::now();
519 let initial_key = KyberPreKeyRecord::new(
520 KyberPreKeyId::from(1u32),
521 Timestamp::from_epoch_millis(
522 now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
523 ),
524 &kyber_keypair,
525 &kyber_signature,
526 );
527 storage
528 .kyber_pre_key_store()
529 .save_kyber_pre_key(KyberPreKeyId::from(1u32), &initial_key)
530 .await?;
531
532 assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 1);
533
534 rotate_kyber_pre_key(&mut storage, &identity_key_pair).await?;
535
536 assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 2);
537
538 let max_id = storage
539 .kyber_pre_key_store()
540 .get_max_kyber_pre_key_id()
541 .await?
542 .expect("Should have max ID");
543 assert_eq!(max_id, 2);
544
545 Ok(())
546 }
547
548 #[tokio::test]
549 async fn test_kyber_pre_key_needs_rotation_when_old() -> Result<(), Box<dyn std::error::Error>>
550 {
551 let mut storage = SqliteStorage::new(":memory:").await?;
552 storage.initialize()?;
553
554 let identity_key_pair = generate_identity_key_pair().await?;
555 storage
556 .identity_store()
557 .set_local_identity_key_pair(&identity_key_pair)
558 .await?;
559
560 let old_timestamp = std::time::SystemTime::now()
562 .duration_since(std::time::UNIX_EPOCH)?
563 .as_secs()
564 - (8 * 24 * 60 * 60);
565
566 let mut rng = rand::rng();
567 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
568 let kyber_signature = identity_key_pair
569 .private_key()
570 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
571 let old_key = KyberPreKeyRecord::new(
572 KyberPreKeyId::from(1u32),
573 Timestamp::from_epoch_millis(old_timestamp * 1000),
574 &kyber_keypair,
575 &kyber_signature,
576 );
577 storage
578 .kyber_pre_key_store()
579 .save_kyber_pre_key(KyberPreKeyId::from(1u32), &old_key)
580 .await?;
581
582 let needs_rotation = kyber_pre_key_needs_rotation(&mut storage).await?;
583 assert!(
584 needs_rotation,
585 "Kyber key older than rotation interval should need rotation"
586 );
587
588 Ok(())
589 }
590
591 #[tokio::test]
592 async fn test_kyber_pre_key_does_not_need_rotation_when_fresh(
593 ) -> Result<(), Box<dyn std::error::Error>> {
594 let mut storage = SqliteStorage::new(":memory:").await?;
595 storage.initialize()?;
596
597 let identity_key_pair = generate_identity_key_pair().await?;
598 storage
599 .identity_store()
600 .set_local_identity_key_pair(&identity_key_pair)
601 .await?;
602
603 let mut rng = rand::rng();
605 let kyber_keypair = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
606 let kyber_signature = identity_key_pair
607 .private_key()
608 .calculate_signature(&kyber_keypair.public_key.serialize(), &mut rng)?;
609 let now = std::time::SystemTime::now();
610 let fresh_key = KyberPreKeyRecord::new(
611 KyberPreKeyId::from(1u32),
612 Timestamp::from_epoch_millis(
613 now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
614 ),
615 &kyber_keypair,
616 &kyber_signature,
617 );
618 storage
619 .kyber_pre_key_store()
620 .save_kyber_pre_key(KyberPreKeyId::from(1u32), &fresh_key)
621 .await?;
622
623 let needs_rotation = kyber_pre_key_needs_rotation(&mut storage).await?;
624 assert!(!needs_rotation, "Fresh Kyber key should not need rotation");
625
626 Ok(())
627 }
628
629 #[tokio::test]
630 async fn test_cleanup_expired_kyber_pre_keys() -> Result<(), Box<dyn std::error::Error>> {
631 let mut storage = SqliteStorage::new(":memory:").await?;
632 storage.initialize()?;
633
634 let identity_key_pair = generate_identity_key_pair().await?;
635 storage
636 .identity_store()
637 .set_local_identity_key_pair(&identity_key_pair)
638 .await?;
639
640 let old_timestamp = std::time::SystemTime::now()
642 .duration_since(std::time::UNIX_EPOCH)?
643 .as_secs()
644 - (14 * 24 * 60 * 60);
645
646 let mut rng = rand::rng();
647 let kyber_keypair1 = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
648 let kyber_signature1 = identity_key_pair
649 .private_key()
650 .calculate_signature(&kyber_keypair1.public_key.serialize(), &mut rng)?;
651 let expired_key = KyberPreKeyRecord::new(
652 KyberPreKeyId::from(1u32),
653 Timestamp::from_epoch_millis(old_timestamp * 1000),
654 &kyber_keypair1,
655 &kyber_signature1,
656 );
657 storage
658 .kyber_pre_key_store()
659 .save_kyber_pre_key(KyberPreKeyId::from(1u32), &expired_key)
660 .await?;
661
662 let kyber_keypair2 = kem::KeyPair::generate(kem::KeyType::Kyber1024, &mut rng);
664 let kyber_signature2 = identity_key_pair
665 .private_key()
666 .calculate_signature(&kyber_keypair2.public_key.serialize(), &mut rng)?;
667 let now = std::time::SystemTime::now();
668 let fresh_key = KyberPreKeyRecord::new(
669 KyberPreKeyId::from(2u32),
670 Timestamp::from_epoch_millis(
671 now.duration_since(std::time::UNIX_EPOCH)?.as_millis() as u64
672 ),
673 &kyber_keypair2,
674 &kyber_signature2,
675 );
676 storage
677 .kyber_pre_key_store()
678 .save_kyber_pre_key(KyberPreKeyId::from(2u32), &fresh_key)
679 .await?;
680
681 assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 2);
682
683 cleanup_expired_kyber_pre_keys(&mut storage).await?;
684
685 assert_eq!(storage.kyber_pre_key_store().kyber_pre_key_count().await, 1);
686
687 Ok(())
688 }
689}