13
13
// limitations under the License.
14
14
#include < algorithm>
15
15
#include < cmath>
16
+ #include < functional>
16
17
#include < limits>
17
18
#include < optional>
18
19
#include < string>
19
20
#include < utility>
21
+ #include < variant>
20
22
#include < vector>
21
23
22
24
#include " absl/container/flat_hash_map.h" // from @com_google_absl
25
27
#include " absl/strings/string_view.h" // from @com_google_absl
26
28
#include " absl/synchronization/blocking_counter.h" // from @com_google_absl
27
29
#include " absl/types/span.h" // from @com_google_absl
30
+ #include " third_party/eigen3/Eigen/Core"
28
31
#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
29
32
#include " jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
30
33
#include " pybind11/cast.h" // from @pybind11
34
+ #include " pybind11/eigen.h" // from @pybind11
31
35
#include " pybind11/gil.h" // from @pybind11
32
36
#include " pybind11/numpy.h" // from @pybind11
33
37
#include " pybind11/pybind11.h" // from @pybind11
34
38
#include " pybind11/pytypes.h" // from @pybind11
39
+ #include " pybind11/stl.h" // from @pybind11
35
40
#include " tsl/profiler/lib/connected_traceme.h"
36
41
#include " tsl/profiler/lib/traceme.h"
37
42
@@ -40,6 +45,10 @@ namespace jax_sc_embedding {
40
45
namespace {
41
46
42
47
namespace py = ::pybind11;
48
+ using MatrixXi =
49
+ Eigen::Matrix<int , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
50
+ using MatrixXf =
51
+ Eigen::Matrix<float , Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
43
52
44
53
namespace {
45
54
@@ -126,16 +135,15 @@ int ExtractCooTensorsFrom1dArray(const py::array& features,
126
135
const int global_device_count,
127
136
const RowCombiner combiner,
128
137
std::vector<CooFormat>& coo_tensors) {
138
+ // We use proxy objects to the python array for the remainder of the function
139
+ // and can hence release the GIL.
140
+ py::gil_scoped_release release_gil;
129
141
// The assumption here is that the gains are always represented as 32bit
130
142
// float arrays (np array with dtype=np.float32) and the features are always
131
143
// represented as 32bit int arrays (np array with dtype=np.int32).
132
144
auto f = features.unchecked <py::array_t <int >, 1 >();
133
145
auto fw = feature_weights.unchecked <py::array_t <float >, 1 >();
134
146
135
- // We use proxy objects to the python array for the remainder of the function
136
- // and can hence release the GIL.
137
- py::gil_scoped_release release_gil;
138
-
139
147
coo_tensors.reserve (f.shape (0 ));
140
148
int coo_tensors_extracted = 0 ;
141
149
@@ -257,10 +265,13 @@ void PreprocessInputForStackedTablePerLocalDevice(
257
265
const int row_pointers_size_per_sc, const int num_global_devices,
258
266
const int num_sc_per_device, const int sharding_strategy,
259
267
const absl::string_view stacked_table_name, const bool allow_id_dropping,
260
- py::array_t <int > row_pointer_buffer, py::array_t <int > embedding_id_buffer,
261
- py::array_t <int > sample_id_buffer, py::array_t <float > gain_buffer,
262
- py::array_t <int > max_ids_buffer, py::array_t <int > max_unique_ids_buffer,
263
- py::array_t <int > required_buffer_size_per_sc_buffer) {
268
+ Eigen::Ref<Eigen::VectorXi> row_pointer_buffer,
269
+ Eigen::Ref<Eigen::VectorXi> embedding_id_buffer,
270
+ Eigen::Ref<Eigen::VectorXi> sample_id_buffer,
271
+ Eigen::Ref<Eigen::VectorXf> gain_buffer,
272
+ Eigen::Ref<Eigen::VectorXi> max_ids_buffer,
273
+ Eigen::Ref<Eigen::VectorXi> max_unique_ids_buffer,
274
+ Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer) {
264
275
const int num_scs = num_sc_per_device * num_global_devices;
265
276
int batch_size_for_device = 0 ;
266
277
int total_num_coo_tensors = 0 ;
@@ -302,19 +313,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
302
313
feature_split, feature_weights_split, row_offset, col_offset, col_shift,
303
314
num_scs, num_global_devices, metadata.row_combiner , coo_tensors);
304
315
}
305
- row_pointer_buffer[py::make_tuple (py::ellipsis ())] = coo_buffer_size;
306
-
307
- auto * row_pointer_data = row_pointer_buffer.mutable_data ();
308
- auto * embedding_ids_data = embedding_id_buffer.mutable_data ();
309
- auto * sample_ids_data = sample_id_buffer.mutable_data ();
310
- auto * gains_data = gain_buffer.mutable_data ();
311
- auto * total_max_ids_per_sc = max_ids_buffer.mutable_data ();
312
- auto * total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data ();
313
- auto * required_buffer_size_per_sc =
314
- required_buffer_size_per_sc_buffer.mutable_data ();
315
316
// The remaining section does not require GIL.
316
317
py::gil_scoped_release release;
317
318
319
+ row_pointer_buffer.setConstant (coo_buffer_size);
320
+
318
321
//
319
322
// Step 2: Sort the COO tensors and group them by SC.
320
323
//
@@ -327,11 +330,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
327
330
stacked_table_metadata[0 ].max_ids_per_partition ,
328
331
stacked_table_metadata[0 ].max_unique_ids_per_partition ,
329
332
stacked_table_name, allow_id_dropping, num_sc_per_device,
330
- total_num_coo_tensors, total_max_ids_per_sc ,
331
- total_max_unique_ids_per_sc, required_buffer_size_per_sc );
333
+ total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer ,
334
+ required_buffer_size_per_sc_buffer );
332
335
for (int i = 0 ; i < num_sc_per_device; ++i) {
333
336
coo_tensors_by_id[i].emplace_back (batch_size_per_sc * (i + 1 ), 0 , 0.0 );
334
- required_buffer_size_per_sc [i]++;
337
+ required_buffer_size_per_sc_buffer [i]++;
335
338
}
336
339
//
337
340
// Step 3: Compute the row pointers for each group of IDs.
@@ -340,35 +343,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
340
343
const int coo_buffer_size_per_sc = coo_buffer_size / num_sc_per_device;
341
344
FillRowPointersPerLocalDevice (
342
345
coo_tensors_by_id, row_pointers_size_per_sc, coo_buffer_size_per_sc,
343
- batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_data ,
344
- embedding_ids_data, sample_ids_data, gains_data );
346
+ batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_buffer ,
347
+ embedding_id_buffer, sample_id_buffer, gain_buffer );
345
348
}
346
349
}
347
350
348
- // Helper function to get the shape container for the output arrays.
349
- // If `has_leading_dimension` is true, the shape will be
350
- // [local_device_count, inner_dim_size]. Otherwise, the shape will be
351
- // [local_device_count * inner_dim_size].
352
- static inline py::array::ShapeContainer GetArrayShapeBasedOnLeadingDim (
353
- bool has_leading_dimension, int local_device_count, int inner_dim_size) {
354
- return has_leading_dimension
355
- ? py::array::ShapeContainer ({local_device_count, inner_dim_size})
356
- : py::array::ShapeContainer ({local_device_count * inner_dim_size});
357
- }
358
-
359
- // Helper function to get the slice for a given device.
360
- // If `has_leading_dimension` is true, the slice will be
361
- // [device_index:device_index+1]. Otherwise, the slice will be
362
- // [device_index * first_dim_size:(device_index + 1) * first_dim_size].
363
- static inline py::slice GetBufferSliceForGivenDevice (bool has_leading_dimension,
364
- int start_index,
365
- int first_dim_size) {
366
- return has_leading_dimension
367
- ? py::slice (start_index, start_index + 1 , 1 )
368
- : py::slice (start_index * first_dim_size,
369
- (start_index + 1 ) * first_dim_size, 1 );
370
- }
371
-
372
351
py::tuple PreprocessSparseDenseMatmulInput (
373
352
py::list features, py::list feature_weights, py::list feature_specs,
374
353
const int local_device_count, const int global_device_count,
@@ -429,81 +408,63 @@ py::tuple PreprocessSparseDenseMatmulInput(
429
408
const int coo_buffer_size_per_device = ComputeCooBufferSizePerDevice (
430
409
num_scs, num_sc_per_device, stacked_table_metadata);
431
410
432
- // Acquire GIL before creating Python arrays.
433
- py::gil_scoped_acquire acq;
434
- py::array_t <int > row_pointers_per_device =
435
- py::array_t <int >(GetArrayShapeBasedOnLeadingDim (
436
- has_leading_dimension, local_device_count,
437
- row_pointers_size_per_sc * num_sc_per_device));
438
-
439
- py::array::ShapeContainer shape_container =
440
- GetArrayShapeBasedOnLeadingDim (has_leading_dimension,
441
- local_device_count,
442
- coo_buffer_size_per_device);
443
- py::array_t <int > embedding_ids_per_device =
444
- py::array_t <int >(shape_container);
445
- py::array_t <int > sample_ids_per_device =
446
- py::array_t <int >(shape_container);
447
- py::array_t <float > gains_per_device =
448
- py::array_t <float >(shape_container);
411
+ MatrixXi row_pointers_per_device (
412
+ local_device_count, row_pointers_size_per_sc * num_sc_per_device);
413
+ MatrixXi embedding_ids_per_device (local_device_count,
414
+ coo_buffer_size_per_device);
415
+ MatrixXi sample_ids_per_device (local_device_count,
416
+ coo_buffer_size_per_device);
417
+ MatrixXf gains_per_device (local_device_count,
418
+ coo_buffer_size_per_device);
419
+
449
420
const int stats_size_per_device = num_scs;
450
421
// NOTE: max ids and max unique ids are {global_sc_count *
451
422
// num_devices}, where they are then aggregated(max) along device
452
423
// dimension to get {global_sc_count} (i.e. max [unique] ids for each
453
424
// sc), which can be further aggregated(max) for a single value for
454
425
// all SCs.
455
- py::array::ShapeContainer max_ids_stats_shape =
456
- GetArrayShapeBasedOnLeadingDim (
457
- /* has_leading_dimension=*/ false , local_device_count,
458
- stats_size_per_device);
459
- py::array_t <int > max_ids_per_partition_per_sc =
460
- py::array_t <int >(max_ids_stats_shape);
461
- py::array_t <int > max_unique_ids_per_partition_per_sc =
462
- py::array_t <int >(max_ids_stats_shape);
426
+ MatrixXi max_ids_per_partition_per_sc (local_device_count,
427
+ stats_size_per_device);
428
+ MatrixXi max_unique_ids_per_partition_per_sc (local_device_count,
429
+ stats_size_per_device);
463
430
// NOTE: required buffer size is {local_sc_count * num_devices}, which
464
431
// is same as {global_sc_count}, and can be further aggregated to get
465
432
// the maximum size of any SC buffer shard.
466
- py::array_t <int > required_buffer_size_per_sc =
467
- py::array_t <int >(GetArrayShapeBasedOnLeadingDim (
468
- false , local_device_count, num_sc_per_device));
433
+ MatrixXi required_buffer_size_per_sc (local_device_count,
434
+ num_sc_per_device);
469
435
for (int local_device = 0 ; local_device < local_device_count;
470
436
++local_device) {
471
437
// Get the tuple outputs for the current split.
472
- auto row_pointer_buffer =
473
- row_pointers_per_device[GetBufferSliceForGivenDevice (
474
- has_leading_dimension, local_device,
475
- row_pointers_size_per_sc * num_sc_per_device)];
476
- py::slice static_buffer_slice = GetBufferSliceForGivenDevice (
477
- has_leading_dimension, local_device, coo_buffer_size_per_device);
478
- auto embedding_id_buffer =
479
- embedding_ids_per_device[static_buffer_slice];
480
- auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
481
- auto gain_buffer = gains_per_device[static_buffer_slice];
482
- py::slice max_ids_stats_slice =
483
- GetBufferSliceForGivenDevice (/* has_leading_dimension=*/ false ,
484
- local_device, stats_size_per_device);
485
- auto max_ids_per_partition_per_sc_buffer =
486
- max_ids_per_partition_per_sc[max_ids_stats_slice];
487
- auto max_unique_ids_per_partition_per_sc_buffer =
488
- max_unique_ids_per_partition_per_sc[max_ids_stats_slice];
489
-
490
- auto required_buffer_size_per_sc_buffer =
491
- required_buffer_size_per_sc[GetBufferSliceForGivenDevice (
492
- false , local_device, num_sc_per_device)];
438
+ Eigen::Ref<Eigen::VectorXi> row_pointer_buffer =
439
+ row_pointers_per_device.row (local_device);
440
+ Eigen::Ref<Eigen::VectorXi> embedding_id_buffer =
441
+ embedding_ids_per_device.row (local_device);
442
+ Eigen::Ref<Eigen::VectorXi> sample_id_buffer =
443
+ sample_ids_per_device.row (local_device);
444
+ Eigen::Ref<Eigen::VectorXf> gain_buffer =
445
+ gains_per_device.row (local_device);
446
+ Eigen::Ref<Eigen::VectorXi> max_ids_per_partition_per_sc_buffer =
447
+ max_ids_per_partition_per_sc.row (local_device);
448
+ Eigen::Ref<Eigen::VectorXi>
449
+ max_unique_ids_per_partition_per_sc_buffer =
450
+ max_unique_ids_per_partition_per_sc.row (local_device);
451
+ Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer =
452
+ required_buffer_size_per_sc.row (local_device);
453
+
454
+ // Acquire GIL
455
+ py::gil_scoped_acquire acq;
493
456
PreprocessInputForStackedTablePerLocalDevice (
494
457
stacked_table_metadata, features, feature_weights, local_device,
495
458
local_device_count, coo_buffer_size_per_device,
496
459
row_pointers_size_per_sc, global_device_count, num_sc_per_device,
497
460
sharding_strategy, stacked_table_name, allow_id_dropping,
498
- py::cast<py::array_t <int >>(row_pointer_buffer),
499
- py::cast<py::array_t <int >>(embedding_id_buffer),
500
- py::cast<py::array_t <int >>(sample_id_buffer),
501
- py::cast<py::array_t <float >>(gain_buffer),
502
- py::cast<py::array_t <int >>(max_ids_per_partition_per_sc_buffer),
503
- py::cast<py::array_t <int >>(
504
- max_unique_ids_per_partition_per_sc_buffer),
505
- py::cast<py::array_t <int >>(required_buffer_size_per_sc_buffer));
461
+ row_pointer_buffer, embedding_id_buffer, sample_id_buffer,
462
+ gain_buffer, max_ids_per_partition_per_sc_buffer,
463
+ max_unique_ids_per_partition_per_sc_buffer,
464
+ required_buffer_size_per_sc_buffer);
506
465
}
466
+ // Acquire GIL before updating Python dicts.
467
+ py::gil_scoped_acquire acq;
507
468
lhs_row_pointers[stacked_table_name.c_str ()] =
508
469
std::move (row_pointers_per_device);
509
470
lhs_embedding_ids[stacked_table_name.c_str ()] =
@@ -517,14 +478,29 @@ py::tuple PreprocessSparseDenseMatmulInput(
517
478
std::move (max_unique_ids_per_partition_per_sc);
518
479
required_buffer_sizes[stacked_table_name.c_str ()] =
519
480
std::move (required_buffer_size_per_sc);
481
+ // To be eventually extracted out of the library
482
+ if (!has_leading_dimension) {
483
+ for (auto & vec : {lhs_row_pointers, lhs_embedding_ids, lhs_gains,
484
+ lhs_sample_ids}) {
485
+ vec[stacked_table_name.c_str ()] =
486
+ py::cast<py::array>(vec[stacked_table_name.c_str ()])
487
+ .reshape ({-1 });
488
+ }
489
+ }
490
+ for (auto & vec : {max_ids_per_partition, max_unique_ids_per_partition,
491
+ required_buffer_sizes}) {
492
+ vec[stacked_table_name.c_str ()] =
493
+ py::cast<py::array>(vec[stacked_table_name.c_str ()])
494
+ .reshape ({-1 });
495
+ }
520
496
counter.DecrementCount ();
521
497
});
522
498
}
523
499
counter.Wait ();
524
500
}
525
501
py::dict stats;
526
- stats[" max_ids" ] = max_ids_per_partition;
527
- stats[" max_unique_ids" ] = max_unique_ids_per_partition;
502
+ stats[" max_ids" ] = std::move ( max_ids_per_partition) ;
503
+ stats[" max_unique_ids" ] = std::move ( max_unique_ids_per_partition) ;
528
504
stats[" required_buffer_size" ] = std::move (required_buffer_sizes);
529
505
530
506
// GIL is held at this point.
0 commit comments