Skip to content

Commit

Permalink
Preparation for NTT (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
fboemer authored and GitHub Enterprise committed Feb 24, 2024
1 parent f1778a8 commit 311db33
Show file tree
Hide file tree
Showing 11 changed files with 429 additions and 64 deletions.
26 changes: 26 additions & 0 deletions Sources/SwiftHe/Error.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import Foundation

enum HeError: Error, Equatable {
case compositeModulus(_ modulus: Int)
case coprimeModuli(moduli: [Int])
case invalidDegree(_ degree: Int)
case invalidNttModulus(modulus: Int, degree: Int)
case polyFormatMismatch(got: PolyFormat, expected: PolyFormat)
}

extension HeError: LocalizedError {
var errorDescription: String? {
switch self {
case let .compositeModulus(modulus):
"Composite modulus \(modulus)"
case let .coprimeModuli(moduli):
"Coprime moduli \(moduli)"
case let .invalidDegree(degree):
"Invalid degree \(degree)"
case let .invalidNttModulus(modulus, degree):
"Invalid NTT modulus \(modulus) for degree \(degree)"
case let .polyFormatMismatch(got, expected):
"PolyFormat mismatch: got \(got), expected \(expected)"
}
}
}
76 changes: 76 additions & 0 deletions Sources/SwiftHe/Ntt.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
extension FixedWidthInteger {
func isNttModulus(for degree: Self) -> Bool {
degree.isPowerOfTwo && isPrime(variableTime: true) && self % (Self(2) * degree) == 1
}
}

extension UnsignedInteger where Self: FixedWidthInteger {
func isPrimitiveRootOfUnity(degree: Int, modulus: Self) -> Bool {
// For degree a power of two, it suffices to check root^(degree/2) == -1 mod p
// This implies root^degree == 1 mod p. Also, note 2 is the only prime factor of
// degree. See
// https://en.wikipedia.org/wiki/Root_of_unity_modulo_n#Testing_whether_x_is_a_primitive_k-th_root_of_unity_modulo_n
precondition(degree.isPowerOfTwo)
return powMod(exponent: Self(degree / 2), modulus: modulus, variableTime: true) == modulus - 1
}

/// Generates a primitive `degree'th` root of unity for integers mod `self`
func generatePrimitiveRootOfUnity(degree: Int) -> Self? {
precondition(degree.isPowerOfTwo)
precondition(isPrime(variableTime: true))

// See https://en.wikipedia.org/wiki/Root_of_unity_modulo_n#Finding_a_primitive_k-th_root_of_unity_modulo_n
// Carmichael function lambda(p) = p - 1 for p prime
let lambdaP = self - 1

// "If k does not divide lambda(n), then there will be no k-th roots of unity, at all."
if !lambdaP.isMultiple(of: Self(degree)) {
return nil
}

// The number of primitive roots mod p for p prime is phi(p-1), where phi is
// Euler's totient function. We know phi(p-1) > p / (e^gamma log(log(p)) + 3 /
// log(log(p)) (https://en.wikipedia.org/wiki/Euler%27s_totient_function#Growth_rate).
// So the probability that a random value in [0, p-1] is a primitive root is at
// least phi(p-1)/p > 1 / (e^gamma log(log(p)) + 3 / log(log(p)) > 1/8 for p
// < 2^64 and where gamma is the Euler–Mascheroni constant ~= 0.577. That
// is, we have at least 1/8 chance of finding a root on each attempt. So, (1 -
// 1/8)^T < 2^{-128} yields T = 665 trials suffices for less than 2^{-128}
// chance of failure.
let trialCount = 665
var rng = SystemRandomNumberGenerator()
for _ in 0..<trialCount {
var root = Self.random(in: 0..<self, using: &rng)
// root^(lambda(p)/degree) will be a primitive degree'th root of unity if root
// is a lambda(p)'th root
root = root.powMod(exponent: lambdaP / Self(degree), modulus: self, variableTime: true)
if root.isPrimitiveRootOfUnity(degree: degree, modulus: self) {
return root
}
}
return nil
}

/// Generates the smallest primitive `degree`'th primitive root for integers mod `self`
/// Degree must be a power of two that divides p - 1
/// p must be prime
func minPrimitiveRootOfUnity(degree: Int) -> Self? {
guard var smallestGenerator = generatePrimitiveRootOfUnity(degree: degree) else {
return nil
}
var currentGenerator = smallestGenerator

// Given a generator g, g^l is a degree'th root of unity iff l and degree are
// co-prime. Since degree is a power of two, we can check g, g^3, g^5, ...
// See https://en.wikipedia.org/wiki/Root_of_unity_modulo_n#Finding_multiple_primitive_k-th_roots_modulo_n
let generatorSquared = currentGenerator.powMod(exponent: 2, modulus: self, variableTime: true)
let modulus = ReduceModulus(modulus: self, bound: ReduceModulus.InputBound.ModulusSquared, variableTime: true)
for _ in 0..<degree / 2 {
if currentGenerator < smallestGenerator {
smallestGenerator = currentGenerator
}
currentGenerator = modulus.multiplyMod(currentGenerator, generatorSquared)
}
return smallestGenerator
}
}
17 changes: 17 additions & 0 deletions Sources/SwiftHe/PolyContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,21 @@ struct PolyContext<T: UnsignedInteger & FixedWidthInteger>: Equatable {
let degree: Int
/// CRT-representation of the modulus `Q = q_1 * q_2 * ... * q_L`
let moduli: [T]

init(degree: Int, moduli: [T]) throws {
guard degree.isPowerOfTwo else {
throw HeError.invalidDegree(degree)
}
for modulus in moduli {
guard modulus.isPrime(variableTime: true) else {
throw HeError.compositeModulus(Int(modulus))
}
}
guard moduli.allUnique() else {
throw HeError.coprimeModuli(moduli: moduli.map { Int($0) })
}

self.degree = degree
self.moduli = moduli
}
}
34 changes: 26 additions & 8 deletions Sources/SwiftHe/PolyRq.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@
// Created by Karl Tarbe on 2/19/24.
//

/// The forward NTT converts from Coefficient form to Evaluation form.
/// The inverse NTT converts from Evaluation form to Coefficient form
enum PolyFormat {
case Coeff /// Coefficient format
case Eval /// Evaluation format
}

/// Represents a polynomial in `R_q = Z_q(X)^N / (X^N + 1)` for `N` a power of
/// two and `q` a (possibly) multi-word integer
struct PolyRq<T: UnsignedInteger & FixedWidthInteger>: Equatable {
let context: PolyContext<T>
var format: PolyFormat
var data: Array2d<T>

init(context: PolyContext<T>, data: Array2d<T>) {
init(context: PolyContext<T>, format: PolyFormat, data: Array2d<T>) {
precondition(context.degree == data.columnCount)
precondition(context.moduli.count == data.rowCount)
self.context = context
self.format = format
self.data = data
assert(isValidData())
}
Expand All @@ -32,6 +41,7 @@ struct PolyRq<T: UnsignedInteger & FixedWidthInteger>: Equatable {
extension PolyRq {
func checkMetadataMatches(with other: Self) {
precondition(context == other.context)
precondition(format == other.format)
precondition(data.rowCount == other.data.rowCount)
precondition(data.columnCount == other.data.columnCount)
}
Expand Down Expand Up @@ -78,23 +88,25 @@ extension PolyRq {

extension PolyRq {
/// Initialize a zero Polynomial with all coefficients set to zero
static func zero(context: PolyContext<T>) -> Self {
static func zero(context: PolyContext<T>, format: PolyFormat) -> Self {
let degree = context.degree
let moduliCount = context.moduli.count
let zeroes = Array2d(
data: Array(repeating: T.zero, count: degree * moduliCount),
rowCount: moduliCount,
columnCount: degree)
return Self(context: context, data: zeroes)
return Self(context: context, format: format, data: zeroes)
}

static func random(context: PolyContext<T>) -> Self {
static func random(context: PolyContext<T>, format: PolyFormat) -> Self {
var rng: any RandomNumberGenerator = SystemRandomNumberGenerator()
return Self.random(context: context, using: &rng)
return Self.random(context: context, format: format, using: &rng)
}

static func random(context: PolyContext<T>, using rng: inout any RandomNumberGenerator) -> Self {
var poly = Self.zero(context: context)
static func random(context: PolyContext<T>, format: PolyFormat,
using rng: inout any RandomNumberGenerator) -> Self
{
var poly = Self.zero(context: context, format: format)
poly.randomizeUniform(using: &rng)
return poly
}
Expand Down Expand Up @@ -134,7 +146,7 @@ extension PolyRq {
}

static prefix func - (_ rhs: Self) -> Self {
var result = Self.zero(context: rhs.context)
var result = Self.zero(context: rhs.context, format: rhs.format)
for (rnsIndex, modulus) in result.context.moduli.enumerated() {
for index in result.polyIndices(rnsIndex: rnsIndex) {
result[index] = rhs[index].negateMod(modulus: modulus)
Expand All @@ -143,4 +155,10 @@ extension PolyRq {

return result
}

private func checkPolyFormat(_ format: PolyFormat) throws {
guard self.format == format else {
throw HeError.polyFormatMismatch(got: self.format, expected: format)
}
}
}
80 changes: 77 additions & 3 deletions Sources/SwiftHe/Scalar.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,27 @@ extension UnsignedInteger where Self: FixedWidthInteger {
let sum = self &+ modulus &- rhs
return sum.subtractIfExceeds(modulus)
}

func powMod(exponent: Self, modulus: Self, variableTime: Bool) -> Self {
precondition(variableTime)
var base = self
var exponent = exponent
let modulus = ReduceModulus(
modulus: modulus,
bound: ReduceModulus.InputBound.ModulusSquared,
variableTime: variableTime)
var result = Self(1)
for _ in 0...exponent.log2 {
if (exponent & 1) != 0 {
result = modulus.multiplyMod(result, base)
}
if exponent > 0 {
base = modulus.multiplyMod(base, base)
}
exponent >>= 1
}
return result
}
}

extension UInt32 {
Expand Down Expand Up @@ -66,7 +87,7 @@ struct MultiplyConstantModulus<T: UnsignedInteger & FixedWidthInteger> {
let factor: T // Barrett factor

/// Note: leaks multiplicand, modulus through timing
init(_ multiplicand: T, modulus: T, variableTime: Bool) {
init(multiplicand: T, modulus: T, variableTime: Bool) {
precondition(variableTime) // TODO: support constant-time
assert(multiplicand < modulus)
self.multiplicand = multiplicand
Expand Down Expand Up @@ -182,10 +203,19 @@ struct ReduceModulus<T: UnsignedInteger & FixedWidthInteger> {
let alphaMinusBeta = T.bitWidth
let nPlusBeta = n &+ reduceModulusBeta
let xShift = x &>> nPlusBeta
let qHat = (xShift &* factor) &>> alphaMinusBeta
let z = x &- qHat &* DoubleWidth(modulus)
// TODO: possibly improve performence by only computing the low T.bitWidth bits of result
let qHat = (xShift &* DoubleWidth(factor.low)) &>> alphaMinusBeta
let z = x &- qHat &* DoubleWidth<T>(modulus)
return T(z.low).subtractIfExceeds(modulus)
}

/// Returns `x * y mod p` for `x, y < p`.
func multiplyMod(_ x: T, _ y: T) -> T {
precondition(x < modulus)
precondition(y < modulus)
let product = x.multipliedFullWidth(by: y)
return reduceProduct(DoubleWidth<T>(product))
}
}

extension FixedWidthInteger {
Expand All @@ -211,4 +241,48 @@ extension FixedWidthInteger {
let multiplied = multipliedFullWidth(by: rhs)
return modulus.dividingFullWidth(multiplied).remainder
}

func isPrime(variableTime: Bool) -> Bool {
precondition(variableTime)
if self <= 1 {
return false
}
// Rabin-prime primality test
let bases: [Self] = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37]
for base in bases {
if self == base {
return true
}
if isMultiple(of: base) {
return false
}
}

// write self = 2**r * d + 1 with d odd
var r = Self.bitWidth - 1
while r > 0, !(self - 1).isMultiple(of: Self(1) << r) {
r -= 1
}
let twoPowR = Self(1) << r
let d = (self - 1) / twoPowR
assert(r != 0)
assert(self == twoPowR * d + 1)
assert(d & 1 == 1)

let nPos = Self.Magnitude(self)
witnessLoop: for base in bases {
var x = UInt(base).powMod(exponent: UInt(d), modulus: UInt(nPos), variableTime: true)
if x == 1 || x == self - 1 {
continue
}
for _ in 0..<r {
x = x.powMod(exponent: 2, modulus: UInt(nPos), variableTime: true)
if x == self - 1 {
continue witnessLoop
}
}
return false
}
return true
}
}
11 changes: 11 additions & 0 deletions Sources/SwiftHe/Util.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
extension Sequence where Element: Hashable {
func allUnique() -> Bool {
var seen = Set<Self.Element>()
for element in self {
guard seen.insert(element).inserted else {
return false
}
}
return true
}
}
37 changes: 37 additions & 0 deletions Tests/SwiftHeTests/NttTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//
// NttTests.swift
// SwiftHeTests
//
// Created by Fabian Boemer on 2/22/24.
//

@testable import SwiftHe
import XCTest

final class NttTests: XCTestCase {
func testIsPrimitiveRootOfUnity() {
XCTAssertTrue(UInt32(12).isPrimitiveRootOfUnity(degree: 2, modulus: 13))
XCTAssertFalse(UInt32(11).isPrimitiveRootOfUnity(degree: 2, modulus: 13))
XCTAssertFalse(UInt32(12).isPrimitiveRootOfUnity(degree: 4, modulus: 13))

XCTAssertTrue(UInt64(28).isPrimitiveRootOfUnity(degree: 2, modulus: 29))
XCTAssertTrue(UInt64(12).isPrimitiveRootOfUnity(degree: 4, modulus: 29))
XCTAssertFalse(UInt64(12).isPrimitiveRootOfUnity(degree: 2, modulus: 29))
XCTAssertFalse(UInt64(12).isPrimitiveRootOfUnity(degree: 8, modulus: 29))

XCTAssertTrue(UInt64(1_234_565_440).isPrimitiveRootOfUnity(degree: 2, modulus: 1_234_565_441))
XCTAssertTrue(UInt64(960_907_033).isPrimitiveRootOfUnity(degree: 8, modulus: 1_234_565_441))
XCTAssertTrue(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 16, modulus: 1_234_565_441))
XCTAssertFalse(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 32, modulus: 1_234_565_441))
XCTAssertFalse(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 8, modulus: 1_234_565_441))
XCTAssertFalse(UInt64(1_180_581_915).isPrimitiveRootOfUnity(degree: 2, modulus: 1_234_565_441))
}

func testMinPrimitiveRootOfUnity() {
XCTAssertEqual(UInt32(11).minPrimitiveRootOfUnity(degree: 2), 10)
XCTAssertEqual(UInt32(29).minPrimitiveRootOfUnity(degree: 2), 28)
XCTAssertEqual(UInt32(29).minPrimitiveRootOfUnity(degree: 4), 12)
XCTAssertEqual(UInt64(1_234_565_441).minPrimitiveRootOfUnity(degree: 2), 1_234_565_440)
XCTAssertEqual(UInt64(1_234_565_441).minPrimitiveRootOfUnity(degree: 8), 249_725_733)
}
}
Loading

0 comments on commit 311db33

Please sign in to comment.