diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cuh index aea747c785..0325011689 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/cmux.cuh @@ -14,27 +14,14 @@ __host__ void zero_out_if(cudaStream_t const *streams, cudaSetDevice(gpu_indexes[0]); auto params = mem_ptr->params; - int big_lwe_size = params.big_lwe_dimension + 1; - - // Left message is shifted - int num_blocks = 0, num_threads = 0; - int num_entries = (params.big_lwe_dimension + 1); - getNumBlocksAndThreads(num_entries, 512, num_blocks, num_threads); - // We can't use integer_radix_apply_bivariate_lookup_table_kb since the - // second operand is fixed + // second operand is not an array auto tmp_lwe_array_input = mem_ptr->tmp; - for (int i = 0; i < num_radix_blocks; i++) { - auto lwe_array_out_block = tmp_lwe_array_input + i * big_lwe_size; - auto lwe_array_input_block = lwe_array_input + i * big_lwe_size; - - device_pack_bivariate_blocks - <<>>( - lwe_array_out_block, predicate->lwe_indexes_in, - lwe_array_input_block, lwe_condition, predicate->lwe_indexes_in, - params.big_lwe_dimension, params.message_modulus, 1); - check_cuda_error(cudaGetLastError()); - } + pack_bivariate_blocks_with_single_block( + streams, gpu_indexes, gpu_count, tmp_lwe_array_input, + predicate->lwe_indexes_in, lwe_array_input, lwe_condition, + predicate->lwe_indexes_in, params.big_lwe_dimension, + params.message_modulus, num_radix_blocks); integer_radix_apply_univariate_lookup_table_kb( streams, gpu_indexes, gpu_count, lwe_array_out, tmp_lwe_array_input, bsks, diff --git a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh index 8560b94c8d..28993ed406 100644 --- a/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh +++ b/backends/tfhe-cuda-backend/cuda/src/integer/integer.cuh @@ -142,8 +142,10 @@ device_pack_bivariate_blocks(Torus *lwe_array_out, Torus const *lwe_indexes_out, int block_id = tid / (lwe_dimension + 1); int coeff_id = tid % (lwe_dimension + 1); - int pos_in = lwe_indexes_in[block_id] * (lwe_dimension + 1) + coeff_id; - int pos_out = lwe_indexes_out[block_id] * (lwe_dimension + 1) + coeff_id; + const int pos_in = + lwe_indexes_in[block_id] * (lwe_dimension + 1) + coeff_id; + const int pos_out = + lwe_indexes_out[block_id] * (lwe_dimension + 1) + coeff_id; lwe_array_out[pos_out] = lwe_array_1[pos_in] * shift + lwe_array_2[pos_in]; } } @@ -172,6 +174,50 @@ pack_bivariate_blocks(cudaStream_t const *streams, uint32_t const *gpu_indexes, check_cuda_error(cudaGetLastError()); } +// polynomial_size threads +template +__global__ void device_pack_bivariate_blocks_with_single_block( + Torus *lwe_array_out, Torus const *lwe_indexes_out, + Torus const *lwe_array_1, Torus const *lwe_2, Torus const *lwe_indexes_in, + uint32_t lwe_dimension, uint32_t shift, uint32_t num_blocks) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + + if (tid < num_blocks * (lwe_dimension + 1)) { + int block_id = tid / (lwe_dimension + 1); + int coeff_id = tid % (lwe_dimension + 1); + + const int pos_in = + lwe_indexes_in[block_id] * (lwe_dimension + 1) + coeff_id; + const int pos_out = + lwe_indexes_out[block_id] * (lwe_dimension + 1) + coeff_id; + lwe_array_out[pos_out] = lwe_array_1[pos_in] * shift + lwe_2[coeff_id]; + } +} + +/* Combine lwe_array_1 and lwe_2 so that each block m1 and lwe_2 + * becomes out = m1 * shift + lwe_2 + * + * This is for the special case when one of the operands is not an array + */ +template +__host__ void pack_bivariate_blocks_with_single_block( + cudaStream_t const *streams, uint32_t const *gpu_indexes, + uint32_t gpu_count, Torus *lwe_array_out, Torus const *lwe_indexes_out, + Torus const *lwe_array_1, Torus const *lwe_2, Torus const *lwe_indexes_in, + uint32_t lwe_dimension, uint32_t shift, uint32_t num_radix_blocks) { + + cudaSetDevice(gpu_indexes[0]); + // Left message is shifted + int num_blocks = 0, num_threads = 0; + int num_entries = num_radix_blocks * (lwe_dimension + 1); + getNumBlocksAndThreads(num_entries, 512, num_blocks, num_threads); + device_pack_bivariate_blocks_with_single_block + <<>>( + lwe_array_out, lwe_indexes_out, lwe_array_1, lwe_2, lwe_indexes_in, + lwe_dimension, shift, num_radix_blocks); + check_cuda_error(cudaGetLastError()); +} + template __host__ void integer_radix_apply_univariate_lookup_table_kb( cudaStream_t const *streams, uint32_t const *gpu_indexes,