From 6d0134ce5d985cf85ed0a928b388f9c80b27046c Mon Sep 17 00:00:00 2001 From: Sebastian Angel Date: Sat, 9 Dec 2017 21:14:03 -0500 Subject: [PATCH] SealPIR rust wrapper v0.1.0 --- .gitignore | 6 + .gitmodules | 3 + Cargo.toml | 30 + README.md | 75 ++ SEAL_v2.3.0-4.patch | 1801 +++++++++++++++++++++++++++++++++ SealPIR | 1 + benches/pir.rs | 161 +++ build.rs | 26 + sealpir-bindings/pir_rust.cpp | 154 +++ sealpir-bindings/pir_rust.hpp | 80 ++ src/client.rs | 149 +++ src/lib.rs | 19 + src/server.rs | 160 +++ tests/pir.rs | 289 ++++++ 14 files changed, 2954 insertions(+) create mode 100644 .gitignore create mode 100644 .gitmodules create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 SEAL_v2.3.0-4.patch create mode 160000 SealPIR create mode 100644 benches/pir.rs create mode 100644 build.rs create mode 100644 sealpir-bindings/pir_rust.cpp create mode 100644 sealpir-bindings/pir_rust.hpp create mode 100644 src/client.rs create mode 100644 src/lib.rs create mode 100644 src/server.rs create mode 100644 tests/pir.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..315b787 --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +Cargo.lock +.clang-format +sealpir_bindings/bin/ +sealpir_bindings/obj/ +target/ +deps/ diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..3f6e868 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "SealPIR"] + path = SealPIR + url = https://github.com/sga001/SealPIR.git diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..88bbb68 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "sealpir" +version = "0.1.0" +authors = ["Sebastian Angel "] + +[build-dependencies] +gcc = "0.3.54" + +[dependencies] +libc = "0.2.42" +rand = "0.5.0" +serde = "1.0.65" +serde_derive = "1.0.65" + +[dev-dependencies] +criterion = "0.2.3" +serde_bytes = "0.10.4" + +[[bench]] +name = "pir" +harness = false + +[profile.release] +opt-level = 3 +debug = false +rpath = false +lto = true +debug-assertions = true +codegen-units = 1 +panic = 'unwind' diff --git a/README.md b/README.md new file mode 100644 index 0000000..e44d951 --- /dev/null +++ b/README.md @@ -0,0 +1,75 @@ +# SealPIR (Rust): Rust wrappers for SealPIR + +SealPIR is a research library and should not be used in production systems. SealPIR allows a client to download an element from a database stored by a server without revealing which element was downloaded. SealPIR was introduced in our [paper](https://eprint.iacr.org/2017/1142.pdf). + + +SealPIR relies on SEAL. The rest of this README assumes that SEAL is installed in the folder deps/SEAL within this repository. See below for instructions on how to install SEAL. + +# Compiling SEAL + +SealPIR depends on SEAL v2.3.0-4 and a patch that exposes the substitution operator. You can get SEAL v2.3.0-4 from this [link](http://sealcrypto.org). + +Once you have downloaded SEAL, apply the patch SEAL_v2.3.0-4.patch (available in this repository) to it. Here are the exact steps. + +We assume that you are in the SEAL directory, and that you have copied the patch to this directory. + +First, convert the SEAL directory into a git repo: + +```sh +$ git init +$ git add . +$ git commit -m "SEAL v2.3.0-4" +``` +Then, apply the patch: + +```sh +$ git am SEAL_v2.3.0-4.patch +``` + +Finally, compile SEAL (NOTE: gcc-8 is not currently supported): + +```sh +$ cd SEAL +$ ./configure CXXFLAGS="-O3 -march=native -fPIC" +$ make clean && make +``` + +# Compiling SealPIR-Rust + +SealPIR's Rust wrapper works with [Rust](https://www.rust-lang.org/) nightly (we have tested with Rust 1.28.0). It also depends on the C++ version of SealPIR (included as a submodule) and SEAL (see above). + +To compile SealPIR-Rust just do: + +```sh +$ git submodule init +$ git submodule update +$ cargo build +``` + +# Reproducing the results in the paper + +If you would like to reproduce the microbenchmarks found in the paper (Figure 9), simply run: + +```sh +$ cargo bench [prefix of name of benchmark (or leave blank to run all)] +``` + +For example, to reproduce the SealPIR entries of the first row of Figure 9 (Query), simply +run: + +```sh +$ cargo bench query +``` + +To reproduce a single data point, for example the Expand entry for SealPIR where n=262,144, run: + +```sh +$ cargo bench expand_d2/262144 +``` + +Note that the reply microbenchmark ("Answer" in Figure 9) already includes the cost of Expand (we subtract this cost in the paper). + + +You can find the code that runs these benchmarks (and their names) in ``benches/pir.rs``. + +To reproduce latency and throughput results, check out the [pir-test](https://github.com/sga001/pir-test) repository (this also has examples on how to use SealPIR in a client-server networked application). diff --git a/SEAL_v2.3.0-4.patch b/SEAL_v2.3.0-4.patch new file mode 100644 index 0000000..dff6008 --- /dev/null +++ b/SEAL_v2.3.0-4.patch @@ -0,0 +1,1801 @@ +From e37bf6b79c81cbbeff19378ad425f987c036286b Mon Sep 17 00:00:00 2001 +From: Kim Laine +Date: Mon, 4 Dec 2017 16:09:56 -0800 +Subject: [PATCH 1/3] Explosed generic Galois automorphisms in public API + +--- + SEAL/seal/evaluator.cpp | 5 + + SEAL/seal/evaluator.h | 181 ++++++++++++++++++++++++++------ + SEAL/seal/keygenerator.cpp | 8 -- + SEAL/seal/keygenerator.h | 40 ++++--- + SEALNET/sealnet/EvaluatorWrapper.cpp | 146 ++++++++++++++++++++++++++ + SEALNET/sealnet/EvaluatorWrapper.h | 157 ++++++++++++++++++++++++--- + SEALNET/sealnet/KeyGeneratorWrapper.cpp | 33 ++++++ + SEALNET/sealnet/KeyGeneratorWrapper.h | 35 ++++++ + SEALNETTest/EvaluatorWrapper.cs | 82 +++++++++++++++ + SEALNETTest/KeyGeneratorWrapper.cs | 147 ++++++++++++++++++++++++++ + SEALTest/evaluator.cpp | 79 ++++++++++++++ + SEALTest/keygenerator.cpp | 146 ++++++++++++++++++++++++++ + 12 files changed, 990 insertions(+), 69 deletions(-) + +diff --git a/SEAL/seal/evaluator.cpp b/SEAL/seal/evaluator.cpp +index 0a5b99d..8945b5d 100644 +--- a/SEAL/seal/evaluator.cpp ++++ b/SEAL/seal/evaluator.cpp +@@ -1791,6 +1791,11 @@ namespace seal + + void Evaluator::rotate_rows(Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys, const MemoryPoolHandle &pool) + { ++ if (!qualifiers_.enable_batching) ++ { ++ throw logic_error("encryption parameters do not support batching"); ++ } ++ + // Is there anything to do? + if (steps == 0) + { +diff --git a/SEAL/seal/evaluator.h b/SEAL/seal/evaluator.h +index 5b67455..93cc1d2 100644 +--- a/SEAL/seal/evaluator.h ++++ b/SEAL/seal/evaluator.h +@@ -894,13 +894,136 @@ namespace seal + @throws std::invalid_argument if plain_ntt is zero + @throws std::logic_error if destination_ntt is aliased and needs to be reallocated + */ +- inline void multiply_plain_ntt(const Ciphertext &encrypted_ntt, const Plaintext &plain_ntt, +- Ciphertext &destination_ntt) ++ inline void multiply_plain_ntt(const Ciphertext &encrypted_ntt, ++ const Plaintext &plain_ntt, Ciphertext &destination_ntt) + { + destination_ntt = encrypted_ntt; + multiply_plain_ntt(destination_ntt, plain_ntt); + } + ++ /** ++ Applies a Galois automorphism to a ciphertext. To evaluate the Galois automorphism, ++ an appropriate set of Galois keys must also be provided. Dynamic memory allocations ++ in the process are allocated from the memory pool pointed to by the given ++ MemoryPoolHandle. ++ ++ The desired Galois automorphism is given as a Galois element, and must be an odd ++ integer in the interval [1, M-1], where M = 2*N, and N = degree(poly_modulus). Used ++ with batching, a Galois element 3^i % M corresponds to a cyclic row rotation i steps ++ to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row rotation ++ i steps to the right. The Galois element M-1 corresponds to a column rotation (row ++ swap). In the polynomial view (not batching), a Galois automorphism by a Galois ++ element p changes Enc(plain(x)) to Enc(plain(x^p)). ++ ++ @param[in] encrypted The ciphertext to apply the Galois automorphism to ++ @param[in] galois_elt The Galois element ++ @param[in] galois_keys The Galois keys ++ @param[in] pool The MemoryPoolHandle pointing to a valid memory pool ++ @throws std::invalid_argument if encrypted or galois_keys is not valid for the ++ encryption parameters ++ @throws std::invalid_argument if encrypted has size greater than two ++ @throws std::invalid_argument if the Galois element is not valid ++ @throws std::invalid_argument if necessary Galois keys are not present ++ @throws std::invalid_argument if pool is uninitialized ++ */ ++ void apply_galois(Ciphertext &encrypted, std::uint64_t galois_elt, ++ const GaloisKeys &galois_keys, const MemoryPoolHandle &pool); ++ ++ /** ++ Applies a Galois automorphism to a ciphertext. To evaluate the Galois automorphism, ++ an appropriate set of Galois keys must also be provided. Dynamic memory allocations ++ in the process are allocated from the memory pool pointed to by the local ++ MemoryPoolHandle. ++ ++ The desired Galois automorphism is given as a Galois element, and must be an odd ++ integer in the interval [1, M-1], where M = 2*N, and N = degree(poly_modulus). Used ++ with batching, a Galois element 3^i % M corresponds to a cyclic row rotation i steps ++ to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row rotation ++ i steps to the right. The Galois element M-1 corresponds to a column rotation (row ++ swap). In the polynomial view (not batching), a Galois automorphism by a Galois ++ element p changes Enc(plain(x)) to Enc(plain(x^p)). ++ ++ @param[in] encrypted The ciphertext to apply the Galois automorphism to ++ @param[in] galois_elt The Galois element ++ @param[in] galois_keys The Galois keys ++ @throws std::invalid_argument if encrypted or galois_keys is not valid for the ++ encryption parameters ++ @throws std::invalid_argument if encrypted has size greater than two ++ @throws std::invalid_argument if the Galois element is not valid ++ @throws std::invalid_argument if necessary Galois keys are not present ++ */ ++ inline void apply_galois(Ciphertext &encrypted, std::uint64_t galois_elt, ++ const GaloisKeys &galois_keys) ++ { ++ apply_galois(encrypted, galois_elt, galois_keys, pool_); ++ } ++ ++ /** ++ Applies a Galois automorphism to a ciphertext and writes the result to the ++ destination parameter. To evaluate the Galois automorphism, an appropriate set of ++ Galois keys must also be provided. Dynamic memory allocations in the process are ++ allocated from the memory pool pointed to by the given MemoryPoolHandle. ++ ++ The desired Galois automorphism is given as a Galois element, and must be an odd ++ integer in the interval [1, M-1], where M = 2*N, and N = degree(poly_modulus). Used ++ with batching, a Galois element 3^i % M corresponds to a cyclic row rotation i steps ++ to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row rotation ++ i steps to the right. The Galois element M-1 corresponds to a column rotation (row ++ swap). In the polynomial view (not batching), a Galois automorphism by a Galois ++ element p changes Enc(plain(x)) to Enc(plain(x^p)). ++ ++ @param[in] encrypted The ciphertext to apply the Galois automorphism to ++ @param[in] galois_elt The Galois element ++ @param[in] galois_keys The Galois keys ++ @param[out] destination The ciphertext to overwrite with the result ++ @param[in] pool The MemoryPoolHandle pointing to a valid memory pool ++ @throws std::invalid_argument if encrypted or galois_keys is not valid for the ++ encryption parameters ++ @throws std::invalid_argument if encrypted has size greater than two ++ @throws std::invalid_argument if the Galois element is not valid ++ @throws std::invalid_argument if necessary Galois keys are not present ++ @throws std::logic_error if destination is aliased and needs to be reallocated ++ @throws std::invalid_argument if pool is uninitialized ++ */ ++ inline void apply_galois(const Ciphertext &encrypted, std::uint64_t galois_elt, ++ const GaloisKeys &galois_keys, Ciphertext &destination, ++ const MemoryPoolHandle &pool) ++ { ++ destination = encrypted; ++ apply_galois(destination, galois_elt, galois_keys, pool); ++ } ++ ++ /** ++ Applies a Galois automorphism to a ciphertext and writes the result to the ++ destination parameter. To evaluate the Galois automorphism, an appropriate set of ++ Galois keys must also be provided. Dynamic memory allocations in the process are ++ allocated from the memory pool pointed to by the local MemoryPoolHandle. ++ ++ The desired Galois automorphism is given as a Galois element, and must be an odd ++ integer in the interval [1, M-1], where M = 2*N, and N = degree(poly_modulus). Used ++ with batching, a Galois element 3^i % M corresponds to a cyclic row rotation i steps ++ to the left, and a Galois element 3^(N/2-i) % M corresponds to a cyclic row rotation ++ i steps to the right. The Galois element M-1 corresponds to a column rotation (row ++ swap). In the polynomial view (not batching), a Galois automorphism by a Galois ++ element p changes Enc(plain(x)) to Enc(plain(x^p)). ++ ++ @param[in] encrypted The ciphertext to apply the Galois automorphism to ++ @param[in] galois_elt The Galois element ++ @param[in] galois_keys The Galois keys ++ @param[out] destination The ciphertext to overwrite with the result ++ @throws std::invalid_argument if encrypted or galois_keys is not valid for the ++ encryption parameters ++ @throws std::invalid_argument if encrypted has size greater than two ++ @throws std::invalid_argument if the Galois element is not valid ++ @throws std::invalid_argument if necessary Galois keys are not present ++ @throws std::logic_error if destination is aliased and needs to be reallocated ++ */ ++ inline void apply_galois(const Ciphertext &encrypted, std::uint64_t galois_elt, ++ const GaloisKeys &galois_keys, Ciphertext &destination) ++ { ++ apply_galois(encrypted, galois_elt, galois_keys, destination, pool_); ++ } ++ + /** + Rotates plaintext matrix rows cyclically. When batching is used, this function rotates + the encrypted plaintext matrix rows cyclically to the left (steps > 0) or to the right +@@ -913,6 +1036,7 @@ namespace seal + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -934,13 +1058,15 @@ namespace seal + @param[in] encrypted The ciphertext to rotate + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two + @throws std::invalid_argument if steps has too big absolute value + @throws std::invalid_argument if necessary Galois keys are not present + */ +- inline void rotate_rows(Ciphertext &encrypted, int steps, const GaloisKeys &galois_keys) ++ inline void rotate_rows(Ciphertext &encrypted, int steps, ++ const GaloisKeys &galois_keys) + { + rotate_rows(encrypted, steps, galois_keys, pool_); + } +@@ -959,6 +1085,7 @@ namespace seal + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -968,7 +1095,8 @@ namespace seal + @throws std::invalid_argument if pool is uninitialized + */ + inline void rotate_rows(const Ciphertext &encrypted, int steps, +- const GaloisKeys &galois_keys, Ciphertext &destination, const MemoryPoolHandle &pool) ++ const GaloisKeys &galois_keys, Ciphertext &destination, ++ const MemoryPoolHandle &pool) + { + destination = encrypted; + rotate_rows(destination, steps, galois_keys, pool); +@@ -987,6 +1115,7 @@ namespace seal + @param[in] steps The number of steps to rotate (negative left, positive right) + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -1011,6 +1140,7 @@ namespace seal + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -1021,6 +1151,10 @@ namespace seal + inline void rotate_columns(Ciphertext &encrypted, const GaloisKeys &galois_keys, + const MemoryPoolHandle &pool) + { ++ if (!qualifiers_.enable_batching) ++ { ++ throw std::logic_error("encryption parameters do not support batching"); ++ } + std::uint64_t m = (parms_.poly_modulus().coeff_count() - 1) << 1; + apply_galois(encrypted, m - 1, galois_keys, pool); + } +@@ -1035,6 +1169,7 @@ namespace seal + @param[in] encrypted The ciphertext to rotate + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -1058,6 +1193,7 @@ namespace seal + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result + @param[in] pool The MemoryPoolHandle pointing to a valid memory pool ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -1084,6 +1220,7 @@ namespace seal + @param[in] encrypted The ciphertext to rotate + @param[in] galois_keys The Galois keys + @param[out] destination The ciphertext to overwrite with the rotated result ++ @throws std::logic_error if the encryption parameters do not support batching + @throws std::invalid_argument if encrypted or galois_keys is not valid for the + encryption parameters + @throws std::invalid_argument if encrypted has size greater than two +@@ -1101,10 +1238,11 @@ namespace seal + + Evaluator &operator =(Evaluator &&assign) = delete; + +- void relinearize(Ciphertext &encrypted, const EvaluationKeys &evaluation_keys, int destination_size, +- const MemoryPoolHandle &pool); ++ void relinearize(Ciphertext &encrypted, const EvaluationKeys &evaluation_keys, ++ int destination_size, const MemoryPoolHandle &pool); + +- inline void decompose_single_coeff(const std::uint64_t *value, std::uint64_t *destination, const MemoryPoolHandle &pool) ++ inline void decompose_single_coeff(const std::uint64_t *value, ++ std::uint64_t *destination, const MemoryPoolHandle &pool) + { + #ifdef SEAL_DEBUG + if (value == nullptr) +@@ -1150,7 +1288,8 @@ namespace seal + } + } + +- inline void decompose(const std::uint64_t *value, std::uint64_t *destination, const MemoryPoolHandle &pool) ++ inline void decompose(const std::uint64_t *value, std::uint64_t *destination, ++ const MemoryPoolHandle &pool) + { + #ifdef SEAL_DEBUG + if (value == nullptr) +@@ -1209,32 +1348,6 @@ namespace seal + + void populate_Zmstar_to_generator(); + +- // The apply_galois function applies a Galois automorphism to a ciphertext. +- // It is needed for slot permutations. +- // Input: encryption of M(x) and an integer p such that gcd(p, m) = 1. +- // Output: encryption of M(x^p). +- // The function requires certain GaloisKeys and auxiliary data. +- void apply_galois(Ciphertext &encrypted, std::uint64_t galois_elt, const GaloisKeys &evaluation_keys, +- const MemoryPoolHandle &pool); +- +- inline void apply_galois(Ciphertext &encrypted, std::uint64_t galois_elt, const GaloisKeys &evaluation_keys) +- { +- apply_galois(encrypted, galois_elt, evaluation_keys, pool_); +- } +- +- inline void apply_galois(const Ciphertext &encrypted, std::uint64_t galois_elt, +- const GaloisKeys &evaluation_keys, Ciphertext &destination, const MemoryPoolHandle &pool) +- { +- destination = encrypted; +- apply_galois(destination, galois_elt, evaluation_keys, pool); +- } +- +- inline void apply_galois(const Ciphertext &encrypted, std::uint64_t galois_elt, +- const GaloisKeys &evaluation_keys, Ciphertext &destination) +- { +- apply_galois(encrypted, galois_elt, evaluation_keys, destination, pool_); +- } +- + MemoryPoolHandle pool_; + + EncryptionParameters parms_; +diff --git a/SEAL/seal/keygenerator.cpp b/SEAL/seal/keygenerator.cpp +index fa7fd46..ee6338f 100644 +--- a/SEAL/seal/keygenerator.cpp ++++ b/SEAL/seal/keygenerator.cpp +@@ -293,10 +293,6 @@ namespace seal + { + throw logic_error("cannot generate galois keys for unspecified secret key"); + } +- if (!qualifiers_.enable_batching) +- { +- throw logic_error("encryption parameters are not valid for batching"); +- } + + // Check that decomposition_bit_count is in correct interval + if (decomposition_bit_count < SEAL_DBC_MIN || decomposition_bit_count > SEAL_DBC_MAX) +@@ -426,10 +422,6 @@ namespace seal + { + throw logic_error("cannot generate galois keys for unspecified secret key"); + } +- if (!qualifiers_.enable_batching) +- { +- throw logic_error("encryption parameters are not valid for batching"); +- } + + // Check that decomposition_bit_count is in correct interval + if (decomposition_bit_count < SEAL_DBC_MIN || decomposition_bit_count > SEAL_DBC_MAX) +diff --git a/SEAL/seal/keygenerator.h b/SEAL/seal/keygenerator.h +index a702a3e..393aa61 100644 +--- a/SEAL/seal/keygenerator.h ++++ b/SEAL/seal/keygenerator.h +@@ -99,15 +99,38 @@ namespace seal + } + + /** +- Generates Galois keys. ++ Generates Galois keys. This function creates logarithmically many (in degree of the ++ polynomial modulus) Galois keys that is sufficient to apply any Galois automorphism ++ (e.g. rotations) on encrypted data. Most users will want to use this overload of ++ the function. + + @param[in] decomposition_bit_count The decomposition bit count + @param[out] galois_keys The Galois keys instance to overwrite with the generated keys + @throws std::invalid_argument if decomposition_bit_count is not within [1, 60] +- @throws std::logic_error if the encryption parameters do not support batching +- */ ++ */ + void generate_galois_keys(int decomposition_bit_count, GaloisKeys &galois_keys); + ++ /** ++ Generates Galois keys. This function creates specific Galois keys that can be used to ++ apply specific Galois automorphisms on encrypted data. The user needs to give as ++ input a vector of Galois elements corresponding to the keys that are to be created. ++ ++ The Galois elements are odd integers in the interval [1, M-1], where M = 2*N, and ++ N = degree(poly_modulus). Used with batching, a Galois element 3^i % M corresponds ++ to a cyclic row rotation i steps to the left, and a Galois element 3^(N/2-i) % M ++ corresponds to a cyclic row rotation i steps to the right. The Galois element M-1 ++ corresponds to a column rotation (row swap). In the polynomial view (not batching), ++ a Galois automorphism by a Galois element p changes Enc(plain(x)) to Enc(plain(x^p)). ++ ++ @param[in] decomposition_bit_count The decomposition bit count ++ @param[in] galois_elts The Galois elements for which to generate keys ++ @param[out] galois_keys The Galois keys instance to overwrite with the generated keys ++ @throws std::invalid_argument if decomposition_bit_count is not within [1, 60] ++ @throws std::invalid_argument if the Galois elements are not valid ++ */ ++ void generate_galois_keys(int decomposition_bit_count, ++ const std::vector &galois_elts, GaloisKeys &galois_keys); ++ + private: + KeyGenerator(const KeyGenerator ©) = delete; + +@@ -141,17 +164,6 @@ namespace seal + return generated_; + } + +- void generate_galois_keys(int decomposition_bit_count, +- const std::vector &galois_elts, GaloisKeys &galois_keys); +- +- inline GaloisKeys generate_galois_keys(int decomposition_bit_count, +- const std::vector &galois_elts) +- { +- GaloisKeys keys; +- generate_galois_keys(decomposition_bit_count, galois_elts, keys); +- return keys; +- } +- + MemoryPoolHandle pool_; + + EncryptionParameters parms_; +diff --git a/SEALNET/sealnet/EvaluatorWrapper.cpp b/SEALNET/sealnet/EvaluatorWrapper.cpp +index b4868ed..648fbfd 100644 +--- a/SEALNET/sealnet/EvaluatorWrapper.cpp ++++ b/SEALNET/sealnet/EvaluatorWrapper.cpp +@@ -1513,6 +1513,152 @@ namespace Microsoft + } + } + ++ void Evaluator::ApplyGalois(Ciphertext ^encrypted, UInt64 galoisElt, GaloisKeys ^galoisKeys) ++ { ++ if (evaluator_ == nullptr) ++ { ++ throw gcnew ObjectDisposedException("Evaluator is disposed"); ++ } ++ if (encrypted == nullptr) ++ { ++ throw gcnew ArgumentNullException("encrypted cannot be null"); ++ } ++ if (galoisKeys == nullptr) ++ { ++ throw gcnew ArgumentNullException("galoisKeys cannot be null"); ++ } ++ try ++ { ++ evaluator_->apply_galois(encrypted->GetCiphertext(), galoisElt, galoisKeys->GetKeys()); ++ GC::KeepAlive(encrypted); ++ GC::KeepAlive(galoisKeys); ++ } ++ catch (const exception &e) ++ { ++ HandleException(&e); ++ } ++ catch (...) ++ { ++ HandleException(nullptr); ++ } ++ } ++ ++ void Evaluator::ApplyGalois(Ciphertext ^encrypted, UInt64 galoisElt, GaloisKeys ^galoisKeys, ++ MemoryPoolHandle ^pool) ++ { ++ if (evaluator_ == nullptr) ++ { ++ throw gcnew ObjectDisposedException("Evaluator is disposed"); ++ } ++ if (encrypted == nullptr) ++ { ++ throw gcnew ArgumentNullException("encrypted cannot be null"); ++ } ++ if (galoisKeys == nullptr) ++ { ++ throw gcnew ArgumentNullException("galoisKeys cannot be null"); ++ } ++ if (pool == nullptr) ++ { ++ throw gcnew ArgumentNullException("pool cannot be null"); ++ } ++ try ++ { ++ evaluator_->apply_galois(encrypted->GetCiphertext(), galoisElt, ++ galoisKeys->GetKeys(), pool->GetHandle()); ++ GC::KeepAlive(encrypted); ++ GC::KeepAlive(galoisKeys); ++ GC::KeepAlive(pool); ++ } ++ catch (const exception &e) ++ { ++ HandleException(&e); ++ } ++ catch (...) ++ { ++ HandleException(nullptr); ++ } ++ } ++ ++ void Evaluator::ApplyGalois(Ciphertext ^encrypted, UInt64 galoisElt, GaloisKeys ^galoisKeys, ++ Ciphertext ^destination) ++ { ++ if (evaluator_ == nullptr) ++ { ++ throw gcnew ObjectDisposedException("Evaluator is disposed"); ++ } ++ if (encrypted == nullptr) ++ { ++ throw gcnew ArgumentNullException("encrypted cannot be null"); ++ } ++ if (galoisKeys == nullptr) ++ { ++ throw gcnew ArgumentNullException("galoisKeys cannot be null"); ++ } ++ if (destination == nullptr) ++ { ++ throw gcnew ArgumentNullException("destination cannot be null"); ++ } ++ try ++ { ++ evaluator_->apply_galois(encrypted->GetCiphertext(), galoisElt, ++ galoisKeys->GetKeys(), destination->GetCiphertext()); ++ GC::KeepAlive(encrypted); ++ GC::KeepAlive(galoisKeys); ++ GC::KeepAlive(destination); ++ } ++ catch (const exception &e) ++ { ++ HandleException(&e); ++ } ++ catch (...) ++ { ++ HandleException(nullptr); ++ } ++ } ++ ++ void Evaluator::ApplyGalois(Ciphertext ^encrypted, UInt64 galoisElt, GaloisKeys ^galoisKeys, ++ Ciphertext ^destination, MemoryPoolHandle ^pool) ++ { ++ if (evaluator_ == nullptr) ++ { ++ throw gcnew ObjectDisposedException("Evaluator is disposed"); ++ } ++ if (encrypted == nullptr) ++ { ++ throw gcnew ArgumentNullException("encrypted cannot be null"); ++ } ++ if (galoisKeys == nullptr) ++ { ++ throw gcnew ArgumentNullException("galoisKeys cannot be null"); ++ } ++ if (destination == nullptr) ++ { ++ throw gcnew ArgumentNullException("destination cannot be null"); ++ } ++ if (pool == nullptr) ++ { ++ throw gcnew ArgumentNullException("pool cannot be null"); ++ } ++ try ++ { ++ evaluator_->apply_galois(encrypted->GetCiphertext(), galoisElt, ++ galoisKeys->GetKeys(), destination->GetCiphertext(), pool->GetHandle()); ++ GC::KeepAlive(encrypted); ++ GC::KeepAlive(galoisKeys); ++ GC::KeepAlive(destination); ++ GC::KeepAlive(pool); ++ } ++ catch (const exception &e) ++ { ++ HandleException(&e); ++ } ++ catch (...) ++ { ++ HandleException(nullptr); ++ } ++ } ++ + void Evaluator::RotateRows(Ciphertext ^encrypted, int steps, GaloisKeys ^galoisKeys) + { + if (evaluator_ == nullptr) +diff --git a/SEALNET/sealnet/EvaluatorWrapper.h b/SEALNET/sealnet/EvaluatorWrapper.h +index 1f99af1..7a7ef05 100644 +--- a/SEALNET/sealnet/EvaluatorWrapper.h ++++ b/SEALNET/sealnet/EvaluatorWrapper.h +@@ -113,7 +113,7 @@ namespace Microsoft + by the given . + The SEALContext +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encryption parameters are not valid + if pool is uninitialized + if context or pool is null +@@ -256,7 +256,7 @@ namespace Microsoft + + The first ciphertext to multiply + The second ciphertext to multiply +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encrypted1 or encrypted2 is not valid + for the encryption parameters + if pool is uninitialized +@@ -296,7 +296,7 @@ namespace Microsoft + The first ciphertext to multiply + The second ciphertext to multiply + The ciphertext to overwrite with the multiplication result +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encrypted1 or encrypted2 is not valid + for the encryption parameters + if pool is uninitialized +@@ -330,7 +330,7 @@ namespace Microsoft + the memory pool pointed to by the given . + + The ciphertext to square +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encrypted is not valid for the + encryption parameters + if encrypted or pool is +@@ -365,7 +365,7 @@ namespace Microsoft + + The ciphertext to square + The ciphertext to overwrite with the square +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encrypted is not valid for the + encryption parameters + if pool is uninitialized +@@ -385,7 +385,7 @@ namespace Microsoft + + The ciphertext to relinearize + The evaluation keys +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encrypted or evaluationKeys is not + valid for the encryption parameters + if the size of evaluationKeys is too +@@ -450,7 +450,7 @@ namespace Microsoft + The ciphertext to relinearize + The evaluation keys + The ciphertext to overwrite with the relinearized result +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool + if encrypted or evaluationKeys is not valid + for the encryption parameters + if the size of evaluationKeys is too +@@ -1024,6 +1024,121 @@ namespace Microsoft + */ + void MultiplyPlainNTT(Ciphertext ^encryptedNTT, Plaintext ^plainNTT); + ++ /** ++ Applies a Galois automorphism to a ciphertext. ++ ++ ++ Applies a Galois automorphism to a ciphertext. To evaluate the Galois automorphism, ++ an appropriate set of Galois keys must also be provided. Dynamic memory allocations ++ in the process are allocated from the memory pool pointed to by the local ++ . ++ ++ The ciphertext to apply the Galois automorphism to ++ The Galois element ++ The Galois keys ++ if encrypted or galoisKeys is not valid ++ for the encryption parameters ++ if encrypted has size greater than ++ two ++ if the Galois element is not ++ valid ++ if necessary Galois keys are not ++ present ++ if encrypted or galoisKeys is ++ null ++ */ ++ void ApplyGalois(Ciphertext ^encrypted, System::UInt64 galoisElt, ++ GaloisKeys ^galoisKeys); ++ ++ /** ++ Applies a Galois automorphism to a ciphertext. ++ ++ ++ Applies a Galois automorphism to a ciphertext. To evaluate the Galois automorphism, ++ an appropriate set of Galois keys must also be provided. Dynamic memory allocations ++ in the process are allocated from the memory pool pointed to by the given ++ . ++ ++ The ciphertext to apply the Galois automorphism to ++ The Galois element ++ The Galois keys ++ The ++ The MemoryPoolHandle pointing to a valid memory pool ++ if encrypted or galoisKeys is not valid ++ for the encryption parameters ++ if encrypted has size greater than ++ two ++ if the Galois element is not ++ valid ++ if necessary Galois keys are not ++ present ++ if pool is uninitialized ++ if encrypted, galoisKeys or pool ++ is null ++ */ ++ void ApplyGalois(Ciphertext ^encrypted, System::UInt64 galoisElt, ++ GaloisKeys ^galoisKeys, MemoryPoolHandle ^pool); ++ ++ /** ++ Applies a Galois automorphism to a ciphertext and writes the result ++ to the destination parameter. ++ ++ ++ Applies a Galois automorphism to a ciphertext and writes the result to the ++ destination parameter. To evaluate the Galois automorphism, an appropriate ++ set of Galois keys must also be provided. Dynamic memory allocations in the ++ process are allocated from the memory pool pointed to by the local ++ . ++ ++ The ciphertext to apply the Galois automorphism to ++ The Galois element ++ The Galois keys ++ The ciphertext to overwrite with the result ++ if encrypted or galoisKeys is not valid ++ for the encryption parameters ++ if encrypted has size greater than ++ two ++ if the Galois element is not ++ valid ++ if necessary Galois keys are not ++ present ++ if encrypted, galoisKeys, or ++ destination is null ++ */ ++ void ApplyGalois(Ciphertext ^encrypted, System::UInt64 galoisElt, ++ GaloisKeys ^galoisKeys, Ciphertext ^destination); ++ ++ /** ++ Applies a Galois automorphism to a ciphertext and writes the result ++ to the destination parameter. ++ ++ ++ Applies a Galois automorphism to a ciphertext and writes the result to the ++ destination parameter. To evaluate the Galois automorphism, an appropriate ++ set of Galois keys must also be provided. Dynamic memory allocations in the ++ process are allocated from the memory pool pointed to by the given ++ . ++ ++ The ciphertext to apply the Galois automorphism to ++ The Galois element ++ The Galois keys ++ The ciphertext to overwrite with the result ++ The MemoryPoolHandle pointing to a valid memory pool ++ if encrypted or galoisKeys is not valid ++ for the encryption parameters ++ if encrypted has size greater than ++ two ++ if the Galois element is not ++ valid ++ if necessary Galois keys are not ++ present ++ if pool is uninitialized ++ if encrypted, galoisKeys, ++ destination, or pool is null ++ */ ++ void ApplyGalois(Ciphertext ^encrypted, System::UInt64 galoisElt, ++ GaloisKeys ^galoisKeys, Ciphertext ^destination, MemoryPoolHandle ^pool); ++ + /** + Rotates plaintext matrix rows cyclically. + +@@ -1038,6 +1153,8 @@ namespace Microsoft + The ciphertext to rotate + The number of steps to rotate (negative left, positive right) + The Galois keys ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1046,7 +1163,7 @@ namespace Microsoft + value + if necessary Galois keys are not + present +- if encrypted, galoisKeys or pool ++ if encrypted, galoisKeys, or pool + is null + */ + void RotateRows(Ciphertext ^encrypted, int steps, GaloisKeys ^galoisKeys); +@@ -1065,7 +1182,9 @@ namespace Microsoft + The ciphertext to rotate + The number of steps to rotate (negative left, positive right) + The Galois keys +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1075,7 +1194,7 @@ namespace Microsoft + if necessary Galois keys are not + present + if pool is uninitialized +- if encrypted, galoisKeys or pool ++ if encrypted, galoisKeys, or pool + is null + */ + void RotateRows(Ciphertext ^encrypted, int steps, GaloisKeys ^galoisKeys, +@@ -1097,6 +1216,8 @@ namespace Microsoft + The number of steps to rotate (negative left, positive right) + The Galois keys + The ciphertext to overwrite with the rotated result ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1127,7 +1248,9 @@ namespace Microsoft + The number of steps to rotate (negative left, positive right) + The Galois keys + The ciphertext to overwrite with the rotated result +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory param> ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1156,6 +1279,8 @@ namespace Microsoft + + The ciphertext to rotate + The Galois keys ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1180,7 +1305,9 @@ namespace Microsoft + + The ciphertext to rotate + The Galois keys +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1208,6 +1335,8 @@ namespace Microsoft + The ciphertext to rotate + The Galois keys + The ciphertext to overwrite with the rotated result ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if encrypted has size greater than +@@ -1234,7 +1363,9 @@ namespace Microsoft + The ciphertext to rotate + The Galois keys + The ciphertext to overwrite with the rotated result +- The MemoryPoolHandle pointing to a valid memory pool/param> ++ The MemoryPoolHandle pointing to a valid memory pool ++ if the encryption parameters do ++ not support batching + if encrypted or galoisKeys is not valid + for the encryption parameters + if necessary Galois keys are not +diff --git a/SEALNET/sealnet/KeyGeneratorWrapper.cpp b/SEALNET/sealnet/KeyGeneratorWrapper.cpp +index c9d52de..e1547f1 100644 +--- a/SEALNET/sealnet/KeyGeneratorWrapper.cpp ++++ b/SEALNET/sealnet/KeyGeneratorWrapper.cpp +@@ -1,10 +1,12 @@ + #include ++#include + #include "sealnet/KeyGeneratorWrapper.h" + #include "sealnet/BigPolyWrapper.h" + #include "sealnet/BigUIntWrapper.h" + #include "sealnet/Common.h" + + using namespace System; ++using namespace System::Collections::Generic; + using namespace std; + + namespace Microsoft +@@ -205,6 +207,37 @@ namespace Microsoft + } + } + ++ void KeyGenerator::GenerateGaloisKeys(int decompositionBitCount, List ^galoisElts, ++ GaloisKeys ^galoisKeys) ++ { ++ if (generator_ == nullptr) ++ { ++ throw gcnew ObjectDisposedException("KeyGenerator is disposed"); ++ } ++ if (galoisKeys == nullptr) ++ { ++ throw gcnew ArgumentNullException("galoisKeys cannot be null"); ++ } ++ try ++ { ++ std::vector v_galois_elts; ++ for (int i = 0; i < galoisElts->Count; i++) ++ { ++ v_galois_elts.push_back(galoisElts[i]); ++ } ++ generator_->generate_galois_keys(decompositionBitCount, v_galois_elts, galoisKeys->GetKeys()); ++ GC::KeepAlive(galoisElts); ++ } ++ catch (const exception &e) ++ { ++ HandleException(&e); ++ } ++ catch (...) ++ { ++ HandleException(nullptr); ++ } ++ } ++ + Microsoft::Research::SEAL::PublicKey ^KeyGenerator::PublicKey::get() + { + if (generator_ == nullptr) +diff --git a/SEALNET/sealnet/KeyGeneratorWrapper.h b/SEALNET/sealnet/KeyGeneratorWrapper.h +index 7fd722c..cf7f1fc 100644 +--- a/SEALNET/sealnet/KeyGeneratorWrapper.h ++++ b/SEALNET/sealnet/KeyGeneratorWrapper.h +@@ -158,6 +158,12 @@ namespace Microsoft + /** + Generates Galois keys. + ++ ++ Generates Galois keys. This function creates logarithmically many (in degree of the ++ polynomial modulus) Galois keys that is sufficient to apply any Galois automorphism ++ (e.g. rotations) on encrypted data. Most users will want to use this overload of ++ the function. ++ + The decomposition bit count + The Galois keys instance to overwrite with the generated + keys +@@ -167,6 +173,35 @@ namespace Microsoft + */ + void GenerateGaloisKeys(int decompositionBitCount, GaloisKeys ^galoisKeys); + ++ /** ++ Generates Galois keys. ++ ++ ++ Generates Galois keys. This function creates specific Galois keys that can be used to ++ apply specific Galois automorphisms on encrypted data. The user needs to give as ++ input a vector of Galois elements corresponding to the keys that are to be created. ++ ++ The Galois elements are odd integers in the interval [1, M-1], where M = 2*N, and ++ N = degree(PolyModulus). Used with batching, a Galois element 3^i % M corresponds ++ to a cyclic row rotation i steps to the left, and a Galois element 3^(N/2-i) % M ++ corresponds to a cyclic row rotation i steps to the right. The Galois element M-1 ++ corresponds to a column rotation (row swap). In the polynomial view (not batching), ++ a Galois automorphism by a Galois element p changes Enc(plain(x)) to Enc(plain(x^p)). ++ ++ The decomposition bit count ++ The Galois elements for which to generate keys ++ The Galois keys instance to overwrite with the generated ++ keys ++ if decompositionBitCount is not ++ within [0, 60] ++ if the Galois elements are not ++ valid ++ if galoisKeys is null ++ */ ++ void GenerateGaloisKeys(int decompositionBitCount, ++ System::Collections::Generic::List ^galoisElts, ++ GaloisKeys ^galoisKeys); ++ + /** + Destroys the KeyGenerator. + */ +diff --git a/SEALNETTest/EvaluatorWrapper.cs b/SEALNETTest/EvaluatorWrapper.cs +index 541786c..4321ecd 100644 +--- a/SEALNETTest/EvaluatorWrapper.cs ++++ b/SEALNETTest/EvaluatorWrapper.cs +@@ -982,6 +982,88 @@ namespace SEALNETTest + Assert.AreEqual(encrypted.HashBlock, parms.HashBlock); + } + ++ [TestMethod] ++ public void FVEncryptApplyGaloisDecryptNET() ++ { ++ var parms = new EncryptionParameters(); ++ var plain_modulus = new SmallModulus(257); ++ parms.NoiseStandardDeviation = 3.19; ++ parms.PlainModulus = plain_modulus; ++ parms.PolyModulus = "1x^8 + 1"; ++ parms.CoeffModulus = new List { ++ DefaultParams.SmallMods40Bit(0), DefaultParams.SmallMods40Bit(1) ++ }; ++ var context = new SEALContext(parms); ++ var keygen = new KeyGenerator(context); ++ var glk = new GaloisKeys(); ++ keygen.GenerateGaloisKeys(24, new List { 1, 3, 5, 15 }, glk); ++ ++ var encryptor = new Encryptor(context, keygen.PublicKey); ++ var evaluator = new Evaluator(context); ++ var decryptor = new Decryptor(context, keygen.SecretKey); ++ ++ var plain = new Plaintext("1"); ++ var encrypted = new Ciphertext(); ++ encryptor.Encrypt(plain, encrypted); ++ evaluator.ApplyGalois(encrypted, 1, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 3, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 5, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 15, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1", plain.ToString()); ++ ++ plain.Set("1x^1"); ++ encryptor.Encrypt(plain, encrypted); ++ evaluator.ApplyGalois(encrypted, 1, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 3, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^3", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 5, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("100x^7", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 15, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^1", plain.ToString()); ++ ++ plain.Set("1x^2"); ++ encryptor.Encrypt(plain, encrypted); ++ evaluator.ApplyGalois(encrypted, 1, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^2", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 3, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^6", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 5, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("100x^6", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 15, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^2", plain.ToString()); ++ ++ plain.Set("1x^3 + 2x^2 + 1x^1 + 1"); ++ encryptor.Encrypt(plain, encrypted); ++ evaluator.ApplyGalois(encrypted, 1, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^3 + 2x^2 + 1x^1 + 1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 3, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("2x^6 + 1x^3 + 100x^1 + 1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 5, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("100x^7 + FFx^6 + 100x^5 + 1", plain.ToString()); ++ evaluator.ApplyGalois(encrypted, 15, glk); ++ decryptor.Decrypt(encrypted, plain); ++ Assert.AreEqual("1x^3 + 2x^2 + 1x^1 + 1", plain.ToString()); ++ } ++ + [TestMethod] + public void FVEncryptRotateMatrixDecryptNET() + { +diff --git a/SEALNETTest/KeyGeneratorWrapper.cs b/SEALNETTest/KeyGeneratorWrapper.cs +index 04f9738..2a74a89 100644 +--- a/SEALNETTest/KeyGeneratorWrapper.cs ++++ b/SEALNETTest/KeyGeneratorWrapper.cs +@@ -1,6 +1,7 @@ + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Microsoft.Research.SEAL; + using System.Collections.Generic; ++using System; + + namespace SEALNETTest + { +@@ -35,6 +36,79 @@ namespace SEALNETTest + keygen.GenerateEvaluationKeys(2, 2, evk); + Assert.AreEqual(evk.HashBlock, parms.HashBlock); + Assert.AreEqual(60, evk.Key(2)[0].Size); ++ ++ var galks = new GaloisKeys(); ++ keygen.GenerateGaloisKeys(60, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.AreEqual(2, galks.Key(3)[0].Size); ++ Assert.AreEqual(10, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.AreEqual(4, galks.Key(3)[0].Size); ++ Assert.AreEqual(10, galks.Size); ++ ++ keygen.GenerateGaloisKeys(2, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.AreEqual(60, galks.Key(3)[0].Size); ++ Assert.AreEqual(10, galks.Size); ++ ++ keygen.GenerateGaloisKeys(60, new List { 1, 3, 5, 7 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(3)); ++ Assert.IsTrue(galks.HasKey(5)); ++ Assert.IsTrue(galks.HasKey(7)); ++ Assert.IsFalse(galks.HasKey(9)); ++ Assert.IsFalse(galks.HasKey(127)); ++ Assert.AreEqual(2, galks.Key(1)[0].Size); ++ Assert.AreEqual(2, galks.Key(3)[0].Size); ++ Assert.AreEqual(2, galks.Key(5)[0].Size); ++ Assert.AreEqual(2, galks.Key(7)[0].Size); ++ Assert.AreEqual(4, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, new List { 1, 3, 5, 7 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(3)); ++ Assert.IsTrue(galks.HasKey(5)); ++ Assert.IsTrue(galks.HasKey(7)); ++ Assert.IsFalse(galks.HasKey(9)); ++ Assert.IsFalse(galks.HasKey(127)); ++ Assert.AreEqual(4, galks.Key(1)[0].Size); ++ Assert.AreEqual(4, galks.Key(3)[0].Size); ++ Assert.AreEqual(4, galks.Key(5)[0].Size); ++ Assert.AreEqual(4, galks.Key(7)[0].Size); ++ Assert.AreEqual(4, galks.Size); ++ ++ keygen.GenerateGaloisKeys(2, new List { 1, 3, 5, 7 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(3)); ++ Assert.IsTrue(galks.HasKey(5)); ++ Assert.IsTrue(galks.HasKey(7)); ++ Assert.IsFalse(galks.HasKey(9)); ++ Assert.IsFalse(galks.HasKey(127)); ++ Assert.AreEqual(60, galks.Key(1)[0].Size); ++ Assert.AreEqual(60, galks.Key(3)[0].Size); ++ Assert.AreEqual(60, galks.Key(5)[0].Size); ++ Assert.AreEqual(60, galks.Key(7)[0].Size); ++ Assert.AreEqual(4, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, new List { 1 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsFalse(galks.HasKey(3)); ++ Assert.IsFalse(galks.HasKey(127)); ++ Assert.AreEqual(4, galks.Key(1)[0].Size); ++ Assert.AreEqual(1, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, new List { 127 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsFalse(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(127)); ++ Assert.AreEqual(4, galks.Key(127)[0].Size); ++ Assert.AreEqual(1, galks.Size); + } + { + parms.NoiseStandardDeviation = 3.19; +@@ -61,6 +135,79 @@ namespace SEALNETTest + keygen.GenerateEvaluationKeys(4, 1, evk); + Assert.AreEqual(evk.HashBlock, parms.HashBlock); + Assert.AreEqual(30, evk.Key(2)[0].Size); ++ ++ var galks = new GaloisKeys(); ++ keygen.GenerateGaloisKeys(60, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.AreEqual(2, galks.Key(3)[0].Size); ++ Assert.AreEqual(14, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.AreEqual(4, galks.Key(3)[0].Size); ++ Assert.AreEqual(14, galks.Size); ++ ++ keygen.GenerateGaloisKeys(2, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.AreEqual(60, galks.Key(3)[0].Size); ++ Assert.AreEqual(14, galks.Size); ++ ++ keygen.GenerateGaloisKeys(60, new List { 1, 3, 5, 7 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(3)); ++ Assert.IsTrue(galks.HasKey(5)); ++ Assert.IsTrue(galks.HasKey(7)); ++ Assert.IsFalse(galks.HasKey(9)); ++ Assert.IsFalse(galks.HasKey(511)); ++ Assert.AreEqual(2, galks.Key(1)[0].Size); ++ Assert.AreEqual(2, galks.Key(3)[0].Size); ++ Assert.AreEqual(2, galks.Key(5)[0].Size); ++ Assert.AreEqual(2, galks.Key(7)[0].Size); ++ Assert.AreEqual(4, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, new List { 1, 3, 5, 7 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(3)); ++ Assert.IsTrue(galks.HasKey(5)); ++ Assert.IsTrue(galks.HasKey(7)); ++ Assert.IsFalse(galks.HasKey(9)); ++ Assert.IsFalse(galks.HasKey(511)); ++ Assert.AreEqual(4, galks.Key(1)[0].Size); ++ Assert.AreEqual(4, galks.Key(3)[0].Size); ++ Assert.AreEqual(4, galks.Key(5)[0].Size); ++ Assert.AreEqual(4, galks.Key(7)[0].Size); ++ Assert.AreEqual(4, galks.Size); ++ ++ keygen.GenerateGaloisKeys(2, new List { 1, 3, 5, 7 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(3)); ++ Assert.IsTrue(galks.HasKey(5)); ++ Assert.IsTrue(galks.HasKey(7)); ++ Assert.IsFalse(galks.HasKey(9)); ++ Assert.IsFalse(galks.HasKey(511)); ++ Assert.AreEqual(60, galks.Key(1)[0].Size); ++ Assert.AreEqual(60, galks.Key(3)[0].Size); ++ Assert.AreEqual(60, galks.Key(5)[0].Size); ++ Assert.AreEqual(60, galks.Key(7)[0].Size); ++ Assert.AreEqual(4, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, new List { 1 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsTrue(galks.HasKey(1)); ++ Assert.IsFalse(galks.HasKey(3)); ++ Assert.IsFalse(galks.HasKey(511)); ++ Assert.AreEqual(4, galks.Key(1)[0].Size); ++ Assert.AreEqual(1, galks.Size); ++ ++ keygen.GenerateGaloisKeys(30, new List { 511 }, galks); ++ Assert.AreEqual(galks.HashBlock, parms.HashBlock); ++ Assert.IsFalse(galks.HasKey(1)); ++ Assert.IsTrue(galks.HasKey(511)); ++ Assert.AreEqual(4, galks.Key(511)[0].Size); ++ Assert.AreEqual(1, galks.Size); + } + } + } +diff --git a/SEALTest/evaluator.cpp b/SEALTest/evaluator.cpp +index 51e083a..edde078 100644 +--- a/SEALTest/evaluator.cpp ++++ b/SEALTest/evaluator.cpp +@@ -964,6 +964,85 @@ namespace SEALTest + Assert::IsTrue(encrypted.hash_block() == parms.hash_block()); + } + ++ TEST_METHOD(FVEncryptApplyGaloisDecrypt) ++ { ++ EncryptionParameters parms; ++ SmallModulus plain_modulus(257); ++ BigPoly poly_modulus("1x^8 + 1"); ++ parms.set_poly_modulus(poly_modulus); ++ parms.set_plain_modulus(plain_modulus); ++ parms.set_coeff_modulus({ small_mods_40bit(0), small_mods_40bit(1) }); ++ SEALContext context(parms); ++ KeyGenerator keygen(context); ++ GaloisKeys glk; ++ keygen.generate_galois_keys(24, { 1, 3, 5, 15 }, glk); ++ ++ Encryptor encryptor(context, keygen.public_key()); ++ Evaluator evaluator(context); ++ Decryptor decryptor(context, keygen.secret_key()); ++ ++ Plaintext plain("1"); ++ Ciphertext encrypted; ++ encryptor.encrypt(plain, encrypted); ++ evaluator.apply_galois(encrypted, 1, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 3, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 5, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 15, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1" == plain.to_string()); ++ ++ plain = "1x^1"; ++ encryptor.encrypt(plain, encrypted); ++ evaluator.apply_galois(encrypted, 1, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 3, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^3" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 5, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("100x^7" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 15, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^1" == plain.to_string()); ++ ++ plain = "1x^2"; ++ encryptor.encrypt(plain, encrypted); ++ evaluator.apply_galois(encrypted, 1, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^2" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 3, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^6" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 5, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("100x^6" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 15, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^2" == plain.to_string()); ++ ++ plain = "1x^3 + 2x^2 + 1x^1 + 1"; ++ encryptor.encrypt(plain, encrypted); ++ evaluator.apply_galois(encrypted, 1, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 3, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("2x^6 + 1x^3 + 100x^1 + 1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 5, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("100x^7 + FFx^6 + 100x^5 + 1" == plain.to_string()); ++ evaluator.apply_galois(encrypted, 15, glk); ++ decryptor.decrypt(encrypted, plain); ++ Assert::IsTrue("1x^3 + 2x^2 + 1x^1 + 1" == plain.to_string()); ++ } ++ + TEST_METHOD(FVEncryptRotateMatrixDecrypt) + { + EncryptionParameters parms; +diff --git a/SEALTest/keygenerator.cpp b/SEALTest/keygenerator.cpp +index b4e15b3..a64b1af 100644 +--- a/SEALTest/keygenerator.cpp ++++ b/SEALTest/keygenerator.cpp +@@ -68,6 +68,79 @@ namespace SEALTest + } + } + } ++ ++ GaloisKeys galks; ++ keygen.generate_galois_keys(60, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::AreEqual(2, galks.key(3)[0].size()); ++ Assert::AreEqual(10, galks.size()); ++ ++ keygen.generate_galois_keys(30, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::AreEqual(4, galks.key(3)[0].size()); ++ Assert::AreEqual(10, galks.size()); ++ ++ keygen.generate_galois_keys(2, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::AreEqual(60, galks.key(3)[0].size()); ++ Assert::AreEqual(10, galks.size()); ++ ++ keygen.generate_galois_keys(60, { 1, 3, 5, 7 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(3)); ++ Assert::IsTrue(galks.has_key(5)); ++ Assert::IsTrue(galks.has_key(7)); ++ Assert::IsFalse(galks.has_key(9)); ++ Assert::IsFalse(galks.has_key(127)); ++ Assert::AreEqual(2, galks.key(1)[0].size()); ++ Assert::AreEqual(2, galks.key(3)[0].size()); ++ Assert::AreEqual(2, galks.key(5)[0].size()); ++ Assert::AreEqual(2, galks.key(7)[0].size()); ++ Assert::AreEqual(4, galks.size()); ++ ++ keygen.generate_galois_keys(30, { 1, 3, 5, 7 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(3)); ++ Assert::IsTrue(galks.has_key(5)); ++ Assert::IsTrue(galks.has_key(7)); ++ Assert::IsFalse(galks.has_key(9)); ++ Assert::IsFalse(galks.has_key(127)); ++ Assert::AreEqual(4, galks.key(1)[0].size()); ++ Assert::AreEqual(4, galks.key(3)[0].size()); ++ Assert::AreEqual(4, galks.key(5)[0].size()); ++ Assert::AreEqual(4, galks.key(7)[0].size()); ++ Assert::AreEqual(4, galks.size()); ++ ++ keygen.generate_galois_keys(2, { 1, 3, 5, 7 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(3)); ++ Assert::IsTrue(galks.has_key(5)); ++ Assert::IsTrue(galks.has_key(7)); ++ Assert::IsFalse(galks.has_key(9)); ++ Assert::IsFalse(galks.has_key(127)); ++ Assert::AreEqual(60, galks.key(1)[0].size()); ++ Assert::AreEqual(60, galks.key(3)[0].size()); ++ Assert::AreEqual(60, galks.key(5)[0].size()); ++ Assert::AreEqual(60, galks.key(7)[0].size()); ++ Assert::AreEqual(4, galks.size()); ++ ++ keygen.generate_galois_keys(30, { 1 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsFalse(galks.has_key(3)); ++ Assert::IsFalse(galks.has_key(127)); ++ Assert::AreEqual(4, galks.key(1)[0].size()); ++ Assert::AreEqual(1, galks.size()); ++ ++ keygen.generate_galois_keys(30, { 127 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsFalse(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(127)); ++ Assert::AreEqual(4, galks.key(127)[0].size()); ++ Assert::AreEqual(1, galks.size()); + } + { + parms.set_noise_standard_deviation(3.19); +@@ -121,6 +194,79 @@ namespace SEALTest + } + } + } ++ ++ GaloisKeys galks; ++ keygen.generate_galois_keys(60, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::AreEqual(2, galks.key(3)[0].size()); ++ Assert::AreEqual(14, galks.size()); ++ ++ keygen.generate_galois_keys(30, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::AreEqual(4, galks.key(3)[0].size()); ++ Assert::AreEqual(14, galks.size()); ++ ++ keygen.generate_galois_keys(2, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::AreEqual(60, galks.key(3)[0].size()); ++ Assert::AreEqual(14, galks.size()); ++ ++ keygen.generate_galois_keys(60, { 1, 3, 5, 7 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(3)); ++ Assert::IsTrue(galks.has_key(5)); ++ Assert::IsTrue(galks.has_key(7)); ++ Assert::IsFalse(galks.has_key(9)); ++ Assert::IsFalse(galks.has_key(511)); ++ Assert::AreEqual(2, galks.key(1)[0].size()); ++ Assert::AreEqual(2, galks.key(3)[0].size()); ++ Assert::AreEqual(2, galks.key(5)[0].size()); ++ Assert::AreEqual(2, galks.key(7)[0].size()); ++ Assert::AreEqual(4, galks.size()); ++ ++ keygen.generate_galois_keys(30, { 1, 3, 5, 7 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(3)); ++ Assert::IsTrue(galks.has_key(5)); ++ Assert::IsTrue(galks.has_key(7)); ++ Assert::IsFalse(galks.has_key(9)); ++ Assert::IsFalse(galks.has_key(511)); ++ Assert::AreEqual(4, galks.key(1)[0].size()); ++ Assert::AreEqual(4, galks.key(3)[0].size()); ++ Assert::AreEqual(4, galks.key(5)[0].size()); ++ Assert::AreEqual(4, galks.key(7)[0].size()); ++ Assert::AreEqual(4, galks.size()); ++ ++ keygen.generate_galois_keys(2, { 1, 3, 5, 7 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(3)); ++ Assert::IsTrue(galks.has_key(5)); ++ Assert::IsTrue(galks.has_key(7)); ++ Assert::IsFalse(galks.has_key(9)); ++ Assert::IsFalse(galks.has_key(511)); ++ Assert::AreEqual(60, galks.key(1)[0].size()); ++ Assert::AreEqual(60, galks.key(3)[0].size()); ++ Assert::AreEqual(60, galks.key(5)[0].size()); ++ Assert::AreEqual(60, galks.key(7)[0].size()); ++ Assert::AreEqual(4, galks.size()); ++ ++ keygen.generate_galois_keys(30, { 1 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsTrue(galks.has_key(1)); ++ Assert::IsFalse(galks.has_key(3)); ++ Assert::IsFalse(galks.has_key(511)); ++ Assert::AreEqual(4, galks.key(1)[0].size()); ++ Assert::AreEqual(1, galks.size()); ++ ++ keygen.generate_galois_keys(30, { 511 }, galks); ++ Assert::IsTrue(galks.hash_block() == parms.hash_block()); ++ Assert::IsFalse(galks.has_key(1)); ++ Assert::IsTrue(galks.has_key(511)); ++ Assert::AreEqual(4, galks.key(511)[0].size()); ++ Assert::AreEqual(1, galks.size()); + } + } + }; +-- +2.14.1 + + +From 1fdfa03edbb50835beedefe26d74f17918c4f1e1 Mon Sep 17 00:00:00 2001 +From: Kim Laine +Date: Mon, 4 Dec 2017 17:31:18 -0800 +Subject: [PATCH 2/3] Added negacyclic_shift_poly_coeffmod + +--- + SEAL/seal/util/polyarithsmallmod.cpp | 28 ++++++--- + SEAL/seal/util/polyarithsmallmod.h | 54 +++++++++++++++++- + SEALTest/util/polyarithsmallmod.cpp | 108 +++++++++++++++++++++++++++++++++++ + 3 files changed, 181 insertions(+), 9 deletions(-) + +diff --git a/SEAL/seal/util/polyarithsmallmod.cpp b/SEAL/seal/util/polyarithsmallmod.cpp +index 5bfeede..4719348 100644 +--- a/SEAL/seal/util/polyarithsmallmod.cpp ++++ b/SEAL/seal/util/polyarithsmallmod.cpp +@@ -407,16 +407,30 @@ namespace seal + } + } + +- uint64_t poly_infty_norm_coeffmod(const std::uint64_t *poly, int poly_coeff_count, const SmallModulus &modulus) ++ uint64_t poly_infty_norm_coeffmod(const std::uint64_t *operand, int coeff_count, const SmallModulus &modulus) + { ++#ifdef SEAL_DEBUG ++ if (operand == nullptr && coeff_count > 0) ++ { ++ throw invalid_argument("operand"); ++ } ++ if (coeff_count < 0) ++ { ++ throw invalid_argument("coeff_count"); ++ } ++ if (modulus.is_zero()) ++ { ++ throw invalid_argument("modulus"); ++ } ++#endif + // Construct negative threshold (first negative modulus value) to compute absolute values of coeffs. + uint64_t modulus_neg_threshold = (modulus.value() + 1) >> 1; + + // Mod out the poly coefficients and choose a symmetric representative from [-modulus,modulus). Keep track of the max. + uint64_t result = 0; +- for (int coeff_index = 0; coeff_index < poly_coeff_count; coeff_index++) ++ for (int coeff_index = 0; coeff_index < coeff_count; coeff_index++) + { +- uint64_t poly_coeff = poly[coeff_index] % modulus.value(); ++ uint64_t poly_coeff = operand[coeff_index] % modulus.value(); + if (poly_coeff >= modulus_neg_threshold) + { + poly_coeff = modulus.value() - poly_coeff; +@@ -594,14 +608,14 @@ namespace seal + return true; + } + +- void exponentiate_poly_polymod_coeffmod(const uint64_t *poly, const uint64_t *exponent, int exponent_uint64_count, const PolyModulus &poly_modulus, const SmallModulus &modulus, uint64_t *result, MemoryPool &pool) ++ void exponentiate_poly_polymod_coeffmod(const uint64_t *operand, const uint64_t *exponent, int exponent_uint64_count, const PolyModulus &poly_modulus, const SmallModulus &modulus, uint64_t *result, MemoryPool &pool) + { + int poly_modulus_coeff_count = poly_modulus.coeff_count(); + #ifdef SEAL_DEBUG + int poly_modulus_coeff_uint64_count = poly_modulus.coeff_uint64_count(); +- if (poly == nullptr) ++ if (operand == nullptr) + { +- throw invalid_argument("poly"); ++ throw invalid_argument("operand"); + } + if (exponent == nullptr) + { +@@ -631,7 +645,7 @@ namespace seal + return; + } + +- modulo_poly(poly, poly_modulus_coeff_count, poly_modulus, modulus, result, pool); ++ modulo_poly(operand, poly_modulus_coeff_count, poly_modulus, modulus, result, pool); + + if (is_equal_uint(exponent, exponent_uint64_count, 1)) + { +diff --git a/SEAL/seal/util/polyarithsmallmod.h b/SEAL/seal/util/polyarithsmallmod.h +index d660439..f081184 100644 +--- a/SEAL/seal/util/polyarithsmallmod.h ++++ b/SEAL/seal/util/polyarithsmallmod.h +@@ -556,14 +556,64 @@ namespace seal + modulo_poly_inplace(result, result_coeff_count, poly_modulus, modulus); + } + +- std::uint64_t poly_infty_norm_coeffmod(const std::uint64_t *poly, int poly_coeff_count, ++ std::uint64_t poly_infty_norm_coeffmod(const std::uint64_t *operand, int coeff_count, + const SmallModulus &modulus); + + bool try_invert_poly_coeffmod(const std::uint64_t *operand, const std::uint64_t *poly_modulus, + int coeff_count, const SmallModulus &modulus, std::uint64_t *result, MemoryPool &pool); + +- void exponentiate_poly_polymod_coeffmod(const std::uint64_t *poly, const std::uint64_t *exponent, ++ void exponentiate_poly_polymod_coeffmod(const std::uint64_t *operand, const std::uint64_t *exponent, + int exponent_uint64_count, const PolyModulus &poly_modulus, const SmallModulus &modulus, + std::uint64_t *result, MemoryPool &pool); ++ ++ inline void negacyclic_shift_poly_coeffmod(const std::uint64_t *operand, int coeff_count, int shift, ++ const SmallModulus &modulus, std::uint64_t *result) ++ { ++#ifdef SEAL_DEBUG ++ if (operand == nullptr && coeff_count > 0) ++ { ++ throw std::invalid_argument("operand"); ++ } ++ if (result == nullptr && coeff_count > 0) ++ { ++ throw std::invalid_argument("result"); ++ } ++ if (operand == result && coeff_count > 0) ++ { ++ throw std::invalid_argument("operand cannot point to the same location as result"); ++ } ++ if (coeff_count < 0) ++ { ++ throw std::invalid_argument("coeff_count"); ++ } ++ if (modulus.is_zero()) ++ { ++ throw std::invalid_argument("modulus"); ++ } ++ if (shift < 0) ++ { ++ throw std::invalid_argument("shift"); ++ } ++ if (util::get_power_of_two(static_cast(coeff_count)) < 0) ++ { ++ throw std::invalid_argument("coeff_count"); ++ } ++#endif ++ std::uint64_t index_raw = shift; ++ std::uint64_t coeff_count_mod_mask = static_cast(coeff_count) - 1; ++ std::uint64_t index; ++ for (int i = 0; i < coeff_count; i++, operand++, index_raw++) ++ { ++ index = index_raw & coeff_count_mod_mask; ++ if (!(index_raw & static_cast(coeff_count)) || (*operand == 0)) ++ { ++ result[index] = *operand; ++ } ++ else ++ { ++ result[index] = modulus.value() - *operand; ++ } ++ } ++ } + } + } +diff --git a/SEALTest/util/polyarithsmallmod.cpp b/SEALTest/util/polyarithsmallmod.cpp +index e917034..93df00b 100644 +--- a/SEALTest/util/polyarithsmallmod.cpp ++++ b/SEALTest/util/polyarithsmallmod.cpp +@@ -470,6 +470,114 @@ namespace SEALTest + Assert::AreEqual(9ULL, result[1]); + Assert::AreEqual(0ULL, result[2]); + } ++ ++ TEST_METHOD(NegacyclicShiftPolyCoeffSmallMod) ++ { ++ MemoryPool &pool = *global_variables::global_memory_pool; ++ Pointer poly(allocate_zero_poly(4, 1, pool)); ++ Pointer result(allocate_zero_poly(4, 1, pool)); ++ ++ SmallModulus mod(10); ++ int coeff_count = 4; ++ ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); ++ Assert::AreEqual(0ULL, result[0]); ++ Assert::AreEqual(0ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); ++ Assert::AreEqual(0ULL, result[0]); ++ Assert::AreEqual(0ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 4, mod, result.get()); ++ Assert::AreEqual(0ULL, result[0]); ++ Assert::AreEqual(0ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 5, mod, result.get()); ++ Assert::AreEqual(0ULL, result[0]); ++ Assert::AreEqual(0ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 8, mod, result.get()); ++ Assert::AreEqual(0ULL, result[0]); ++ Assert::AreEqual(0ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ ++ poly[0] = 1; ++ poly[1] = 2; ++ poly[2] = 3; ++ poly[3] = 4; ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); ++ Assert::AreEqual(1ULL, result[0]); ++ Assert::AreEqual(2ULL, result[1]); ++ Assert::AreEqual(3ULL, result[2]); ++ Assert::AreEqual(4ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); ++ Assert::AreEqual(6ULL, result[0]); ++ Assert::AreEqual(1ULL, result[1]); ++ Assert::AreEqual(2ULL, result[2]); ++ Assert::AreEqual(3ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 4, mod, result.get()); ++ Assert::AreEqual(9ULL, result[0]); ++ Assert::AreEqual(8ULL, result[1]); ++ Assert::AreEqual(7ULL, result[2]); ++ Assert::AreEqual(6ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 5, mod, result.get()); ++ Assert::AreEqual(4ULL, result[0]); ++ Assert::AreEqual(9ULL, result[1]); ++ Assert::AreEqual(8ULL, result[2]); ++ Assert::AreEqual(7ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 8, mod, result.get()); ++ Assert::AreEqual(1ULL, result[0]); ++ Assert::AreEqual(2ULL, result[1]); ++ Assert::AreEqual(3ULL, result[2]); ++ Assert::AreEqual(4ULL, result[3]); ++ ++ poly[0] = 1; ++ poly[1] = 2; ++ poly[2] = 0; ++ poly[3] = 4; ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 0, mod, result.get()); ++ Assert::AreEqual(1ULL, result[0]); ++ Assert::AreEqual(2ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(4ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); ++ Assert::AreEqual(6ULL, result[0]); ++ Assert::AreEqual(1ULL, result[1]); ++ Assert::AreEqual(2ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 4, mod, result.get()); ++ Assert::AreEqual(9ULL, result[0]); ++ Assert::AreEqual(8ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(6ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 5, mod, result.get()); ++ Assert::AreEqual(4ULL, result[0]); ++ Assert::AreEqual(9ULL, result[1]); ++ Assert::AreEqual(8ULL, result[2]); ++ Assert::AreEqual(0ULL, result[3]); ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 8, mod, result.get()); ++ Assert::AreEqual(1ULL, result[0]); ++ Assert::AreEqual(2ULL, result[1]); ++ Assert::AreEqual(0ULL, result[2]); ++ Assert::AreEqual(4ULL, result[3]); ++ ++ poly[0] = 1; ++ poly[1] = 2; ++ poly[2] = 3; ++ poly[3] = 4; ++ coeff_count = 2; ++ negacyclic_shift_poly_coeffmod(poly.get(), coeff_count, 1, mod, result.get()); ++ negacyclic_shift_poly_coeffmod(poly.get() + 2, coeff_count, 1, mod, result.get() + 2); ++ Assert::AreEqual(8ULL, result[0]); ++ Assert::AreEqual(1ULL, result[1]); ++ Assert::AreEqual(6ULL, result[2]); ++ Assert::AreEqual(3ULL, result[3]); ++ } + }; + } + } +\ No newline at end of file +-- +2.14.1 + + +From 9c5a16fb3e8ffc5bae69ba175867d89b774091c5 Mon Sep 17 00:00:00 2001 +From: Sebastian Angel +Date: Tue, 15 May 2018 00:11:24 +0000 +Subject: [PATCH 3/3] enable mutable + +--- + SEAL/seal/util/defines.h | 4 ++-- + 1 file changed, 2 insertions(+), 2 deletions(-) + +diff --git a/SEAL/seal/util/defines.h b/SEAL/seal/util/defines.h +index 20c0bf2..4e1b2fb 100644 +--- a/SEAL/seal/util/defines.h ++++ b/SEAL/seal/util/defines.h +@@ -19,7 +19,7 @@ + // parameter compatibility checks pass in cases where they normally + // should not pass. Please note that it is extremely easy to break + // things by doing this, and the consequences can be unexpected. +-//#define SEAL_EXPOSE_MUTABLE_HASH_BLOCK ++#define SEAL_EXPOSE_MUTABLE_HASH_BLOCK + + // Allow ciphertext data to be directly modified by exposing the + // functions seal::Ciphertext::mutable_pointer(int) and +@@ -28,7 +28,7 @@ + // way of mutating ciphertext data is by allocating memory manually, + // and using aliased ciphertexts pointing to the allocated memory, + // which can then be mutated freely. +-//#define SEAL_EXPOSE_MUTABLE_CIPHERTEXT ++#define SEAL_EXPOSE_MUTABLE_CIPHERTEXT + + // For security reasons one should never throw when decoding fails due + // to overflow, but in some cases this might help in diagnosing problems. +-- +2.14.1 + diff --git a/SealPIR b/SealPIR new file mode 160000 index 0000000..d821285 --- /dev/null +++ b/SealPIR @@ -0,0 +1 @@ +Subproject commit d82128503ec6b6a740edefef30cdff8d0d54481c diff --git a/benches/pir.rs b/benches/pir.rs new file mode 100644 index 0000000..e30cbc2 --- /dev/null +++ b/benches/pir.rs @@ -0,0 +1,161 @@ +#![feature(custom_attribute, custom_derive, plugin)] + +#[macro_use] +extern crate criterion; +extern crate rand; +extern crate sealpir; +extern crate serde; +#[macro_use] +extern crate serde_derive; + +use criterion::Criterion; +use rand::ChaChaRng; +use rand::{RngCore, FromEntropy}; +use sealpir::client::PirClient; +use sealpir::server::PirServer; +use std::time::Duration; + +const SIZE: usize = 288; +const DIM: u32 = 2; +const LOGT: u32 = 23; +const POLY_DEGREE: u32 = 2048; +const NUMS: [u32; 3] = [1 << 16, 1 << 18, 1 << 20]; + +#[derive(Serialize, Clone)] +struct Element { + #[serde(serialize_with = "<[_]>::serialize")] + e: [u8; SIZE], +} + +fn setup(c: &mut Criterion) { + c.bench_function_over_inputs( + &format!("setup_d{}", DIM), + |b, &&num| { + // setup + let mut rng = ChaChaRng::from_entropy(); + let mut collection = vec![]; + for _ in 0..num { + let mut x = [0u8; SIZE]; + rng.fill_bytes(&mut x); + collection.push(x); + } + // measurement + b.iter(|| { + let mut server = PirServer::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + server.setup(&collection); + }) + }, + &NUMS, + ); +} + +fn query(c: &mut Criterion) { + c.bench_function_over_inputs( + &format!("query_d{}", DIM), + |b, &&num| { + // setup + let client = PirClient::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + // measurement + b.iter_with_setup(|| rand::random::() % num, |idx| client.gen_query(idx)); + }, + &NUMS, + ); +} + +fn expand(c: &mut Criterion) { + c.bench_function_over_inputs( + &format!("expand_d{}", DIM), + |b, &&num| { + // setup + let mut rng = ChaChaRng::from_entropy(); + let mut collection = vec![]; + for _ in 0..num { + let mut x = [0u8; SIZE]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let mut server = PirServer::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + let client = PirClient::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + let key = client.get_key(); + server.setup(&collection); + server.set_galois_key(key, 0); + + // measurement + b.iter_with_setup( + || client.gen_query(rand::random::() % num), + |query| server.expand(&query, 0), + ); + }, + &NUMS, + ); +} + +fn reply(c: &mut Criterion) { + c.bench_function_over_inputs( + &format!("reply_d{}", DIM), + |b, &&num| { + // setup + let mut rng = ChaChaRng::from_entropy(); + let mut collection = vec![]; + for _ in 0..num { + let mut x = [0u8; SIZE]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let mut server = PirServer::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + let client = PirClient::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + let key = client.get_key(); + server.setup(&collection); + server.set_galois_key(key, 0); + + // measurement + b.iter_with_setup( + || client.gen_query(rand::random::() % num), + |query| server.gen_reply(&query, 0), + ); + }, + &NUMS, + ); +} + +fn decode(c: &mut Criterion) { + c.bench_function_over_inputs( + &format!("decode_d{}", DIM), + |b, &&num| { + // setup + let mut rng = ChaChaRng::from_entropy(); + let mut collection = vec![]; + for _ in 0..num { + let mut x = [0u8; SIZE]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let mut server = PirServer::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + let client = PirClient::new(num, SIZE as u32, POLY_DEGREE, LOGT, DIM); + let key = client.get_key(); + server.setup(&collection); + server.set_galois_key(key, 0); + let idx = rand::random::() % num; + let query = client.gen_query(idx); + let reply = server.gen_reply(&query, 0); + + // measurement + b.iter(|| client.decode_reply::(idx, &reply)); + }, + &NUMS, + ); +} + +criterion_group! { + name = benches; + config = Criterion::default() + .sample_size(10) + .measurement_time(Duration::new(5, 0)) + .without_plots(); + targets = setup, query, expand, reply, decode +} + +criterion_main!(benches); diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..7b3faa8 --- /dev/null +++ b/build.rs @@ -0,0 +1,26 @@ +extern crate gcc; +use std::env; + +fn main() { + gcc::Build::new() + .file("SealPIR/pir.cpp") + .file("SealPIR/pir_server.cpp") + .file("SealPIR/pir_client.cpp") + .file("sealpir-bindings/pir_rust.cpp") + .include("sealpir-bindings/") + .include("SealPIR/") + .include("deps/SEAL/SEAL/") + .flag("-Wno-unknown-pragmas") + .flag("-Wno-sign-compare") + .flag("-Wno-unused-parameter") + .flag("-std=c++11") + .flag("-fopenmp") + .pic(true) + .cpp(true) + .compile("libsealpir.a"); + + let link_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); + + println!("cargo:rustc-link-search={}/deps/SEAL/bin/", link_dir); + println!("cargo:rustc-link-lib=static=seal"); +} diff --git a/sealpir-bindings/pir_rust.cpp b/sealpir-bindings/pir_rust.cpp new file mode 100644 index 0000000..56b9168 --- /dev/null +++ b/sealpir-bindings/pir_rust.cpp @@ -0,0 +1,154 @@ +#include "pir_rust.hpp" + +void *new_parameters(uint32_t ele_num, uint32_t ele_size, uint32_t N, uint32_t logt, uint32_t d) { + Parameters *param = new Parameters; + gen_params(ele_num, ele_size, N, logt, d, param->params, param->expanded_params, + param->pir_params); + return (void *)param; +} + +void update_parameters(void *params, uint32_t ele_num, uint32_t ele_size, uint32_t d) { + Parameters *param = (Parameters *) params; + update_params(ele_num, ele_size, d, param->params, param->expanded_params, param->pir_params); +} + +void delete_parameters(void *params) { delete ((Parameters *)params); } + +void *new_pir_client(const void *params) { + Parameters *param = (Parameters *)params; + PIRClient *client = new PIRClient(param->params, param->expanded_params, param->pir_params); + return (void *)client; +} + +void update_client_params(void *pir_client, const void *params) { + PIRClient *client = (PIRClient *) pir_client; + Parameters *param = (Parameters *)params; + client->update_parameters(param->expanded_params, param->pir_params); +} + +void delete_pir_client(void *pir_client) { delete ((PIRClient *)pir_client); } + +void *new_pir_server(const void *params) { + Parameters *param = (Parameters *)params; + PIRServer *server = new PIRServer(param->expanded_params, param->pir_params); + return (void *)server; +} + +void update_server_params(void *pir_server, const void *params) { + PIRServer *server = (PIRServer *) pir_server; + Parameters *param = (Parameters *)params; + server->update_parameters(param->expanded_params, param->pir_params); +} + +void delete_pir_server(void *pir_server) { delete ((PIRServer *)pir_server); } + +uint8_t *get_galois_key(const void *pir_client, uint32_t *key_size) { + PIRClient *client = (PIRClient *)pir_client; + seal::GaloisKeys galois = client->generate_galois_keys(); + string ser = serialize_galoiskeys(galois); + + uint32_t size = ser.size(); + uint8_t *out = (uint8_t *)malloc(size); + memcpy(out, ser.data(), size); + *key_size = size; + return out; +} + +void set_galois_key(void *pir_server, const uint8_t *galois_key, uint32_t key_size, + uint32_t client_id) { + PIRServer *server = (PIRServer *)pir_server; + string gal_str = string((const char *)galois_key, key_size); + seal::GaloisKeys *galois = deserialize_galoiskeys(gal_str); + server->set_galois_key(client_id, *galois); + delete galois; +} + +uint32_t get_fv_index(const void *pir_client, uint32_t ele_index, uint32_t ele_size) { + PIRClient *client = (PIRClient *)pir_client; + return client->get_fv_index(ele_index, ele_size); +} + +uint32_t get_fv_offset(const void *pir_client, uint32_t ele_index, uint32_t ele_size) { + PIRClient *client = (PIRClient *)pir_client; + return client->get_fv_offset(ele_index, ele_size); +} + +uint8_t *generate_query(const void *pir_client, uint32_t index, uint32_t *query_size, + uint32_t *query_num) { + PIRClient *client = (PIRClient *)pir_client; + PirQuery query = client->generate_query(index); + *query_num = query.size(); + string ser = serialize_ciphertexts(query); + + uint32_t size = ser.size(); + uint8_t *out = (uint8_t *)malloc(size); + memcpy(out, ser.data(), size); + *query_size = size; + return out; +} + +void expand_query(const void *pir_server, const void *params, const uint8_t *query, + uint32_t query_size, uint32_t query_num, uint32_t client_id) { + + PIRServer *server = (PIRServer *)pir_server; + Parameters *param = (Parameters *)params; + string query_str = string((const char *)query, query_size); + + PirQuery query_des = deserialize_ciphertexts(query_num, query_str, CIPHER_SIZE); + + for (uint32_t i = 0; i < query_num; i++) { + uint32_t m = param->pir_params.nvec[i]; + PirQuery expanded = server->expand_query(query_des[i], m, client_id); + } +} + +uint8_t *generate_reply(const void *pir_server, const uint8_t *query, uint32_t query_size, + uint32_t query_num, uint32_t *reply_size, uint32_t *reply_num, + uint32_t client_id) { + + PIRServer *server = (PIRServer *)pir_server; + + string query_str = string((const char *)query, query_size); + + PirQuery query_des = deserialize_ciphertexts(query_num, query_str, CIPHER_SIZE); + PirReply reply = server->generate_reply(query_des, client_id); + *reply_num = reply.size(); + + string ser = serialize_ciphertexts(reply); + uint32_t size = ser.size(); + uint8_t *out = (uint8_t *)malloc(size); + memcpy(out, ser.data(), size); + *reply_size = size; + return out; +} + +void set_database(void *pir_server, const uint8_t *database, uint32_t ele_num, uint32_t ele_size) { + PIRServer *server = (PIRServer *)pir_server; + server->set_database(database, ele_num, ele_size); +} + +void preprocess_db(void *pir_server) { + PIRServer *server = (PIRServer *)pir_server; + server->preprocess_database(); +} + +uint8_t *decode_reply(const void *pir_client, const void *params, const uint8_t *reply, + uint32_t reply_size, uint32_t reply_num, uint32_t *size) { + + PIRClient *client = (PIRClient *)pir_client; + Parameters *param = (Parameters *)params; + + string reply_str = string((const char *)reply, reply_size); + + PirReply reply_res = deserialize_ciphertexts(reply_num, reply_str, CIPHER_SIZE); + seal::Plaintext result = client->decode_reply(reply_res); + + uint32_t logtp = ceil(log2(param->expanded_params.plain_modulus().value())); + uint32_t N = param->expanded_params.poly_modulus().coeff_count() - 1; + + uint8_t *elems = (uint8_t *)malloc((N * logtp) / 8); + coeffs_to_bytes(logtp, result, elems, (N * logtp) / 8); + + *size = (N * logtp) / 8; + return elems; +} diff --git a/sealpir-bindings/pir_rust.hpp b/sealpir-bindings/pir_rust.hpp new file mode 100644 index 0000000..fe267ae --- /dev/null +++ b/sealpir-bindings/pir_rust.hpp @@ -0,0 +1,80 @@ +#ifndef SEAL_PIR_RUST_H +#define SEAL_PIR_RUST_H + +#include "pir.hpp" +#include "pir_client.hpp" +#include "pir_server.hpp" + +extern "C" { + +struct Parameters { + seal::EncryptionParameters params; + seal::EncryptionParameters expanded_params; + PirParams pir_params; +}; + +// returns a pointer to SealPIR's parameters +void *new_parameters(uint32_t ele_num, uint32_t ele_size, uint32_t N, uint32_t logt, uint32_t d); +void update_parameters(void *params, uint32_t ele_num, uint32_t ele_size, uint32_t d); +void delete_parameters(void *params); + +// Client operations + +// returns a pointer to a PirClient object +void *new_pir_client(const void *params); +void update_client_params(void *pir_client, const void *params); +void delete_pir_client(void *pir_client); + +// returns index of FV plaintext given the index of an element +uint32_t get_fv_index(const void *pir_client, uint32_t ele_index, uint32_t ele_size); + +// returns the offset within an FV plaintext for the given element index +uint32_t get_fv_offset(const void *pir_client, uint32_t ele_index, uint32_t ele_size); + +// get the serialized representation of a galois key +uint8_t *get_galois_key(const void *pir_client, uint32_t *key_size); + +// get the serialized version of a PIR query for the given index +// num: number of ciphertexts making up the query +// query_size: size in bytes +uint8_t *generate_query(const void *pir_client, uint32_t index, uint32_t *query_size, + uint32_t *query_num); + +// decodes the given reply and returns a pointer to the N coefficients +// reply_num: number of ciphertexts making up the query +// reply_size: size in bytes of the reply +// size: size in bytes of the decoded elements +uint8_t *decode_reply(const void *pir_client, const void *param, const uint8_t *reply, + uint32_t reply_size, uint32_t reply_num, uint32_t *size); + +// Server operations + +// returns a pointer to a PirServer object +void *new_pir_server(const void *params); +void update_server_params(void *pir_server, const void *params); +void delete_pir_server(void *pir_server); + +// deserializes the galois key and configures it for the given client +void set_galois_key(void *pir_server, const uint8_t *galois_key, uint32_t key_size, + uint32_t client_id); + +// sets the existing database +void set_database(void *pir_server, const uint8_t *database, uint32_t ele_num, uint32_t ele_size); + +// preprocesses the database +void preprocess_db(void *pir_server); + +// For microbenchmark purposes only (generate_reply does this already) +void expand_query(const void *pir_server, const void *params, const uint8_t *query, + uint32_t query_size, uint32_t query_num, uint32_t client_id); + +// generates a reply for the given client +// query_size: bytes of query +// query_num: number of ciphertexts +// reply_num: number of ciphertexts +// reply_size: bytes of reply +uint8_t *generate_reply(const void *pir_server, const uint8_t *query, uint32_t query_size, + uint32_t query_num, uint32_t *reply_size, uint32_t *reply_num, + uint32_t client_id); +} +#endif diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..f98507e --- /dev/null +++ b/src/client.rs @@ -0,0 +1,149 @@ +use libc; +use std::mem; +use std::slice; + +use super::{PirQuery, PirReply}; + +extern "C" { + fn new_parameters(ele_num: u32, ele_size: u32, N: u32, logt: u32, d: u32) -> *mut libc::c_void; + fn update_parameters(params: *mut libc::c_void, ele_num: u32, ele_size: u32, d: u32); + fn delete_parameters(params: *mut libc::c_void); + + fn new_pir_client(params: *const libc::c_void) -> *mut libc::c_void; + fn update_client_params(pir_client: *mut libc::c_void, params: *const libc::c_void); + fn delete_pir_client(pir_client: *mut libc::c_void); + + fn get_fv_index(pir_client: *const libc::c_void, ele_idx: u32, ele_size: u32) -> u32; + fn get_fv_offset(pir_client: *const libc::c_void, ele_idx: u32, ele_size: u32) -> u32; + + fn get_galois_key(pir_client: *const libc::c_void, key_size: &mut u32) -> *mut u8; + + fn generate_query( + pir_client: *const libc::c_void, + index: u32, + query_size: &mut u32, + query_num: &mut u32, + ) -> *mut u8; + + fn decode_reply( + pir_client: *const libc::c_void, + params: *const libc::c_void, + reply: *const u8, + reply_size: u32, + reply_num: u32, + result_size: &mut u32, + ) -> *mut u8; +} + +pub struct PirClient<'a> { + client: &'a mut libc::c_void, + params: &'a mut libc::c_void, + ele_size: u32, + ele_num: u32, + key: Vec, +} + +impl<'a> Drop for PirClient<'a> { + fn drop(&mut self) { + unsafe { + delete_pir_client(self.client); + delete_parameters(self.params); + } + } +} + +impl<'a> PirClient<'a> { + pub fn new( + ele_num: u32, + ele_size: u32, + poly_degree: u32, + log_plain_mod: u32, + d: u32, + ) -> PirClient<'a> { + let param_ptr: &'a mut libc::c_void = + unsafe { &mut *(new_parameters(ele_num, ele_size, poly_degree, log_plain_mod, d)) }; + + let client_ptr: &'a mut libc::c_void = unsafe { &mut *(new_pir_client(param_ptr)) }; + + let mut key_size: u32 = 0; + + let key: Vec = unsafe { + let ptr = get_galois_key(client_ptr, &mut key_size); + let key = slice::from_raw_parts_mut(ptr as *mut u8, key_size as usize).to_vec(); + libc::free(ptr as *mut libc::c_void); + key + }; + + PirClient { + client: client_ptr, + params: param_ptr, + ele_size, + ele_num, + key, + } + } + + pub fn update_params(&mut self, ele_num: u32, ele_size: u32, d: u32) { + unsafe { + update_parameters(self.params, ele_num, ele_size, d); + update_client_params(self.client, self.params); + } + + self.ele_size = ele_size; + self.ele_num = ele_num; + } + + pub fn get_key(&'a self) -> &'a Vec { + &self.key + } + + pub fn gen_query(&self, index: u32) -> PirQuery { + assert!(index <= self.ele_num); + + let mut query_size: u32 = 0; // # of bytes + let mut query_num: u32 = 0; // # of ciphertexts + + let query: Vec = unsafe { + let fv_index = get_fv_index(self.client, index, self.ele_size); + let ptr = generate_query(self.client, fv_index, &mut query_size, &mut query_num); + let q = slice::from_raw_parts_mut(ptr as *mut u8, query_size as usize).to_vec(); + libc::free(ptr as *mut libc::c_void); + q + }; + + PirQuery { + query, + num: query_num, + } + } + + pub fn decode_reply(&self, ele_index: u32, reply: &PirReply) -> T + where + T: Clone, + { + assert_eq!(self.ele_size as usize, mem::size_of::()); + + let mut result_size: u32 = 0; + let result: T = unsafe { + // returns the content of the FV plaintext + let ptr = decode_reply( + self.client, + self.params, + reply.reply.as_ptr(), + reply.reply.len() as u32, + reply.num, + &mut result_size, + ); + + // offset into the FV plaintext + let offset = get_fv_offset(self.client, ele_index, self.ele_size); + assert!(offset + self.ele_size <= result_size as u32); + + let r = slice::from_raw_parts_mut((ptr as *mut T).offset(offset as isize), 1).to_vec(); + libc::free(ptr as *mut libc::c_void); + r[0].clone() + }; + + result + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..1983978 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,19 @@ +extern crate libc; +extern crate serde; +#[macro_use] +extern crate serde_derive; + +#[derive(Serialize, Deserialize, Clone)] +pub struct PirQuery { + pub query: Vec, + pub num: u32, +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct PirReply { + pub reply: Vec, + pub num: u32, +} + +pub mod client; +pub mod server; diff --git a/src/server.rs b/src/server.rs new file mode 100644 index 0000000..7841194 --- /dev/null +++ b/src/server.rs @@ -0,0 +1,160 @@ +use super::{PirQuery, PirReply}; +use libc; +use std::mem; +use std::slice; + +extern "C" { + fn new_parameters(ele_num: u32, ele_size: u32, N: u32, logt: u32, d: u32) -> *mut libc::c_void; + fn update_parameters(params: *mut libc::c_void, ele_num: u32, ele_size: u32, d: u32); + fn delete_parameters(params: *mut libc::c_void); + + fn new_pir_server(params: *const libc::c_void) -> *mut libc::c_void; + fn update_server_params(pir_server: *mut libc::c_void, params: *const libc::c_void); + fn delete_pir_server(pir_server: *mut libc::c_void); + + fn set_galois_key( + pir_server: *mut libc::c_void, + galois_key: *const u8, + key_size: u32, + client_id: u32, + ); + fn set_database( + pir_server: *mut libc::c_void, + database: *const u8, + ele_num: u32, + ele_size: u32, + ); + + fn preprocess_db(pir_server: *mut libc::c_void); + + // for debugging/benchmark purposes only + fn expand_query( + pir_server: *const libc::c_void, + params: *const libc::c_void, + query: *const u8, + query_size: u32, + query_num: u32, + client_id: u32, + ); + + fn generate_reply( + pir_server: *const libc::c_void, + query: *const u8, + query_size: u32, + query_num: u32, + reply_size: &mut u32, + reply_num: &mut u32, + client_id: u32, + ) -> *mut u8; +} + +pub struct PirServer<'a> { + server: &'a mut libc::c_void, + params: &'a mut libc::c_void, + ele_num: u32, + ele_size: u32, +} + +impl<'a> Drop for PirServer<'a> { + fn drop(&mut self) { + unsafe { + delete_pir_server(self.server); + delete_parameters(self.params); + } + } +} + +impl<'a> PirServer<'a> { + pub fn new( + ele_num: u32, + ele_size: u32, + poly_degree: u32, + log_plain_mod: u32, + d: u32, + ) -> PirServer<'a> { + let params: &'a mut libc::c_void = + unsafe { &mut *(new_parameters(ele_num, ele_size, poly_degree, log_plain_mod, d)) }; + + let server_ptr: &'a mut libc::c_void = unsafe { &mut *(new_pir_server(params)) }; + + PirServer { + server: server_ptr, + params, + ele_num, + ele_size, + } + } + + pub fn update_params(&mut self, ele_num: u32, ele_size: u32, d: u32) { + unsafe { + update_parameters(self.params, ele_num, ele_size, d); + update_server_params(self.server, self.params); + } + + self.ele_size = ele_size; + self.ele_num = ele_num; + } + + pub fn setup(&mut self, collection: &[T]) { + assert_eq!(collection.len(), self.ele_num as usize); + assert_eq!(mem::size_of::(), self.ele_size as usize); + + unsafe { + set_database( + self.server, + collection.as_ptr() as *const u8, + self.ele_num, + self.ele_size, + ); + + preprocess_db(self.server); + } + } + + pub fn set_galois_key(&mut self, key: &[u8], client_id: u32) { + unsafe { + set_galois_key(self.server, key.as_ptr(), key.len() as u32, client_id); + } + } + + #[inline] + pub fn gen_reply(&self, query: &PirQuery, client_id: u32) -> PirReply { + let mut reply_size: u32 = 0; + let mut reply_num: u32 = 0; + + let reply: Vec = unsafe { + let ptr = generate_reply( + self.server, + query.query.as_ptr(), + query.query.len() as u32, + query.num, + &mut reply_size, + &mut reply_num, + client_id, + ); + + let ans = slice::from_raw_parts_mut(ptr as *mut u8, reply_size as usize).to_vec(); + libc::free(ptr as *mut libc::c_void); + ans + }; + + PirReply { + reply, + num: reply_num, + } + } + + // for microbenchmark purposes only + pub fn expand(&self, query: &PirQuery, client_id: u32) { + unsafe { + expand_query( + self.server, + self.params, + query.query.as_ptr(), + query.query.len() as u32, + query.num, + client_id, + ) + } + } +} diff --git a/tests/pir.rs b/tests/pir.rs new file mode 100644 index 0000000..6ea3106 --- /dev/null +++ b/tests/pir.rs @@ -0,0 +1,289 @@ +extern crate rand; +extern crate sealpir; +use rand::RngCore; +use sealpir::client::PirClient; +use sealpir::server::PirServer; + +#[test] +fn pir_very_small_collection_test() { + let poly_degree = 2048; + let log_plain_mod = 19; + let num = 2; + let d = 2; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let truth = collection.clone(); + + let mut server = PirServer::new(8168, 288, poly_degree, log_plain_mod, d); + let mut client = PirClient::new(8168, 288, poly_degree, log_plain_mod, d); + + { + let key = client.get_key(); + println!("Key size {}", key.len()); + server.set_galois_key(key, 0); + } + + client.update_params(num, 288, d); + server.update_params(num, 288, d); + + server.setup(&collection); + + let index = 0; + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + let result = client.decode_reply::<[u8; 288]>(index, &reply); + assert_eq!(&result[..], &truth[index as usize][..]); +} + +#[test] +fn pir_small_collection_test() { + let poly_degree = 2048; + let log_plain_mod = 20; + let num = 100; + let d = 2; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let truth = collection.clone(); + + let mut server = PirServer::new(8168, 288, poly_degree, log_plain_mod, d); + let mut client = PirClient::new(8168, 288, poly_degree, log_plain_mod, d); + + { + let key = client.get_key(); + server.set_galois_key(key, 0); + } + + client.update_params(num, 288, d); + server.update_params(num, 288, d); + + server.setup(&collection); + + let index = 0; + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + let result = client.decode_reply::<[u8; 288]>(index, &reply); + assert_eq!(&result[..], &truth[index as usize][..]); +} + +#[test] +fn pir_medium_collection_test() { + let poly_degree = 2048; + let log_plain_mod = 23; + let num = 1 << 16; + let d = 2; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let truth = collection.clone(); + + let mut server = PirServer::new(8168, 288, poly_degree, log_plain_mod, d); + let mut client = PirClient::new(8168, 288, poly_degree, log_plain_mod, d); + + { + let key = client.get_key(); + server.set_galois_key(key, 0); + } + + client.update_params(num, 288, d); + server.update_params(num, 288, d); + + server.setup(&collection); + + let index = 0; + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + let result = client.decode_reply::<[u8; 288]>(index, &reply); + assert_eq!(&result[..], &truth[index as usize][..]); +} + +#[test] +fn pir_large_collection_test() { + let poly_degree = 2048; + let log_plain_mod = 23; + let num = 1 << 18; + let d = 2; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let truth = collection.clone(); + + let mut server = PirServer::new(num, 288, poly_degree, log_plain_mod, d); + let client = PirClient::new(num, 288, poly_degree, log_plain_mod, d); + let key = client.get_key(); + + server.set_galois_key(key, 0); + server.setup(&collection); + + let index = 0; + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + let result = client.decode_reply::<[u8; 288]>(index, &reply); + assert_eq!(&result[..], &truth[index as usize][..]); +} + +#[test] +fn pir_very_large_collection_test() { + let poly_degree = 2048; + let log_plain_mod = 23; + let num = 1 << 20; + let d = 2; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let truth = collection.clone(); + + let mut server = PirServer::new(8168, 288, poly_degree, log_plain_mod, d); + let mut client = PirClient::new(8168, 288, poly_degree, log_plain_mod, d); + + { + let key = client.get_key(); + println!("Key size {}", key.len()); + server.set_galois_key(key, 0); + } + + client.update_params(num, 288, d); + server.update_params(num, 288, d); + + server.setup(&collection); + + let index = 0; + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + let result = client.decode_reply::<[u8; 288]>(index, &reply); + assert_eq!(&result[..], &truth[index as usize][..]); +} + +#[test] +fn pir_largest_collection_test() { + let poly_degree = 2048; + let log_plain_mod = 24; + let num = 1 << 22; + let d = 2; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + let mut rng = rand::thread_rng(); + + for _ in 0..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + collection.push(x); + } + + let truth = collection.clone(); + + let mut server = PirServer::new(8168, 288, poly_degree, log_plain_mod, d); + let mut client = PirClient::new(8168, 288, poly_degree, log_plain_mod, d); + + { + let key = client.get_key(); + println!("Key size {}", key.len()); + server.set_galois_key(key, 0); + } + + client.update_params(num, 288, d); + server.update_params(num, 288, d); + + server.setup(&collection); + + let index = 0; + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + let result = client.decode_reply::<[u8; 288]>(index, &reply); + assert_eq!(&result[..], &truth[index as usize][..]); +} + +#[test] +fn pir_sizes() { + let size = 288; + let index = 70; + + let mut collection: Vec<[u8; 288]> = Vec::new(); + + let mut rng = rand::thread_rng(); + + let logts = vec![20, 23]; + let ds = vec![2]; + let ns = vec![1 << 16, 1 << 18, 1 << 20]; + + let mut num_prev = 0; + + for num in ns { + for _ in num_prev..num { + let mut x: [u8; 288] = [0; 288]; + rng.fill_bytes(&mut x); + + collection.push(x); + } + + num_prev = num; + + for d in &ds { + for logt in &logts { + let mut server = PirServer::new(num, size, 2048, *logt, *d); + + + let client = PirClient::new(num, size, 2048, *logt, *d); + let galois = client.get_key(); + + server.setup(&collection); + server.set_galois_key(&galois, 0); + + let query = client.gen_query(index); + let reply = server.gen_reply(&query, 0); + + println!( + "query: num {}, logt {}, d {}, size {}", + num, + *logt, + *d, + query.query.len() / 1024 + ); + println!( + "reply num {}, logt {}, d {}, size {}", + num, + *logt, + *d, + reply.reply.len() / 1024 + ); + + } + } + } +}