Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate and spec-ify NTT implementations #163

Open
3 tasks
marsella opened this issue Oct 31, 2024 · 1 comment
Open
3 tasks

Consolidate and spec-ify NTT implementations #163

marsella opened this issue Oct 31, 2024 · 1 comment

Comments

@marsella
Copy link
Contributor

We have a couple of versions of NTT floating around (including in the ML-KEM implementation, several of the Dilithium versions, and standalone in Common/ntt.

None of them really look like the versions we have written in specs (ML-KEM and ML-DSA being the ones we'd most like to emulate). Those versions have dense nested loops with parameters that depend on other parameters and non-consecutive updates of sequence elements. It's hard to implement faithfully in Cryptol. See more discussion on #156.

The versions we have match other NTT reference implementations (I don't have sources but I guess there's a python library that matches the ML-KEM one and a C reference implementation for Dilithium).

There are also some other recursion-based fast versions that are better than the naive versions in the specs. Most specs are explicit that "any algorithm that's mathematically equivalent" is fine to use, but we need to make sure they're proven equivalent.

  • It's probably best to have a single implementation rather than multiple versions all over. Consolidate them to the Common/ location
  • We need some kind of sources to reference. Ideally it would be a spec, but if we end up matching a reference implementation instead, we should cite it.
  • Make sure the "fast" version is proven equivalent to the "naive" version and only one of them is public in the NTT module.
@marsella
Copy link
Contributor Author

Here's the old recursive attempt at something spec-adherent:

/**
 * ```repl
 * :prove NaiveNTTsMatch
 * ```
 */
property NaiveNTTsMatch f = NaiveNTT' f == NaiveNTT f

private
    /**
     * Naive version of NTT, implemented using recursing instead of loops.
     * [FIPS-203] Algorithm 9.
     *
     * Note that this implementation is spread out across multiple functions
     * to support the use of numeric constraint guards.
     */
    NaiveNTT' : Rq -> Tq
    NaiveNTT' f = state.f_hat where
        // Step 1 - 2. Initialize `f_hat`, `i`.
        state0 = { z = 0, i = 1, f_hat = f}
        // Step 3. Initialize `len` and evaluate the body of the loop.
        state = len_loop`{len = 128} state0

    type State = { z : Z q, i : [8] , f_hat : Tq }

    // Step 3 - 13.
    len_loop : {len} (len <= 128) => State -> State
    len_loop state
        // Step 3: Stop if we're at the end of the loop.
        | len < 2 => state
        // Otherwise, we're in a valid loop iteration.
        | len >= 2 => state'' where
            // Evaluate the body of the loop...
            state' = start_loop`{len, 0} state
            // ...then start the next iteration.
            state'' = len_loop`{len / 2} state'

    // Steps 4 - 12.
    start_loop : {len, start} (fin len, fin start) => State -> State
    start_loop state
        // Step 4: Stop if we're at the end of the loop.
        | start >= 256 => state
        // Otherwise, we're in a valid loop iteration.
        | start < 256 => state''' where
            // Step 5.
            z = zeta ^^(BitRev7 state.i)
            // Step 6.
            i = state.i + 1
            // Save the changes from 5-6.
            state' = { z = z, i = i, f_hat = state.f_hat }
            // Step 7-11. Evaluate the `j`-loop.
            state'' = j_loop`{len, start, start}  state'
            // Start the next iteration of the `start` loop.
            state''' = start_loop`{len, start + 2 * len} state''

    // Steps 7 - 11.
    j_loop : {len, start, j} (start <= j, j <= start + len)
        => State -> State
    j_loop state
        // Step 7: Stop if we're at the end of the loop
        | (j == start + len) => state
        // This case is impossible to reach; `j + len` will always be a valid
        // index into `f_hat`. It's not possible to infer that from the type
        // constraints we have now, so it's stated explicitly.
        | (j + len >= 256) => state
        // Otherwise, we're in a valid loop iteration.
        | (j + len < 256, j < start + len) => state'' where
            // Step 8.
            t = state.z * state.f_hat @(`j + `len)
            // Step 9.
            f_hat' = set_f`{j + len} state.f_hat (state.f_hat @`j - t)
            // Step 10.
            f_hat'' = set_f`{j} f_hat' (f_hat' @`j + t)
            // Save the changes made in Steps 8-10.
            state' = {
                z = state.z,
                i = state.i,
                f_hat = f_hat''
            }
            // Start the next iteration of the loop.
            state'' = j_loop`{len, start, j+1} state'

            // Helper function to set the `idx`th value of the polynomial.
            set_f : {idx} (idx <= 255) => Tq -> Z q -> Tq
            set_f poly val = take`{idx} poly # [val] # drop`{idx + 1} poly

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant