@@ -139,6 +139,25 @@ impl Mpi {
139
139
}
140
140
}
141
141
142
+ /// Checks if an [`Mpi`] is less than the other in constant time.
143
+ ///
144
+ /// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
145
+ pub fn less_than_const_time ( & self , other : & Mpi ) -> Result < bool > {
146
+ mpi_inner_less_than_const_time ( & self . inner , & other. inner )
147
+ }
148
+
149
+ /// Compares an [`Mpi`] with the other in constant time.
150
+ ///
151
+ /// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
152
+ pub fn cmp_const_time ( & self , other : & Mpi ) -> Result < Ordering > {
153
+ mpi_inner_cmp_const_time ( & self . inner , & other. inner )
154
+ }
155
+
156
+ /// Checks equalness with the other in constant time.
157
+ pub fn eq_const_time ( & self , other : & Mpi ) -> Result < bool > {
158
+ mpi_inner_eq_const_time ( & self . inner , & other. inner )
159
+ }
160
+
142
161
pub fn as_u32 ( & self ) -> Result < u32 > {
143
162
if self . bit_length ( ) ? > 32 {
144
163
// Not exactly correct but close enough
@@ -409,6 +428,35 @@ impl Mpi {
409
428
}
410
429
}
411
430
431
+ pub ( super ) fn mpi_inner_eq_const_time ( x : & mpi , y : & mpi ) -> core:: prelude:: v1:: Result < bool , Error > {
432
+ match mpi_inner_cmp_const_time ( x, y) {
433
+ Ok ( order) => Ok ( order == Ordering :: Equal ) ,
434
+ Err ( Error :: MpiBadInputData ) => Ok ( false ) ,
435
+ Err ( e) => Err ( e) ,
436
+ }
437
+ }
438
+
439
+ fn mpi_inner_cmp_const_time ( x : & mpi , y : & mpi ) -> Result < Ordering > {
440
+ let less = mpi_inner_less_than_const_time ( x, y) ;
441
+ let more = mpi_inner_less_than_const_time ( y, x) ;
442
+ match ( less, more) {
443
+ ( Ok ( true ) , Ok ( false ) ) => Ok ( Ordering :: Less ) ,
444
+ ( Ok ( false ) , Ok ( true ) ) => Ok ( Ordering :: Greater ) ,
445
+ ( Ok ( false ) , Ok ( false ) ) => Ok ( Ordering :: Equal ) ,
446
+ ( Ok ( true ) , Ok ( true ) ) => unreachable ! ( ) ,
447
+ ( Err ( e) , _) => Err ( e) ,
448
+ ( Ok ( _) , Err ( e) ) => Err ( e) ,
449
+ }
450
+ }
451
+
452
+ fn mpi_inner_less_than_const_time ( x : & mpi , y : & mpi ) -> Result < bool > {
453
+ let mut r = 0 ;
454
+ unsafe {
455
+ mpi_lt_mpi_ct ( x, y, & mut r) . into_result ( ) ?;
456
+ } ;
457
+ Ok ( r == 1 )
458
+ }
459
+
412
460
impl Ord for Mpi {
413
461
fn cmp ( & self , other : & Mpi ) -> Ordering {
414
462
let r = unsafe { mpi_cmp_mpi ( & self . inner , & other. inner ) } ;
@@ -709,3 +757,52 @@ impl ShrAssign<usize> for Mpi {
709
757
// mbedtls_mpi_sub_abs
710
758
// mbedtls_mpi_mod_int
711
759
// mbedtls_mpi_gcd
760
+
761
+ #[ cfg( test) ]
762
+ mod tests {
763
+ use core:: str:: FromStr ;
764
+
765
+ use super :: * ;
766
+
767
+ #[ test]
768
+ fn test_less_than_const_time ( ) {
769
+ let mpi1 = Mpi :: new ( 10 ) . unwrap ( ) ;
770
+ let mpi2 = Mpi :: new ( 20 ) . unwrap ( ) ;
771
+
772
+ assert_eq ! ( mpi1. less_than_const_time( & mpi2) , Ok ( true ) ) ;
773
+
774
+ assert_eq ! ( mpi1. less_than_const_time( & mpi1) , Ok ( false ) ) ;
775
+
776
+ assert_eq ! ( mpi2. less_than_const_time( & mpi1) , Ok ( false ) ) ;
777
+
778
+ // Check: function returns `Error::MpiBadInputData` if the allocated length of the two input Mpis is not the same.
779
+ let mpi3 = Mpi :: from_str ( "0xdddddddddddddddddddddddddddddddd" ) . unwrap ( ) ;
780
+ assert_eq ! ( mpi3. less_than_const_time( & mpi3) , Ok ( false ) ) ;
781
+ assert_eq ! ( mpi2. less_than_const_time( & mpi3) , Err ( Error :: MpiBadInputData ) ) ;
782
+ }
783
+
784
+ #[ test]
785
+ fn test_cmp_const_time ( ) {
786
+ let mpi1 = Mpi :: new ( 10 ) . unwrap ( ) ;
787
+ let mpi2 = Mpi :: new ( 20 ) . unwrap ( ) ;
788
+
789
+ assert_eq ! ( mpi1. cmp_const_time( & mpi2) , Ok ( Ordering :: Less ) ) ;
790
+
791
+ let mpi3 = Mpi :: new ( 10 ) . unwrap ( ) ;
792
+ assert_eq ! ( mpi1. cmp_const_time( & mpi3) , Ok ( Ordering :: Equal ) ) ;
793
+
794
+ let mpi4 = Mpi :: new ( 5 ) . unwrap ( ) ;
795
+ assert_eq ! ( mpi1. cmp_const_time( & mpi4) , Ok ( Ordering :: Greater ) ) ;
796
+ }
797
+
798
+ #[ test]
799
+ fn test_eq_const_time ( ) {
800
+ let mpi1 = Mpi :: new ( 10 ) . unwrap ( ) ;
801
+ let mpi2 = Mpi :: new ( 10 ) . unwrap ( ) ;
802
+
803
+ assert_eq ! ( mpi1. eq_const_time( & mpi2) , Ok ( true ) ) ;
804
+
805
+ let mpi3 = Mpi :: new ( 20 ) . unwrap ( ) ;
806
+ assert_eq ! ( mpi1. eq_const_time( & mpi3) , Ok ( false ) ) ;
807
+ }
808
+ }
0 commit comments