diff --git a/build.rs b/build.rs index cbb1c4c6bf..c50c5c2d11 100644 --- a/build.rs +++ b/build.rs @@ -1026,7 +1026,10 @@ fn prefix_all_symbols(pp: char, prefix_prefix: &str, prefix: &str) -> String { "x25519_ge_double_scalarmult_vartime", "x25519_ge_frombytes_vartime", "x25519_ge_scalarmult_base", - "x25519_ge_scalarmult_base_adx", + "x25519_ge_scalarmult_base_adx_recode", + "x25519_ge_scalarmult_base_adx_add", + "x25519_ge_scalarmult_base_adx_dbl_4", + "x25519_ge_scalarmult_base_adx_canon", "x25519_ge_scalarmult_base_adx_from_bytes", "x25519_public_from_private_generic_masked", "x25519_sc_mask", diff --git a/src/ec/curve25519/adx.rs b/src/ec/curve25519/adx.rs index 5f49d6c2d6..3cfa9ea3ab 100644 --- a/src/ec/curve25519/adx.rs +++ b/src/ec/curve25519/adx.rs @@ -27,14 +27,27 @@ pub(super) fn get_features(cpu: cpu::Features) -> Option { } pub fn scalarmult_base(a: &Scalar, _cpu: RequiredFeatures) -> P3 { - prefixed_extern! { - unsafe fn x25519_ge_scalarmult_base_adx(t: &mut MaybeUninit<[[u8; 32]; 4]>, a: &Scalar); - unsafe fn x25519_ge_scalarmult_base_adx_from_bytes(h: &mut MaybeUninit, t: &[[u8; 32]; 4]); + let mut e: MaybeUninit = MaybeUninit::uninit(); + let e = unsafe { + x25519_ge_scalarmult_base_adx_recode(&mut e, a); + e.assume_init_ref() + }; + let mut r = ge_p3_4::new_0_1_1_0(); + unsafe { + x25519_ge_scalarmult_base_adx_add(&mut r, e, true); } - - let mut t = MaybeUninit::uninit(); + const LAST_DOUBLING: usize = 3; + for i in 0..=LAST_DOUBLING { + unsafe { + x25519_ge_scalarmult_base_adx_dbl_4(&mut r, i != LAST_DOUBLING); + } + } + unsafe { + x25519_ge_scalarmult_base_adx_add(&mut r, e, false); + } + let mut t: MaybeUninit = MaybeUninit::uninit(); let t = unsafe { - x25519_ge_scalarmult_base_adx(&mut t, a); + x25519_ge_scalarmult_base_adx_canon(&mut t, &mut r); t.assume_init_ref() }; let mut h = MaybeUninit::uninit(); @@ -43,3 +56,40 @@ pub fn scalarmult_base(a: &Scalar, _cpu: RequiredFeatures) -> P3 { h.assume_init() } } + +type Digits = [i8; 64]; + +// Keep in sync with ge_p3_4 in curve25519_64_adx.h +#[repr(C)] +struct ge_p3_4 { + X: fe4, + Y: fe4, + Z: fe4, + T: fe4, +} + +type ge_p3_4_bytes = [[u8; 32]; 4]; + +impl ge_p3_4 { + fn new_0_1_1_0() -> Self { + const ZERO: fe4 = [0, 0, 0, 0]; + const ONE: fe4 = [1, 0, 0, 0]; + Self { + X: ZERO, + Y: ONE, + Z: ONE, + T: ZERO, + } + } +} + +type fe4 = [u64; 4]; + +prefixed_extern! { + // Postcondition: `e` is a valid `E` for the value `a`. + unsafe fn x25519_ge_scalarmult_base_adx_recode(e: &mut MaybeUninit, a: &Scalar); + unsafe fn x25519_ge_scalarmult_base_adx_add(r: &mut ge_p3_4, e: &Digits, odd: bool); + unsafe fn x25519_ge_scalarmult_base_adx_dbl_4(r: &mut ge_p3_4, skip_t: bool); + unsafe fn x25519_ge_scalarmult_base_adx_canon(t: &mut MaybeUninit, r: &mut ge_p3_4); + unsafe fn x25519_ge_scalarmult_base_adx_from_bytes(h: &mut MaybeUninit, t: &ge_p3_4_bytes); +} diff --git a/third_party/fiat/curve25519_64_adx.h b/third_party/fiat/curve25519_64_adx.h index 02e8ad9114..f0eb8e3873 100644 --- a/third_party/fiat/curve25519_64_adx.h +++ b/third_party/fiat/curve25519_64_adx.h @@ -5,7 +5,9 @@ #include #include +// Keep in sync with Rust `fe4`. typedef uint64_t fe4[4]; + typedef uint8_t fiat_uint1; typedef int8_t fiat_int1; @@ -546,6 +548,7 @@ void x25519_scalar_mult_adx(uint8_t out[32], const uint8_t scalar[32], OPENSSL_memcpy(out, x2, sizeof(fe4)); } +// Keep in sync with Rust `ge_p3_4`. typedef struct { fe4 X; fe4 Y; @@ -559,8 +562,14 @@ typedef struct { fe4 xy2d; } ge_precomp_4; +// This was named "inline_x25519_ge_dbl_4" but it was never inlined. +RING_NOINLINE // https://github.com/rust-lang/rust/issues/116573 __attribute__((target("adx,bmi2"))) -static void inline_x25519_ge_dbl_4(ge_p3_4 *r, const ge_p3_4 *p, bool skip_t) { +void x25519_ge_scalarmult_base_adx_dbl_4(ge_p3_4 *r, bool skip_t) { + // Originally this function accepted a separate `p` argument, but it was + // always used with `p` equal to `r`. + const ge_p3_4 *p = r; + // Transcribed from a Coq function proven against affine coordinates. // https://github.com/mit-plv/fiat-crypto/blob/9943ba9e7d8f3e1c0054b2c94a5edca46ea73ef8/src/Curves/Edwards/XYZT/Basic.v#L136-L165 fe4 trX, trZ, trT, t0, cX, cY, cZ, cT; @@ -645,8 +654,7 @@ static inline void table_select_4(ge_precomp_4 *t, const int pos, // a[31] <= 127 RING_NOINLINE // https://github.com/rust-lang/rust/issues/116573 __attribute__((target("adx,bmi2"))) -void x25519_ge_scalarmult_base_adx(uint8_t h[4][32], const uint8_t a[32]) { - signed char e[64]; +void x25519_ge_scalarmult_base_adx_recode(signed char e[64], const uint8_t a[32]) { signed char carry; for (unsigned i = 0; i < 32; ++i) { @@ -665,31 +673,28 @@ void x25519_ge_scalarmult_base_adx(uint8_t h[4][32], const uint8_t a[32]) { } e[63] += carry; // each e[i] is between -8 and 8 +} - ge_p3_4 r = {{0}, {1}, {1}, {0}}; - for (unsigned i = 1; i < 64; i += 2) { - ge_precomp_4 t; - table_select_4(&t, i / 2, e[i]); - ge_p3_add_p3_precomp_4(&r, &r, &t); - } - - inline_x25519_ge_dbl_4(&r, &r, /*skip_t=*/true); - inline_x25519_ge_dbl_4(&r, &r, /*skip_t=*/true); - inline_x25519_ge_dbl_4(&r, &r, /*skip_t=*/true); - inline_x25519_ge_dbl_4(&r, &r, /*skip_t=*/false); - - for (unsigned i = 0; i < 64; i += 2) { +// `odd` is `true` for step 2 and `false` for step 4. +RING_NOINLINE // https://github.com/rust-lang/rust/issues/116573 +__attribute__((target("adx,bmi2"))) +void x25519_ge_scalarmult_base_adx_add(ge_p3_4 *r, const signed char e[64], bool odd) { + for (unsigned i = odd; i < 64; i += 2) { ge_precomp_4 t; table_select_4(&t, i / 2, e[i]); - ge_p3_add_p3_precomp_4(&r, &r, &t); + ge_p3_add_p3_precomp_4(r, r, &t); } +} +RING_NOINLINE // https://github.com/rust-lang/rust/issues/116573 +__attribute__((target("adx,bmi2"))) +void x25519_ge_scalarmult_base_adx_canon(uint8_t h[4][32], ge_p3_4 *r) { // fe4 uses saturated 64-bit limbs, so converting to bytes is just a copy. // Satisfy stated precondition of fiat_25519_from_bytes; tests pass either way - fe4_canon(r.X, r.X); - fe4_canon(r.Y, r.Y); - fe4_canon(r.Z, r.Z); - fe4_canon(r.T, r.T); + fe4_canon(r->X, r->X); + fe4_canon(r->Y, r->Y); + fe4_canon(r->Z, r->Z); + fe4_canon(r->T, r->T); OPENSSL_STATIC_ASSERT(sizeof(ge_p3_4) == sizeof(uint8_t[4][32]), ""); - OPENSSL_memcpy(h, &r, sizeof(ge_p3_4)); + OPENSSL_memcpy(h, r, sizeof(ge_p3_4)); }