13
13
// limitations under the License.
14
14
#include < algorithm>
15
15
#include < cmath>
16
+ #include < cstddef>
17
+ #include < cstdint>
16
18
#include < optional>
17
19
#include < string>
18
20
#include < utility>
24
26
#include " absl/strings/string_view.h" // from @com_google_absl
25
27
#include " absl/synchronization/blocking_counter.h" // from @com_google_absl
26
28
#include " absl/types/span.h" // from @com_google_absl
29
+ #include " jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"
27
30
#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
28
31
#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
29
32
#include " pybind11/cast.h" // from @pybind11
@@ -250,8 +253,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
250
253
const absl::string_view stacked_table_name, const bool allow_id_dropping,
251
254
py::array_t <int > row_pointer_buffer, py::array_t <int > embedding_id_buffer,
252
255
py::array_t <int > sample_id_buffer, py::array_t <float > gain_buffer,
253
- py:: array_t <int > max_ids_buffer, py:: array_t <int > max_unique_ids_buffer,
254
- py:: array_t <int > required_buffer_size_per_sc_buffer) {
256
+ absl::Span <int > max_ids_buffer, absl::Span <int > max_unique_ids_buffer,
257
+ absl::Span <int > required_buffer_size_per_sc_buffer) {
255
258
const int num_scs = num_sc_per_device * num_global_devices;
256
259
int batch_size_for_device = 0 ;
257
260
int total_num_coo_tensors = 0 ;
@@ -299,10 +302,6 @@ void PreprocessInputForStackedTablePerLocalDevice(
299
302
auto * embedding_ids_data = embedding_id_buffer.mutable_data ();
300
303
auto * sample_ids_data = sample_id_buffer.mutable_data ();
301
304
auto * gains_data = gain_buffer.mutable_data ();
302
- auto * total_max_ids_per_sc = max_ids_buffer.mutable_data ();
303
- auto * total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data ();
304
- auto * required_buffer_size_per_sc =
305
- required_buffer_size_per_sc_buffer.mutable_data ();
306
305
// The remaining section does not require GIL.
307
306
py::gil_scoped_release release;
308
307
@@ -318,8 +317,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
318
317
stacked_table_metadata[0 ].max_ids_per_partition ,
319
318
stacked_table_metadata[0 ].max_unique_ids_per_partition ,
320
319
stacked_table_name, allow_id_dropping, num_sc_per_device,
321
- total_num_coo_tensors, total_max_ids_per_sc ,
322
- total_max_unique_ids_per_sc, required_buffer_size_per_sc );
320
+ total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer ,
321
+ required_buffer_size_per_sc_buffer );
323
322
for (int i = 0 ; i < num_sc_per_device; ++i) {
324
323
coo_tensors_by_id[i].emplace_back (batch_size_per_sc * (i + 1 ), 0 , 0.0 );
325
324
}
@@ -359,6 +358,13 @@ static inline py::slice GetBufferSliceForGivenDevice(bool has_leading_dimension,
359
358
(start_index + 1 ) * first_dim_size, 1 );
360
359
}
361
360
361
+ static inline absl::Span<int > GetStatsSliceForGivenDevice (
362
+ std::vector<int >& stats, int device_index, int stats_size_per_device) {
363
+ return absl::MakeSpan (stats).subspan (
364
+ device_index * stats_size_per_device,
365
+ (device_index + 1 ) * stats_size_per_device);
366
+ }
367
+
362
368
py::tuple PreprocessSparseDenseMatmulInput (
363
369
py::list features, py::list feature_weights, py::list feature_specs,
364
370
const int local_device_count, const int global_device_count,
@@ -379,9 +385,9 @@ py::tuple PreprocessSparseDenseMatmulInput(
379
385
py::dict lhs_embedding_ids;
380
386
py::dict lhs_sample_ids;
381
387
py::dict lhs_gains;
382
- py::dict max_ids_per_partition;
383
- py::dict max_unique_ids_per_partition;
384
- py::dict required_buffer_sizes;
388
+ FdoStats::FdoStatsPerStackedTable max_ids_per_partition;
389
+ FdoStats::FdoStatsPerStackedTable max_unique_ids_per_partition;
390
+ FdoStats::FdoStatsPerStackedTable required_buffer_sizes;
385
391
const int num_scs = num_sc_per_device * global_device_count;
386
392
const int row_pointers_size_per_sc = std::max (num_scs, 8 );
387
393
@@ -437,15 +443,10 @@ py::tuple PreprocessSparseDenseMatmulInput(
437
443
py::array_t <float > gains_per_device =
438
444
py::array_t <float >(shape_container);
439
445
const int stats_size_per_device = num_scs;
440
- py::array::ShapeContainer stats_shape = GetArrayShapeBasedOnLeadingDim (
441
- /* has_leading_dimension=*/ false , local_device_count,
442
- stats_size_per_device);
443
- py::array_t <int > max_ids_per_partition_per_sc =
444
- py::array_t <int >(stats_shape);
445
- py::array_t <int > max_unique_ids_per_partition_per_sc =
446
- py::array_t <int >(stats_shape);
447
- py::array_t <int > required_buffer_size_per_sc =
448
- py::array_t <int >(stats_shape);
446
+ size_t stats_size = local_device_count * stats_size_per_device;
447
+ std::vector<int > max_ids_per_partition_per_sc (stats_size);
448
+ std::vector<int > max_unique_ids_per_partition_per_sc (stats_size);
449
+ std::vector<int > required_buffer_size_per_sc (stats_size);
449
450
for (int local_device = 0 ; local_device < local_device_count;
450
451
++local_device) {
451
452
// Get the tuple outputs for the current split.
@@ -459,15 +460,14 @@ py::tuple PreprocessSparseDenseMatmulInput(
459
460
embedding_ids_per_device[static_buffer_slice];
460
461
auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
461
462
auto gain_buffer = gains_per_device[static_buffer_slice];
462
- py::slice stats_slice =
463
- GetBufferSliceForGivenDevice (/* has_leading_dimension=*/ false ,
464
- local_device, stats_size_per_device);
465
- auto max_ids_per_partition_per_sc_buffer =
466
- max_ids_per_partition_per_sc[stats_slice];
467
- auto max_unique_ids_per_partition_per_sc_buffer =
468
- max_unique_ids_per_partition_per_sc[stats_slice];
469
- auto required_buffer_size_per_sc_buffer =
470
- required_buffer_size_per_sc[stats_slice];
463
+ auto device_max_ids_per_partition =
464
+ GetStatsSliceForGivenDevice (max_ids_per_partition_per_sc,
465
+ local_device, stats_size_per_device);
466
+ auto device_max_unique_ids_per_partition =
467
+ GetStatsSliceForGivenDevice (max_unique_ids_per_partition_per_sc,
468
+ local_device, stats_size_per_device);
469
+ auto device_required_buffer_size = GetStatsSliceForGivenDevice (
470
+ required_buffer_size_per_sc, local_device, stats_size_per_device);
471
471
PreprocessInputForStackedTablePerLocalDevice (
472
472
stacked_table_metadata, features, feature_weights, local_device,
473
473
local_device_count, coo_buffer_size, row_pointers_size_per_sc,
@@ -477,10 +477,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
477
477
py::cast<py::array_t <int >>(embedding_id_buffer),
478
478
py::cast<py::array_t <int >>(sample_id_buffer),
479
479
py::cast<py::array_t <float >>(gain_buffer),
480
- py::cast<py::array_t <int >>(max_ids_per_partition_per_sc_buffer),
481
- py::cast<py::array_t <int >>(
482
- max_unique_ids_per_partition_per_sc_buffer),
483
- py::cast<py::array_t <int >>(required_buffer_size_per_sc_buffer));
480
+ device_max_ids_per_partition, device_max_unique_ids_per_partition,
481
+ device_required_buffer_size);
484
482
}
485
483
lhs_row_pointers[stacked_table_name.c_str ()] =
486
484
std::move (row_pointers_per_device);
@@ -490,21 +488,21 @@ py::tuple PreprocessSparseDenseMatmulInput(
490
488
std::move (sample_ids_per_device);
491
489
lhs_gains[stacked_table_name.c_str ()] = std::move (gains_per_device);
492
490
max_ids_per_partition[stacked_table_name.c_str ()] =
493
- std::move ( max_ids_per_partition_per_sc) ;
491
+ max_ids_per_partition_per_sc;
494
492
max_unique_ids_per_partition[stacked_table_name.c_str ()] =
495
- std::move ( max_unique_ids_per_partition_per_sc) ;
493
+ max_unique_ids_per_partition_per_sc;
496
494
required_buffer_sizes[stacked_table_name.c_str ()] =
497
- std::move ( required_buffer_size_per_sc) ;
495
+ required_buffer_size_per_sc;
498
496
counter.DecrementCount ();
499
497
});
500
498
}
501
499
counter.Wait ();
502
500
}
503
- py::dict stats;
504
- stats[ " max_ids " ] = max_ids_per_partition;
505
- stats[ " max_unique_ids " ] = max_unique_ids_per_partition;
506
- stats[ " required_buffer_size " ] = std::move ( required_buffer_sizes);
507
-
501
+ FdoStats stats{
502
+ . max_ids_per_partition = max_ids_per_partition,
503
+ . max_unique_ids_per_partition = max_unique_ids_per_partition,
504
+ . required_buffer_sizes = required_buffer_sizes,
505
+ };
508
506
// GIL is held at this point.
509
507
return py::make_tuple (lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids,
510
508
lhs_gains, stats);
0 commit comments