Skip to content

Commit

Permalink
chore(shortint): add multi bit GPU alias
Browse files Browse the repository at this point in the history
- add easy access to compact PK tuniform params
  • Loading branch information
IceTDrinker committed Apr 9, 2024
1 parent d1fe49f commit bea9b77
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 19 deletions.
14 changes: 7 additions & 7 deletions tfhe/benches/core_crypto/ks_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ use serde::Serialize;
use tfhe::boolean::prelude::*;
use tfhe::core_crypto::prelude::*;
use tfhe::keycache::NamedParam;
#[cfg(feature = "gpu")]
use tfhe::shortint::parameters::{
PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
};
#[cfg(not(feature = "gpu"))]
use tfhe::shortint::parameters::{
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS,
Expand All @@ -16,11 +21,6 @@ use tfhe::shortint::parameters::{
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS,
};
#[cfg(feature = "gpu")]
use tfhe::shortint::parameters::{
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
};
use tfhe::shortint::prelude::*;
use tfhe::shortint::{MultiBitPBSParameters, PBSParameters};

Expand All @@ -43,8 +43,8 @@ const SHORTINT_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 6] = [

#[cfg(feature = "gpu")]
const SHORTINT_MULTI_BIT_BENCH_PARAMS: [MultiBitPBSParameters; 2] = [
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
];

const BOOLEAN_BENCH_PARAMS: [(&str, BooleanParameters); 2] = [
Expand Down
4 changes: 2 additions & 2 deletions tfhe/benches/core_crypto/pbs_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ fn multi_bit_benchmark_parameters<Scalar: UnsignedInteger + Default>(
if Scalar::BITS == 64 {
let parameters = if cfg!(feature = "gpu") {
vec![
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
]
} else {
vec![
Expand Down
4 changes: 2 additions & 2 deletions tfhe/benches/integer/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ impl Default for ParamsAndNumBlocksIter {

if env_config.is_multi_bit {
#[cfg(feature = "gpu")]
let params = vec![PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS.into()];
let params = vec![PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS.into()];
#[cfg(not(feature = "gpu"))]
let params = vec![PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS.into()];
let params = vec![PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS.into()];

let params_and_bit_sizes = iproduct!(params, env_config.bit_sizes());
Self {
Expand Down
6 changes: 3 additions & 3 deletions tfhe/benches/integer/signed_bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use std::vec::IntoIter;
use tfhe::integer::keycache::KEY_CACHE;
use tfhe::integer::{IntegerKeyKind, RadixCiphertext, ServerKey, SignedRadixCiphertext, I256};
use tfhe::keycache::NamedParam;
#[cfg(feature = "gpu")]
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
use tfhe::shortint::parameters::PARAM_MESSAGE_2_CARRY_2_KS_PBS;
#[cfg(not(feature = "gpu"))]
use tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS;
#[cfg(feature = "gpu")]
use tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;

fn gen_random_i256(rng: &mut ThreadRng) -> I256 {
let clearlow = rng.gen::<u128>();
Expand All @@ -37,7 +37,7 @@ impl Default for ParamsAndNumBlocksIter {

if env_config.is_multi_bit {
#[cfg(feature = "gpu")]
let params = vec![PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS.into()];
let params = vec![PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS.into()];
#[cfg(not(feature = "gpu"))]
let params = vec![PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS.into()];

Expand Down
8 changes: 4 additions & 4 deletions tfhe/docs/guides/run_on_gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,19 +123,19 @@ TFHE-rs includes the possibility to leverage the high number of threads given by
To do so, the configuration should be updated with

```Rust
let config = ConfigBuilder::with_custom_parameters(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, None).build();
let config = ConfigBuilder::with_custom_parameters(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, None).build();
```

The complete example becomes:

```rust
use tfhe::{ConfigBuilder, set_server_key, FheUint8, ClientKey, CompressedServerKey};
use tfhe::prelude::*;
use tfhe::shortint::parameters::PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;
use tfhe::shortint::parameters::PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;

fn main() {

let config = ConfigBuilder::with_custom_parameters(PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, None).build();
let config = ConfigBuilder::with_custom_parameters(PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS, None).build();

let client_key= ClientKey::generate(config);
let compressed_server_key = CompressedServerKey::new(&client_key);
Expand Down Expand Up @@ -201,7 +201,7 @@ All operations follow the same syntax than the one described in [here](../gettin
# Benchmarks

All GPU benchmarks presented here were obtained on a single H100 GPU, and rely on the multithreaded PBS algorithm.
The cryptographic parameters PARAM\_MULTI\_BIT\_MESSAGE\_1\_CARRY\_2\_GROUP\_3\_KS\_PBS were used.
The cryptographic parameters PARAM\_GPU\_MULTI\_BIT\_MESSAGE\_1\_CARRY\_2\_GROUP\_3\_KS\_PBS were used.
Performance is the following when the inputs of the benchmarked operation are encrypted:

| Operation \ Size | `FheUint7` | `FheUint16` | `FheUint32` | `FheUint64` | `FheUint128` | `FheUint256` |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ macro_rules! create_gpu_parametrized_test{
PARAM_MESSAGE_2_CARRY_2_KS_PBS,
// PARAM_MESSAGE_3_CARRY_3_KS_PBS,
// PARAM_MESSAGE_4_CARRY_4_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS
PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS
});
};
}
Expand Down
8 changes: 8 additions & 0 deletions tfhe/src/shortint/keycache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ named_params_impl!( ShortintParameterSet =>
PARAM_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS,
// MultiBit Group 2 GPU
PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS,
// MultiBit Group 3 GPU
PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS,
PARAM_GPU_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS,
// CPK
PARAM_MESSAGE_1_CARRY_1_COMPACT_PK_KS_PBS,
PARAM_MESSAGE_1_CARRY_2_COMPACT_PK_KS_PBS,
Expand Down
1 change: 1 addition & 0 deletions tfhe/src/shortint/parameters/classic/compact_pk/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::shortint::parameters::classic::compact_pk::gaussian::p_fail_2_minus_4
ks_pbs, pbs_ks,
};
use crate::shortint::ClassicPBSParameters;
pub use tuniform::p_fail_2_minus_40::ks_pbs::*;

pub mod gaussian;
pub mod tuniform;
Expand Down
15 changes: 15 additions & 0 deletions tfhe/src/shortint/parameters/multi_bit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use crate::core_crypto::commons::parameters::{
use crate::core_crypto::prelude::LweCiphertextParameters;
use crate::shortint::ciphertext::{Degree, MaxNoiseLevel, NoiseLevel};
use crate::shortint::parameters::p_fail_2_minus_40::ks_pbs::*;
use crate::shortint::parameters::p_fail_2_minus_40::ks_pbs_gpu::*;
use crate::shortint::parameters::{
CarryModulus, CiphertextModulus, EncryptionKeyChoice, LweBskGroupingFactor, MessageModulus,
};
Expand Down Expand Up @@ -117,3 +118,17 @@ pub const DEFAULT_MULTI_BIT_GROUP_2: MultiBitPBSParameters =
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS;
pub const DEFAULT_MULTI_BIT_GROUP_3: MultiBitPBSParameters =
PARAM_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS;

// GPU
pub const PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_2_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M40;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_2_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M40;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_2_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M40;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_1_CARRY_1_GROUP_3_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M40;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_2_CARRY_2_GROUP_3_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M40;
pub const PARAM_GPU_MULTI_BIT_MESSAGE_3_CARRY_3_GROUP_3_KS_PBS: MultiBitPBSParameters =
PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M40;
156 changes: 156 additions & 0 deletions tfhe/src/shortint/parameters/multi_bit/p_fail_2_minus_40/ks_pbs_gpu.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
use crate::core_crypto::prelude::*;
use crate::shortint::ciphertext::MaxNoiseLevel;
use crate::shortint::parameters::multi_bit::MultiBitPBSParameters;
use crate::shortint::parameters::{CarryModulus, MessageModulus};

// Group 2
pub const PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M40:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(764),
glwe_dimension: GlweDimension(3),
polynomial_size: PolynomialSize(512),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.000006025673585415336,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000000000039666089171633006,
)),
pbs_base_log: DecompositionBaseLog(18),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(6),
ks_level: DecompositionLevelCount(2),
message_modulus: MessageModulus(2),
carry_modulus: CarryModulus(2),
max_noise_level: MaxNoiseLevel::new(3),
log2_p_fail: -40.,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(2),
deterministic_execution: false,
};

pub const PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M40:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(818),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.000002226459789930014,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000000000000003152931493498455,
)),
pbs_base_log: DecompositionBaseLog(22),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(5),
ks_level: DecompositionLevelCount(3),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
max_noise_level: MaxNoiseLevel::new(5),
log2_p_fail: -40.,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(2),
deterministic_execution: false,
};

pub const PARAM_GPU_MULTI_BIT_GROUP_2_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M40:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(922),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(8192),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000003272369292345697,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000000000000000002168404344971009,
)),
pbs_base_log: DecompositionBaseLog(14),
pbs_level: DecompositionLevelCount(2),
ks_base_log: DecompositionBaseLog(4),
ks_level: DecompositionLevelCount(4),
message_modulus: MessageModulus(8),
carry_modulus: CarryModulus(8),
max_noise_level: MaxNoiseLevel::new(9),
log2_p_fail: -40.,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(2),
deterministic_execution: false,
};

// Group 3
pub const PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_1_CARRY_1_KS_PBS_GAUSSIAN_2M40:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(765),
glwe_dimension: GlweDimension(3),
polynomial_size: PolynomialSize(512),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.000005915594083804978,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000000000039666089171633006,
)),
pbs_base_log: DecompositionBaseLog(18),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(6),
ks_level: DecompositionLevelCount(2),
message_modulus: MessageModulus(2),
carry_modulus: CarryModulus(2),
max_noise_level: MaxNoiseLevel::new(3),
log2_p_fail: -40.,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(3),
deterministic_execution: false,
};

pub const PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M40:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(888),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(2048),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000006125031601933181,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000000000000003152931493498455,
)),
pbs_base_log: DecompositionBaseLog(21),
pbs_level: DecompositionLevelCount(1),
ks_base_log: DecompositionBaseLog(7),
ks_level: DecompositionLevelCount(2),
message_modulus: MessageModulus(4),
carry_modulus: CarryModulus(4),
max_noise_level: MaxNoiseLevel::new(5),
log2_p_fail: -40.,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(3),
deterministic_execution: false,
};

pub const PARAM_GPU_MULTI_BIT_GROUP_3_MESSAGE_3_CARRY_3_KS_PBS_GAUSSIAN_2M40:
MultiBitPBSParameters = MultiBitPBSParameters {
lwe_dimension: LweDimension(972),
glwe_dimension: GlweDimension(1),
polynomial_size: PolynomialSize(8192),
lwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.00000013016688349592805,
)),
glwe_noise_distribution: DynamicDistribution::new_gaussian_from_std_dev(StandardDev(
0.0000000000000000002168404344971009,
)),
pbs_base_log: DecompositionBaseLog(14),
pbs_level: DecompositionLevelCount(2),
ks_base_log: DecompositionBaseLog(6),
ks_level: DecompositionLevelCount(3),
message_modulus: MessageModulus(8),
carry_modulus: CarryModulus(8),
max_noise_level: MaxNoiseLevel::new(9),
log2_p_fail: -40.,
ciphertext_modulus: CiphertextModulus::new_native(),
encryption_key_choice: EncryptionKeyChoice::Big,
grouping_factor: LweBskGroupingFactor(3),
deterministic_execution: false,
};
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod ks_pbs;
pub mod ks_pbs_gpu;

0 comments on commit bea9b77

Please sign in to comment.