Skip to content

Commit 361acba

Browse files
fix(gpu): refactor
1 parent 4afa5a1 commit 361acba

File tree

1 file changed

+6
-34
lines changed

1 file changed

+6
-34
lines changed

backends/tfhe-cuda-backend/cuda/include/integer/integer_utilities.h

Lines changed: 6 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -378,22 +378,7 @@ template <typename Torus> struct int_radix_lut {
378378
// Allocate LUT
379379
// LUT is used as a trivial encryption and must be initialized outside
380380
// this constructor
381-
for (uint i = 0; i < active_streams.count(); i++) {
382-
auto lut = (Torus *)cuda_malloc_with_size_tracking_async(
383-
num_luts * lut_buffer_size, active_streams.stream(i), active_streams.gpu_index(i),
384-
size_tracker, allocate_gpu_memory);
385-
auto lut_indexes = (Torus *)cuda_malloc_with_size_tracking_async(
386-
lut_indexes_size, active_streams.stream(i), active_streams.gpu_index(i),
387-
size_tracker, allocate_gpu_memory);
388-
// lut_indexes is initialized to 0 by default
389-
// if a different behavior is wanted, it should be rewritten later
390-
cuda_memset_with_size_tracking_async(
391-
lut_indexes, 0, lut_indexes_size, active_streams.stream(i),
392-
active_streams.gpu_index(i), allocate_gpu_memory);
393-
394-
lut_vec.push_back(lut);
395-
lut_indexes_vec.push_back(lut_indexes);
396-
}
381+
allocate_luts_and_indexes(num_radix_blocks, size_tracker);
397382

398383
// lwe_(input/output)_indexes are initialized to range(num_radix_blocks)
399384
// by default
@@ -447,7 +432,7 @@ template <typename Torus> struct int_radix_lut {
447432
// Keyswitch
448433
tmp_lwe_before_ks = new CudaRadixCiphertextFFI;
449434
create_zero_radix_ciphertext_async<Torus>(
450-
streams.stream(0), streams.gpu_index(0), tmp_lwe_before_ks,
435+
active_streams.stream(0), active_streams.gpu_index(0), tmp_lwe_before_ks,
451436
num_radix_blocks, params.big_lwe_dimension, size_tracker,
452437
allocate_gpu_memory);
453438
h_lut_indexes = (Torus *)(calloc(num_radix_blocks, sizeof(Torus)));
@@ -458,7 +443,8 @@ template <typename Torus> struct int_radix_lut {
458443
// constructor to reuse memory
459444
int_radix_lut(CudaStreams streams, int_radix_params params, uint32_t num_luts,
460445
uint32_t num_radix_blocks, int_radix_lut *base_lut_object,
461-
bool allocate_gpu_memory, uint64_t &size_tracker) {
446+
bool allocate_gpu_memory, uint64_t &size_tracker) :
447+
active_streams(streams.active_gpu_subset(num_radix_blocks)) {
462448

463449
this->params = params;
464450
this->num_blocks = num_radix_blocks;
@@ -489,22 +475,8 @@ template <typename Torus> struct int_radix_lut {
489475
// Allocate LUT
490476
// LUT is used as a trivial encryption and must be initialized outside
491477
// this constructor
492-
active_streams = streams.active_gpu_subset(num_radix_blocks);
493-
for (uint i = 0; i < active_streams.count(); i++) {
494-
auto lut = (Torus *)cuda_malloc_with_size_tracking_async(
495-
num_luts * lut_buffer_size, streams.stream(i), streams.gpu_index(i),
496-
size_tracker, allocate_gpu_memory);
497-
auto lut_indexes = (Torus *)cuda_malloc_with_size_tracking_async(
498-
lut_indexes_size, streams.stream(i), streams.gpu_index(i),
499-
size_tracker, allocate_gpu_memory);
500-
// lut_indexes is initialized to 0 by default
501-
// if a different behavior is wanted, it should be rewritten later
502-
cuda_memset_with_size_tracking_async(
503-
lut_indexes, 0, lut_indexes_size, streams.stream(i),
504-
streams.gpu_index(i), allocate_gpu_memory);
505-
lut_vec.push_back(lut);
506-
lut_indexes_vec.push_back(lut_indexes);
507-
}
478+
479+
allocate_luts_and_indexes(num_radix_blocks, size_tracker);
508480

509481
// lwe_(input/output)_indexes are initialized to range(num_radix_blocks)
510482
// by default

0 commit comments

Comments
 (0)