Skip to content

Commit 19cc8cd

Browse files
committed
const time compare
- Add `Mpi` const time compare. - Refactor `EcPoint` const time compare.
2 parents a0f2dd5 + 114bbc9 commit 19cc8cd

File tree

3 files changed

+111
-10
lines changed

3 files changed

+111
-10
lines changed

mbedtls/benches/ecp_eq_test.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ fn ecp_equal(a: &EcPoint, b: &EcPoint) {
66
}
77

88
fn ecp_equal_const_time(a: &EcPoint, b: &EcPoint) {
9-
assert!(!a.eq_const_time(&b));
9+
assert!(!a.eq_const_time(&b).unwrap());
1010
}
1111

1212
fn criterion_benchmark(c: &mut Criterion) {

mbedtls/src/bignum/mod.rs

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,25 @@ impl Mpi {
139139
}
140140
}
141141

142+
/// Checks if an [`Mpi`] is less than the other in constant time.
143+
///
144+
/// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
145+
pub fn less_than_const_time(&self, other: &Mpi) -> Result<bool> {
146+
mpi_inner_less_than_const_time(&self.inner, &other.inner)
147+
}
148+
149+
/// Compares an [`Mpi`] with the other in constant time.
150+
///
151+
/// Will return [`Error::MpiBadInputData`] if the allocated length of the two input [`Mpi`]s is not the same.
152+
pub fn cmp_const_time(&self, other: &Mpi) -> Result<Ordering> {
153+
mpi_inner_cmp_const_time(&self.inner, &other.inner)
154+
}
155+
156+
/// Checks equalness with the other in constant time.
157+
pub fn eq_const_time(&self, other: &Mpi) -> Result<bool> {
158+
mpi_inner_eq_const_time(&self.inner, &other.inner)
159+
}
160+
142161
pub fn as_u32(&self) -> Result<u32> {
143162
if self.bit_length()? > 32 {
144163
// Not exactly correct but close enough
@@ -409,6 +428,35 @@ impl Mpi {
409428
}
410429
}
411430

431+
pub(super) fn mpi_inner_eq_const_time(x: &mpi, y: &mpi) -> core::prelude::v1::Result<bool, Error> {
432+
match mpi_inner_cmp_const_time(x, y) {
433+
Ok(order) => Ok(order == Ordering::Equal),
434+
Err(Error::MpiBadInputData) => Ok(false),
435+
Err(e) => Err(e),
436+
}
437+
}
438+
439+
fn mpi_inner_cmp_const_time(x: &mpi, y: &mpi) -> Result<Ordering> {
440+
let less = mpi_inner_less_than_const_time(x, y);
441+
let more = mpi_inner_less_than_const_time(y, x);
442+
match (less, more) {
443+
(Ok(true), Ok(false)) => Ok(Ordering::Less),
444+
(Ok(false), Ok(true)) => Ok(Ordering::Greater),
445+
(Ok(false), Ok(false)) => Ok(Ordering::Equal),
446+
(Ok(true), Ok(true)) => unreachable!(),
447+
(Err(e), _) => Err(e),
448+
(Ok(_), Err(e)) => Err(e),
449+
}
450+
}
451+
452+
fn mpi_inner_less_than_const_time(x: &mpi, y: &mpi) -> Result<bool> {
453+
let mut r = 0;
454+
unsafe {
455+
mpi_lt_mpi_ct(x, y, &mut r).into_result()?;
456+
};
457+
Ok(r == 1)
458+
}
459+
412460
impl Ord for Mpi {
413461
fn cmp(&self, other: &Mpi) -> Ordering {
414462
let r = unsafe { mpi_cmp_mpi(&self.inner, &other.inner) };
@@ -709,3 +757,52 @@ impl ShrAssign<usize> for Mpi {
709757
// mbedtls_mpi_sub_abs
710758
// mbedtls_mpi_mod_int
711759
// mbedtls_mpi_gcd
760+
761+
#[cfg(test)]
762+
mod tests {
763+
use core::str::FromStr;
764+
765+
use super::*;
766+
767+
#[test]
768+
fn test_less_than_const_time() {
769+
let mpi1 = Mpi::new(10).unwrap();
770+
let mpi2 = Mpi::new(20).unwrap();
771+
772+
assert_eq!(mpi1.less_than_const_time(&mpi2), Ok(true));
773+
774+
assert_eq!(mpi1.less_than_const_time(&mpi1), Ok(false));
775+
776+
assert_eq!(mpi2.less_than_const_time(&mpi1), Ok(false));
777+
778+
// Check: function returns `Error::MpiBadInputData` if the allocated length of the two input Mpis is not the same.
779+
let mpi3 = Mpi::from_str("0xdddddddddddddddddddddddddddddddd").unwrap();
780+
assert_eq!(mpi3.less_than_const_time(&mpi3), Ok(false));
781+
assert_eq!(mpi2.less_than_const_time(&mpi3), Err(Error::MpiBadInputData));
782+
}
783+
784+
#[test]
785+
fn test_cmp_const_time() {
786+
let mpi1 = Mpi::new(10).unwrap();
787+
let mpi2 = Mpi::new(20).unwrap();
788+
789+
assert_eq!(mpi1.cmp_const_time(&mpi2), Ok(Ordering::Less));
790+
791+
let mpi3 = Mpi::new(10).unwrap();
792+
assert_eq!(mpi1.cmp_const_time(&mpi3), Ok(Ordering::Equal));
793+
794+
let mpi4 = Mpi::new(5).unwrap();
795+
assert_eq!(mpi1.cmp_const_time(&mpi4), Ok(Ordering::Greater));
796+
}
797+
798+
#[test]
799+
fn test_eq_const_time() {
800+
let mpi1 = Mpi::new(10).unwrap();
801+
let mpi2 = Mpi::new(10).unwrap();
802+
803+
assert_eq!(mpi1.eq_const_time(&mpi2), Ok(true));
804+
805+
let mpi3 = Mpi::new(20).unwrap();
806+
assert_eq!(mpi1.eq_const_time(&mpi3), Ok(false));
807+
}
808+
}

mbedtls/src/ecp/mod.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -440,12 +440,16 @@ Please use `mul_with_rng` instead."
440440
/// This new implementation ensures there is no shortcut when any of `x, y ,z` fields of two points is not equal.
441441
///
442442
/// [`mbedtls_ecp_point_cmp`]: https://github.com/fortanix/rust-mbedtls/blob/main/mbedtls-sys/vendor/library/ecp.c#L809-L825
443-
pub fn eq_const_time(&self, other: &EcPoint) -> bool {
444-
unsafe {
445-
let x = mpi_cmp_mpi(&self.inner.X, &other.inner.X) == 0;
446-
let y = mpi_cmp_mpi(&self.inner.Y, &other.inner.Y) == 0;
447-
let z = mpi_cmp_mpi(&self.inner.Z, &other.inner.Z) == 0;
448-
x & y & z
443+
pub fn eq_const_time(&self, other: &EcPoint) -> Result<bool> {
444+
let x = crate::bignum::mpi_inner_eq_const_time(&self.inner.X, &other.inner.X);
445+
let y = crate::bignum::mpi_inner_eq_const_time(&self.inner.Y, &other.inner.Y);
446+
let z = crate::bignum::mpi_inner_eq_const_time(&self.inner.Z, &other.inner.Z);
447+
match (x, y, z) {
448+
(Ok(true), Ok(true), Ok(true)) => Ok(true),
449+
(Ok(_), Ok(_), Ok(_)) => Ok(false),
450+
(Ok(_), Ok(_), Err(e)) => Err(e),
451+
(Ok(_), Err(e), _) => Err(e),
452+
(Err(e), _, _) => Err(e),
449453
}
450454
}
451455

@@ -724,9 +728,9 @@ mod tests {
724728
assert!(g.eq(&g).unwrap());
725729
assert!(zero.eq(&zero).unwrap());
726730
assert!(!g.eq(&zero).unwrap());
727-
assert!(g.eq_const_time(&g));
728-
assert!(zero.eq_const_time(&zero));
729-
assert!(!g.eq_const_time(&zero));
731+
assert!(g.eq_const_time(&g).unwrap());
732+
assert!(zero.eq_const_time(&zero).unwrap());
733+
assert!(!g.eq_const_time(&zero).unwrap());
730734
}
731735

732736
#[test]

0 commit comments

Comments
 (0)