13
13
// limitations under the License.
14
14
#include < algorithm>
15
15
#include < cmath>
16
+ #include < functional>
16
17
#include < limits>
17
18
#include < optional>
18
19
#include < string>
19
20
#include < utility>
21
+ #include < variant>
20
22
#include < vector>
21
23
22
24
#include " absl/container/flat_hash_map.h" // from @com_google_absl
25
27
#include " absl/strings/string_view.h" // from @com_google_absl
26
28
#include " absl/synchronization/blocking_counter.h" // from @com_google_absl
27
29
#include " absl/types/span.h" // from @com_google_absl
30
+ #include " third_party/eigen3/Eigen/Core"
28
31
#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
29
32
#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
30
33
#include " pybind11/cast.h" // from @pybind11
34
+ #include " pybind11/eigen.h" // from @pybind11
31
35
#include " pybind11/gil.h" // from @pybind11
32
36
#include " pybind11/numpy.h" // from @pybind11
33
37
#include " pybind11/pybind11.h" // from @pybind11
34
38
#include " pybind11/pytypes.h" // from @pybind11
39
+ #include " pybind11/stl.h" // from @pybind11
35
40
#include " tsl/profiler/lib/connected_traceme.h"
36
41
#include " tsl/profiler/lib/traceme.h"
37
42
@@ -40,6 +45,10 @@ namespace jax_sc_embedding {
40
45
namespace {
41
46
42
47
namespace py = ::pybind11;
48
+ using MatrixXi =
49
+ Eigen::Matrix<int , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
50
+ using MatrixXf =
51
+ Eigen::Matrix<float , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
43
52
44
53
namespace {
45
54
@@ -248,19 +257,20 @@ GetStackedTableMetadata(py::list feature_specs, py::list features) {
248
257
// a table that has no parent in the table stacking hierarchy. So in the case
249
258
// of table stacking, the stacked table is the top level table and in the case
250
259
// where we don't have any table stacking, the table itself is top level.
251
- //
252
- // IMPORTANT: Assumes that GIL is held.
253
260
void PreprocessInputForStackedTablePerLocalDevice (
254
261
const absl::Span<const StackedTableMetadata> stacked_table_metadata,
255
262
py::list features, py::list feature_weights, const int local_device_id,
256
263
const int local_device_count, const int coo_buffer_size,
257
264
const int row_pointers_size_per_sc, const int num_global_devices,
258
265
const int num_sc_per_device, const int sharding_strategy,
259
266
const absl::string_view stacked_table_name, const bool allow_id_dropping,
260
- py::array_t <int > row_pointer_buffer, py::array_t <int > embedding_id_buffer,
261
- py::array_t <int > sample_id_buffer, py::array_t <float > gain_buffer,
262
- py::array_t <int > max_ids_buffer, py::array_t <int > max_unique_ids_buffer,
263
- py::array_t <int > required_buffer_size_per_sc_buffer) {
267
+ Eigen::Ref<Eigen::VectorXi> row_pointer_buffer,
268
+ Eigen::Ref<Eigen::VectorXi> embedding_id_buffer,
269
+ Eigen::Ref<Eigen::VectorXi> sample_id_buffer,
270
+ Eigen::Ref<Eigen::VectorXf> gain_buffer,
271
+ Eigen::Ref<Eigen::VectorXi> max_ids_buffer,
272
+ Eigen::Ref<Eigen::VectorXi> max_unique_ids_buffer,
273
+ Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer) {
264
274
const int num_scs = num_sc_per_device * num_global_devices;
265
275
int batch_size_for_device = 0 ;
266
276
int total_num_coo_tensors = 0 ;
@@ -302,16 +312,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
302
312
feature_split, feature_weights_split, row_offset, col_offset, col_shift,
303
313
num_scs, num_global_devices, metadata.row_combiner , coo_tensors);
304
314
}
305
- row_pointer_buffer[py::make_tuple (py::ellipsis ())] = coo_buffer_size;
306
-
307
- auto * row_pointer_data = row_pointer_buffer.mutable_data ();
308
- auto * embedding_ids_data = embedding_id_buffer.mutable_data ();
309
- auto * sample_ids_data = sample_id_buffer.mutable_data ();
310
- auto * gains_data = gain_buffer.mutable_data ();
311
- auto * total_max_ids_per_sc = max_ids_buffer.mutable_data ();
312
- auto * total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data ();
313
- auto * required_buffer_size_per_sc =
314
- required_buffer_size_per_sc_buffer.mutable_data ();
315
+ row_pointer_buffer.setConstant (coo_buffer_size);
316
+
315
317
// The remaining section does not require GIL.
316
318
py::gil_scoped_release release;
317
319
@@ -327,11 +329,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
327
329
stacked_table_metadata[0 ].max_ids_per_partition ,
328
330
stacked_table_metadata[0 ].max_unique_ids_per_partition ,
329
331
stacked_table_name, allow_id_dropping, num_sc_per_device,
330
- total_num_coo_tensors, total_max_ids_per_sc ,
331
- total_max_unique_ids_per_sc, required_buffer_size_per_sc );
332
+ total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer ,
333
+ required_buffer_size_per_sc_buffer );
332
334
for (int i = 0 ; i < num_sc_per_device; ++i) {
333
335
coo_tensors_by_id[i].emplace_back (batch_size_per_sc * (i + 1 ), 0 , 0.0 );
334
- required_buffer_size_per_sc [i]++;
336
+ required_buffer_size_per_sc_buffer [i]++;
335
337
}
336
338
//
337
339
// Step 3: Compute the row pointers for each group of IDs.
@@ -340,35 +342,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
340
342
const int coo_buffer_size_per_sc = coo_buffer_size / num_sc_per_device;
341
343
FillRowPointersPerLocalDevice (
342
344
coo_tensors_by_id, row_pointers_size_per_sc, coo_buffer_size_per_sc,
343
- batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_data ,
344
- embedding_ids_data, sample_ids_data, gains_data );
345
+ batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_buffer ,
346
+ embedding_id_buffer, sample_id_buffer, gain_buffer );
345
347
}
346
348
}
347
349
348
- // Helper function to get the shape container for the output arrays.
349
- // If `has_leading_dimension` is true, the shape will be
350
- // [local_device_count, inner_dim_size]. Otherwise, the shape will be
351
- // [local_device_count * inner_dim_size].
352
- static inline py::array::ShapeContainer GetArrayShapeBasedOnLeadingDim (
353
- bool has_leading_dimension, int local_device_count, int inner_dim_size) {
354
- return has_leading_dimension
355
- ? py::array::ShapeContainer ({local_device_count, inner_dim_size})
356
- : py::array::ShapeContainer ({local_device_count * inner_dim_size});
357
- }
358
-
359
- // Helper function to get the slice for a given device.
360
- // If `has_leading_dimension` is true, the slice will be
361
- // [device_index:device_index+1]. Otherwise, the slice will be
362
- // [device_index * first_dim_size:(device_index + 1) * first_dim_size].
363
- static inline py::slice GetBufferSliceForGivenDevice (bool has_leading_dimension,
364
- int start_index,
365
- int first_dim_size) {
366
- return has_leading_dimension
367
- ? py::slice (start_index, start_index + 1 , 1 )
368
- : py::slice (start_index * first_dim_size,
369
- (start_index + 1 ) * first_dim_size, 1 );
370
- }
371
-
372
350
py::tuple PreprocessSparseDenseMatmulInput (
373
351
py::list features, py::list feature_weights, py::list feature_specs,
374
352
const int local_device_count, const int global_device_count,
@@ -429,81 +407,63 @@ py::tuple PreprocessSparseDenseMatmulInput(
429
407
const int coo_buffer_size_per_device = ComputeCooBufferSizePerDevice (
430
408
num_scs, num_sc_per_device, stacked_table_metadata);
431
409
432
- // Acquire GIL before creating Python arrays.
433
- py::gil_scoped_acquire acq;
434
- py::array_t <int > row_pointers_per_device =
435
- py::array_t <int >(GetArrayShapeBasedOnLeadingDim (
436
- has_leading_dimension, local_device_count,
437
- row_pointers_size_per_sc * num_sc_per_device));
438
-
439
- py::array::ShapeContainer shape_container =
440
- GetArrayShapeBasedOnLeadingDim (has_leading_dimension,
441
- local_device_count,
442
- coo_buffer_size_per_device);
443
- py::array_t <int > embedding_ids_per_device =
444
- py::array_t <int >(shape_container);
445
- py::array_t <int > sample_ids_per_device =
446
- py::array_t <int >(shape_container);
447
- py::array_t <float > gains_per_device =
448
- py::array_t <float >(shape_container);
410
+ MatrixXi row_pointers_per_device (
411
+ local_device_count, row_pointers_size_per_sc * num_sc_per_device);
412
+ MatrixXi embedding_ids_per_device (local_device_count,
413
+ coo_buffer_size_per_device);
414
+ MatrixXi sample_ids_per_device (local_device_count,
415
+ coo_buffer_size_per_device);
416
+ MatrixXf gains_per_device (local_device_count,
417
+ coo_buffer_size_per_device);
418
+
449
419
const int stats_size_per_device = num_scs;
450
420
// NOTE: max ids and max unique ids are {global_sc_count *
451
421
// num_devices}, where they are then aggregated(max) along device
452
422
// dimension to get {global_sc_count} (i.e. max [unique] ids for each
453
423
// sc), which can be further aggregated(max) for a single value for
454
424
// all SCs.
455
- py::array::ShapeContainer max_ids_stats_shape =
456
- GetArrayShapeBasedOnLeadingDim (
457
- /* has_leading_dimension=*/ false , local_device_count,
458
- stats_size_per_device);
459
- py::array_t <int > max_ids_per_partition_per_sc =
460
- py::array_t <int >(max_ids_stats_shape);
461
- py::array_t <int > max_unique_ids_per_partition_per_sc =
462
- py::array_t <int >(max_ids_stats_shape);
425
+ MatrixXi max_ids_per_partition_per_sc (local_device_count,
426
+ stats_size_per_device);
427
+ MatrixXi max_unique_ids_per_partition_per_sc (local_device_count,
428
+ stats_size_per_device);
463
429
// NOTE: required buffer size is {local_sc_count * num_devices}, which
464
430
// is same as {global_sc_count}, and can be further aggregated to get
465
431
// the maximum size of any SC buffer shard.
466
- py::array_t <int > required_buffer_size_per_sc =
467
- py::array_t <int >(GetArrayShapeBasedOnLeadingDim (
468
- false , local_device_count, num_sc_per_device));
432
+ MatrixXi required_buffer_size_per_sc (local_device_count,
433
+ num_sc_per_device);
469
434
for (int local_device = 0 ; local_device < local_device_count;
470
435
++local_device) {
471
436
// Get the tuple outputs for the current split.
472
- auto row_pointer_buffer =
473
- row_pointers_per_device[GetBufferSliceForGivenDevice (
474
- has_leading_dimension, local_device,
475
- row_pointers_size_per_sc * num_sc_per_device)];
476
- py::slice static_buffer_slice = GetBufferSliceForGivenDevice (
477
- has_leading_dimension, local_device, coo_buffer_size_per_device);
478
- auto embedding_id_buffer =
479
- embedding_ids_per_device[static_buffer_slice];
480
- auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
481
- auto gain_buffer = gains_per_device[static_buffer_slice];
482
- py::slice max_ids_stats_slice =
483
- GetBufferSliceForGivenDevice (/* has_leading_dimension=*/ false ,
484
- local_device, stats_size_per_device);
485
- auto max_ids_per_partition_per_sc_buffer =
486
- max_ids_per_partition_per_sc[max_ids_stats_slice];
487
- auto max_unique_ids_per_partition_per_sc_buffer =
488
- max_unique_ids_per_partition_per_sc[max_ids_stats_slice];
489
-
490
- auto required_buffer_size_per_sc_buffer =
491
- required_buffer_size_per_sc[GetBufferSliceForGivenDevice (
492
- false , local_device, num_sc_per_device)];
437
+ Eigen::Ref<Eigen::VectorXi> row_pointer_buffer =
438
+ row_pointers_per_device.row (local_device);
439
+ Eigen::Ref<Eigen::VectorXi> embedding_id_buffer =
440
+ embedding_ids_per_device.row (local_device);
441
+ Eigen::Ref<Eigen::VectorXi> sample_id_buffer =
442
+ sample_ids_per_device.row (local_device);
443
+ Eigen::Ref<Eigen::VectorXf> gain_buffer =
444
+ gains_per_device.row (local_device);
445
+ Eigen::Ref<Eigen::VectorXi> max_ids_per_partition_per_sc_buffer =
446
+ max_ids_per_partition_per_sc.row (local_device);
447
+ Eigen::Ref<Eigen::VectorXi>
448
+ max_unique_ids_per_partition_per_sc_buffer =
449
+ max_unique_ids_per_partition_per_sc.row (local_device);
450
+ Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer =
451
+ required_buffer_size_per_sc.row (local_device);
452
+
453
+ // Acquire GIL
454
+ py::gil_scoped_acquire acq;
493
455
PreprocessInputForStackedTablePerLocalDevice (
494
456
stacked_table_metadata, features, feature_weights, local_device,
495
457
local_device_count, coo_buffer_size_per_device,
496
458
row_pointers_size_per_sc, global_device_count, num_sc_per_device,
497
459
sharding_strategy, stacked_table_name, allow_id_dropping,
498
- py::cast<py::array_t <int >>(row_pointer_buffer),
499
- py::cast<py::array_t <int >>(embedding_id_buffer),
500
- py::cast<py::array_t <int >>(sample_id_buffer),
501
- py::cast<py::array_t <float >>(gain_buffer),
502
- py::cast<py::array_t <int >>(max_ids_per_partition_per_sc_buffer),
503
- py::cast<py::array_t <int >>(
504
- max_unique_ids_per_partition_per_sc_buffer),
505
- py::cast<py::array_t <int >>(required_buffer_size_per_sc_buffer));
460
+ row_pointer_buffer, embedding_id_buffer, sample_id_buffer,
461
+ gain_buffer, max_ids_per_partition_per_sc_buffer,
462
+ max_unique_ids_per_partition_per_sc_buffer,
463
+ required_buffer_size_per_sc_buffer);
506
464
}
465
+ // Acquire GIL before updating Python dicts.
466
+ py::gil_scoped_acquire acq;
507
467
lhs_row_pointers[stacked_table_name.c_str ()] =
508
468
std::move (row_pointers_per_device);
509
469
lhs_embedding_ids[stacked_table_name.c_str ()] =
@@ -517,14 +477,29 @@ py::tuple PreprocessSparseDenseMatmulInput(
517
477
std::move (max_unique_ids_per_partition_per_sc);
518
478
required_buffer_sizes[stacked_table_name.c_str ()] =
519
479
std::move (required_buffer_size_per_sc);
480
+ // To be eventually extracted out of the library
481
+ if (!has_leading_dimension) {
482
+ for (auto & vec : {lhs_row_pointers, lhs_embedding_ids, lhs_gains,
483
+ lhs_sample_ids}) {
484
+ vec[stacked_table_name.c_str ()] =
485
+ py::cast<py::array>(vec[stacked_table_name.c_str ()])
486
+ .reshape ({-1 });
487
+ }
488
+ }
489
+ for (auto & vec : {max_ids_per_partition, max_unique_ids_per_partition,
490
+ required_buffer_sizes}) {
491
+ vec[stacked_table_name.c_str ()] =
492
+ py::cast<py::array>(vec[stacked_table_name.c_str ()])
493
+ .reshape ({-1 });
494
+ }
520
495
counter.DecrementCount ();
521
496
});
522
497
}
523
498
counter.Wait ();
524
499
}
525
500
py::dict stats;
526
- stats[" max_ids" ] = max_ids_per_partition;
527
- stats[" max_unique_ids" ] = max_unique_ids_per_partition;
501
+ stats[" max_ids" ] = std::move ( max_ids_per_partition) ;
502
+ stats[" max_unique_ids" ] = std::move ( max_unique_ids_per_partition) ;
528
503
stats[" required_buffer_size" ] = std::move (required_buffer_sizes);
529
504
530
505
// GIL is held at this point.
0 commit comments