Skip to content

Commit 419aeb4

Browse files
Replace numpy arrays with Eigen Matrix
PiperOrigin-RevId: 764419490
1 parent 40f8ee3 commit 419aeb4

File tree

5 files changed

+143
-155
lines changed

5 files changed

+143
-155
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ cc_library(
5050
srcs = ["input_preprocessing_util.cc"],
5151
hdrs = ["input_preprocessing_util.h"],
5252
deps = [
53+
"//third_party/eigen3",
5354
"@com_google_absl//absl/algorithm:container",
5455
"@com_google_absl//absl/log",
5556
"@com_google_absl//absl/log:check",
@@ -68,6 +69,7 @@ cc_test(
6869
env = {"JAX_PLATFORMS": "cpu"},
6970
deps = [
7071
":input_preprocessing_util",
72+
"//third_party/eigen3",
7173
"@com_google_googletest//:gtest_main",
7274
],
7375
)
@@ -78,11 +80,14 @@ pybind_extension(
7880
deps = [
7981
":input_preprocessing_threads",
8082
":input_preprocessing_util",
83+
"//third_party/eigen3",
8184
"@com_google_absl//absl/container:flat_hash_map",
85+
"@com_google_absl//absl/log",
8286
"@com_google_absl//absl/log:check",
8387
"@com_google_absl//absl/strings",
8488
"@com_google_absl//absl/synchronization",
8589
"@com_google_absl//absl/types:span",
90+
"@pybind11//:pybind11_eigen",
8691
"@tsl//tsl/profiler/lib:connected_traceme",
8792
"@tsl//tsl/profiler/lib:traceme",
8893
],

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 79 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
// limitations under the License.
1414
#include <algorithm>
1515
#include <cmath>
16+
#include <functional>
1617
#include <limits>
1718
#include <optional>
1819
#include <string>
1920
#include <utility>
21+
#include <variant>
2022
#include <vector>
2123

2224
#include "absl/container/flat_hash_map.h" // from @com_google_absl
@@ -25,13 +27,16 @@
2527
#include "absl/strings/string_view.h" // from @com_google_absl
2628
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
2729
#include "absl/types/span.h" // from @com_google_absl
30+
#include "third_party/eigen3/Eigen/Core"
2831
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
2932
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
3033
#include "pybind11/cast.h" // from @pybind11
34+
#include "pybind11/eigen.h" // from @pybind11
3135
#include "pybind11/gil.h" // from @pybind11
3236
#include "pybind11/numpy.h" // from @pybind11
3337
#include "pybind11/pybind11.h" // from @pybind11
3438
#include "pybind11/pytypes.h" // from @pybind11
39+
#include "pybind11/stl.h" // from @pybind11
3540
#include "tsl/profiler/lib/connected_traceme.h"
3641
#include "tsl/profiler/lib/traceme.h"
3742

@@ -40,6 +45,10 @@ namespace jax_sc_embedding {
4045
namespace {
4146

4247
namespace py = ::pybind11;
48+
using MatrixXi =
49+
Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
50+
using MatrixXf =
51+
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
4352

4453
namespace {
4554

@@ -248,19 +257,20 @@ GetStackedTableMetadata(py::list feature_specs, py::list features) {
248257
// a table that has no parent in the table stacking hierarchy. So in the case
249258
// of table stacking, the stacked table is the top level table and in the case
250259
// where we don't have any table stacking, the table itself is top level.
251-
//
252-
// IMPORTANT: Assumes that GIL is held.
253260
void PreprocessInputForStackedTablePerLocalDevice(
254261
const absl::Span<const StackedTableMetadata> stacked_table_metadata,
255262
py::list features, py::list feature_weights, const int local_device_id,
256263
const int local_device_count, const int coo_buffer_size,
257264
const int row_pointers_size_per_sc, const int num_global_devices,
258265
const int num_sc_per_device, const int sharding_strategy,
259266
const absl::string_view stacked_table_name, const bool allow_id_dropping,
260-
py::array_t<int> row_pointer_buffer, py::array_t<int> embedding_id_buffer,
261-
py::array_t<int> sample_id_buffer, py::array_t<float> gain_buffer,
262-
py::array_t<int> max_ids_buffer, py::array_t<int> max_unique_ids_buffer,
263-
py::array_t<int> required_buffer_size_per_sc_buffer) {
267+
Eigen::Ref<Eigen::VectorXi> row_pointer_buffer,
268+
Eigen::Ref<Eigen::VectorXi> embedding_id_buffer,
269+
Eigen::Ref<Eigen::VectorXi> sample_id_buffer,
270+
Eigen::Ref<Eigen::VectorXf> gain_buffer,
271+
Eigen::Ref<Eigen::VectorXi> max_ids_buffer,
272+
Eigen::Ref<Eigen::VectorXi> max_unique_ids_buffer,
273+
Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer) {
264274
const int num_scs = num_sc_per_device * num_global_devices;
265275
int batch_size_for_device = 0;
266276
int total_num_coo_tensors = 0;
@@ -302,16 +312,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
302312
feature_split, feature_weights_split, row_offset, col_offset, col_shift,
303313
num_scs, num_global_devices, metadata.row_combiner, coo_tensors);
304314
}
305-
row_pointer_buffer[py::make_tuple(py::ellipsis())] = coo_buffer_size;
306-
307-
auto* row_pointer_data = row_pointer_buffer.mutable_data();
308-
auto* embedding_ids_data = embedding_id_buffer.mutable_data();
309-
auto* sample_ids_data = sample_id_buffer.mutable_data();
310-
auto* gains_data = gain_buffer.mutable_data();
311-
auto* total_max_ids_per_sc = max_ids_buffer.mutable_data();
312-
auto* total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data();
313-
auto* required_buffer_size_per_sc =
314-
required_buffer_size_per_sc_buffer.mutable_data();
315+
row_pointer_buffer.setConstant(coo_buffer_size);
316+
315317
// The remaining section does not require GIL.
316318
py::gil_scoped_release release;
317319

@@ -327,11 +329,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
327329
stacked_table_metadata[0].max_ids_per_partition,
328330
stacked_table_metadata[0].max_unique_ids_per_partition,
329331
stacked_table_name, allow_id_dropping, num_sc_per_device,
330-
total_num_coo_tensors, total_max_ids_per_sc,
331-
total_max_unique_ids_per_sc, required_buffer_size_per_sc);
332+
total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer,
333+
required_buffer_size_per_sc_buffer);
332334
for (int i = 0; i < num_sc_per_device; ++i) {
333335
coo_tensors_by_id[i].emplace_back(batch_size_per_sc * (i + 1), 0, 0.0);
334-
required_buffer_size_per_sc[i]++;
336+
required_buffer_size_per_sc_buffer[i]++;
335337
}
336338
//
337339
// Step 3: Compute the row pointers for each group of IDs.
@@ -340,35 +342,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
340342
const int coo_buffer_size_per_sc = coo_buffer_size / num_sc_per_device;
341343
FillRowPointersPerLocalDevice(
342344
coo_tensors_by_id, row_pointers_size_per_sc, coo_buffer_size_per_sc,
343-
batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_data,
344-
embedding_ids_data, sample_ids_data, gains_data);
345+
batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_buffer,
346+
embedding_id_buffer, sample_id_buffer, gain_buffer);
345347
}
346348
}
347349

348-
// Helper function to get the shape container for the output arrays.
349-
// If `has_leading_dimension` is true, the shape will be
350-
// [local_device_count, inner_dim_size]. Otherwise, the shape will be
351-
// [local_device_count * inner_dim_size].
352-
static inline py::array::ShapeContainer GetArrayShapeBasedOnLeadingDim(
353-
bool has_leading_dimension, int local_device_count, int inner_dim_size) {
354-
return has_leading_dimension
355-
? py::array::ShapeContainer({local_device_count, inner_dim_size})
356-
: py::array::ShapeContainer({local_device_count * inner_dim_size});
357-
}
358-
359-
// Helper function to get the slice for a given device.
360-
// If `has_leading_dimension` is true, the slice will be
361-
// [device_index:device_index+1]. Otherwise, the slice will be
362-
// [device_index * first_dim_size:(device_index + 1) * first_dim_size].
363-
static inline py::slice GetBufferSliceForGivenDevice(bool has_leading_dimension,
364-
int start_index,
365-
int first_dim_size) {
366-
return has_leading_dimension
367-
? py::slice(start_index, start_index + 1, 1)
368-
: py::slice(start_index * first_dim_size,
369-
(start_index + 1) * first_dim_size, 1);
370-
}
371-
372350
py::tuple PreprocessSparseDenseMatmulInput(
373351
py::list features, py::list feature_weights, py::list feature_specs,
374352
const int local_device_count, const int global_device_count,
@@ -429,81 +407,63 @@ py::tuple PreprocessSparseDenseMatmulInput(
429407
const int coo_buffer_size_per_device = ComputeCooBufferSizePerDevice(
430408
num_scs, num_sc_per_device, stacked_table_metadata);
431409

432-
// Acquire GIL before creating Python arrays.
433-
py::gil_scoped_acquire acq;
434-
py::array_t<int> row_pointers_per_device =
435-
py::array_t<int>(GetArrayShapeBasedOnLeadingDim(
436-
has_leading_dimension, local_device_count,
437-
row_pointers_size_per_sc * num_sc_per_device));
438-
439-
py::array::ShapeContainer shape_container =
440-
GetArrayShapeBasedOnLeadingDim(has_leading_dimension,
441-
local_device_count,
442-
coo_buffer_size_per_device);
443-
py::array_t<int> embedding_ids_per_device =
444-
py::array_t<int>(shape_container);
445-
py::array_t<int> sample_ids_per_device =
446-
py::array_t<int>(shape_container);
447-
py::array_t<float> gains_per_device =
448-
py::array_t<float>(shape_container);
410+
MatrixXi row_pointers_per_device(
411+
local_device_count, row_pointers_size_per_sc * num_sc_per_device);
412+
MatrixXi embedding_ids_per_device(local_device_count,
413+
coo_buffer_size_per_device);
414+
MatrixXi sample_ids_per_device(local_device_count,
415+
coo_buffer_size_per_device);
416+
MatrixXf gains_per_device(local_device_count,
417+
coo_buffer_size_per_device);
418+
449419
const int stats_size_per_device = num_scs;
450420
// NOTE: max ids and max unique ids are {global_sc_count *
451421
// num_devices}, where they are then aggregated(max) along device
452422
// dimension to get {global_sc_count} (i.e. max [unique] ids for each
453423
// sc), which can be further aggregated(max) for a single value for
454424
// all SCs.
455-
py::array::ShapeContainer max_ids_stats_shape =
456-
GetArrayShapeBasedOnLeadingDim(
457-
/*has_leading_dimension=*/false, local_device_count,
458-
stats_size_per_device);
459-
py::array_t<int> max_ids_per_partition_per_sc =
460-
py::array_t<int>(max_ids_stats_shape);
461-
py::array_t<int> max_unique_ids_per_partition_per_sc =
462-
py::array_t<int>(max_ids_stats_shape);
425+
MatrixXi max_ids_per_partition_per_sc(local_device_count,
426+
stats_size_per_device);
427+
MatrixXi max_unique_ids_per_partition_per_sc(local_device_count,
428+
stats_size_per_device);
463429
// NOTE: required buffer size is {local_sc_count * num_devices}, which
464430
// is same as {global_sc_count}, and can be further aggregated to get
465431
// the maximum size of any SC buffer shard.
466-
py::array_t<int> required_buffer_size_per_sc =
467-
py::array_t<int>(GetArrayShapeBasedOnLeadingDim(
468-
false, local_device_count, num_sc_per_device));
432+
MatrixXi required_buffer_size_per_sc(local_device_count,
433+
num_sc_per_device);
469434
for (int local_device = 0; local_device < local_device_count;
470435
++local_device) {
471436
// Get the tuple outputs for the current split.
472-
auto row_pointer_buffer =
473-
row_pointers_per_device[GetBufferSliceForGivenDevice(
474-
has_leading_dimension, local_device,
475-
row_pointers_size_per_sc * num_sc_per_device)];
476-
py::slice static_buffer_slice = GetBufferSliceForGivenDevice(
477-
has_leading_dimension, local_device, coo_buffer_size_per_device);
478-
auto embedding_id_buffer =
479-
embedding_ids_per_device[static_buffer_slice];
480-
auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
481-
auto gain_buffer = gains_per_device[static_buffer_slice];
482-
py::slice max_ids_stats_slice =
483-
GetBufferSliceForGivenDevice(/*has_leading_dimension=*/false,
484-
local_device, stats_size_per_device);
485-
auto max_ids_per_partition_per_sc_buffer =
486-
max_ids_per_partition_per_sc[max_ids_stats_slice];
487-
auto max_unique_ids_per_partition_per_sc_buffer =
488-
max_unique_ids_per_partition_per_sc[max_ids_stats_slice];
489-
490-
auto required_buffer_size_per_sc_buffer =
491-
required_buffer_size_per_sc[GetBufferSliceForGivenDevice(
492-
false, local_device, num_sc_per_device)];
437+
Eigen::Ref<Eigen::VectorXi> row_pointer_buffer =
438+
row_pointers_per_device.row(local_device);
439+
Eigen::Ref<Eigen::VectorXi> embedding_id_buffer =
440+
embedding_ids_per_device.row(local_device);
441+
Eigen::Ref<Eigen::VectorXi> sample_id_buffer =
442+
sample_ids_per_device.row(local_device);
443+
Eigen::Ref<Eigen::VectorXf> gain_buffer =
444+
gains_per_device.row(local_device);
445+
Eigen::Ref<Eigen::VectorXi> max_ids_per_partition_per_sc_buffer =
446+
max_ids_per_partition_per_sc.row(local_device);
447+
Eigen::Ref<Eigen::VectorXi>
448+
max_unique_ids_per_partition_per_sc_buffer =
449+
max_unique_ids_per_partition_per_sc.row(local_device);
450+
Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer =
451+
required_buffer_size_per_sc.row(local_device);
452+
453+
// Acquire GIL
454+
py::gil_scoped_acquire acq;
493455
PreprocessInputForStackedTablePerLocalDevice(
494456
stacked_table_metadata, features, feature_weights, local_device,
495457
local_device_count, coo_buffer_size_per_device,
496458
row_pointers_size_per_sc, global_device_count, num_sc_per_device,
497459
sharding_strategy, stacked_table_name, allow_id_dropping,
498-
py::cast<py::array_t<int>>(row_pointer_buffer),
499-
py::cast<py::array_t<int>>(embedding_id_buffer),
500-
py::cast<py::array_t<int>>(sample_id_buffer),
501-
py::cast<py::array_t<float>>(gain_buffer),
502-
py::cast<py::array_t<int>>(max_ids_per_partition_per_sc_buffer),
503-
py::cast<py::array_t<int>>(
504-
max_unique_ids_per_partition_per_sc_buffer),
505-
py::cast<py::array_t<int>>(required_buffer_size_per_sc_buffer));
460+
row_pointer_buffer, embedding_id_buffer, sample_id_buffer,
461+
gain_buffer, max_ids_per_partition_per_sc_buffer,
462+
max_unique_ids_per_partition_per_sc_buffer,
463+
required_buffer_size_per_sc_buffer);
506464
}
465+
// Acquire GIL before updating Python dicts.
466+
py::gil_scoped_acquire acq;
507467
lhs_row_pointers[stacked_table_name.c_str()] =
508468
std::move(row_pointers_per_device);
509469
lhs_embedding_ids[stacked_table_name.c_str()] =
@@ -517,14 +477,29 @@ py::tuple PreprocessSparseDenseMatmulInput(
517477
std::move(max_unique_ids_per_partition_per_sc);
518478
required_buffer_sizes[stacked_table_name.c_str()] =
519479
std::move(required_buffer_size_per_sc);
480+
// To be eventually extracted out of the library
481+
if (!has_leading_dimension) {
482+
for (auto& vec : {lhs_row_pointers, lhs_embedding_ids, lhs_gains,
483+
lhs_sample_ids}) {
484+
vec[stacked_table_name.c_str()] =
485+
py::cast<py::array>(vec[stacked_table_name.c_str()])
486+
.reshape({-1});
487+
}
488+
}
489+
for (auto& vec : {max_ids_per_partition, max_unique_ids_per_partition,
490+
required_buffer_sizes}) {
491+
vec[stacked_table_name.c_str()] =
492+
py::cast<py::array>(vec[stacked_table_name.c_str()])
493+
.reshape({-1});
494+
}
520495
counter.DecrementCount();
521496
});
522497
}
523498
counter.Wait();
524499
}
525500
py::dict stats;
526-
stats["max_ids"] = max_ids_per_partition;
527-
stats["max_unique_ids"] = max_unique_ids_per_partition;
501+
stats["max_ids"] = std::move(max_ids_per_partition);
502+
stats["max_unique_ids"] = std::move(max_unique_ids_per_partition);
528503
stats["required_buffer_size"] = std::move(required_buffer_sizes);
529504

530505
// GIL is held at this point.

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "absl/strings/str_join.h" // from @com_google_absl
3131
#include "absl/strings/string_view.h" // from @com_google_absl
3232
#include "absl/types/span.h" // from @com_google_absl
33+
#include "third_party/eigen3/Eigen/Core"
3334
#include "hwy/contrib/sort/order.h" // from @highway
3435
#include "hwy/contrib/sort/vqsort.h" // from @highway
3536
#include "tsl/profiler/lib/traceme.h"
@@ -106,8 +107,9 @@ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
106107
const int32_t max_unique_ids_per_partition,
107108
const absl::string_view stacked_table_name, const bool allow_id_dropping,
108109
const int num_sc_per_device, const int total_num_coo_tensors,
109-
int max_ids_per_sc[], int max_unique_ids_per_sc[],
110-
int required_buffer_size_per_sc[]) {
110+
Eigen::Ref<Eigen::VectorXi> max_ids_per_sc,
111+
Eigen::Ref<Eigen::VectorXi> max_unique_ids_per_sc,
112+
Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc) {
111113
tsl::profiler::TraceMe t("SortAndGroupCooTensors");
112114
const int local_sc_count = batch_size_for_device / batch_size_per_sc;
113115
std::vector<std::vector<CooFormat>> coo_tensors_by_id;
@@ -122,16 +124,13 @@ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
122124
uint32_t coo_tensor_index = 0;
123125
const int32_t num_scs_bit = std::log2(global_sc_count);
124126
// Initialize the aggregated max ids and unique ids per SC to 0.
125-
for (int32_t global_sc_id = 0; global_sc_id < global_sc_count;
126-
++global_sc_id) {
127-
max_ids_per_sc[global_sc_id] = 0;
128-
max_unique_ids_per_sc[global_sc_id] = 0;
129-
}
127+
max_ids_per_sc.fill(0);
128+
max_unique_ids_per_sc.fill(0);
129+
required_buffer_size_per_sc.fill(0);
130130
// Loop over scs for this device.
131131
for (int32_t local_sc_id = 0; local_sc_id < local_sc_count; ++local_sc_id) {
132132
std::vector<int32_t> ids_per_sc_partition(global_sc_count, 0);
133133
std::vector<int32_t> unique_ids_per_sc_partition(global_sc_count, 0);
134-
required_buffer_size_per_sc[local_sc_id] = 0;
135134
std::vector<uint64_t> keys;
136135
keys.reserve(batch_size_per_sc);
137136
// We take the advantage of the fact that the row_ids are already sorted
@@ -295,7 +294,9 @@ void FillRowPointersPerLocalDevice(
295294
absl::Span<const std::vector<CooFormat>> coo_tensors_by_id,
296295
const int row_pointers_size_per_sc, const int coo_buffer_size_per_sc,
297296
const int batch_size_per_sc, const int num_scs, const int num_sc_per_device,
298-
int row_pointers[], int embedding_ids[], int sample_ids[], float gains[]) {
297+
Eigen::Ref<Eigen::VectorXi> row_pointers,
298+
Eigen::Ref<Eigen::VectorXi> embedding_ids,
299+
Eigen::Ref<Eigen::VectorXi> sample_ids, Eigen::Ref<Eigen::VectorXf> gains) {
299300
tsl::profiler::TraceMe t("FillRowPointers");
300301
for (int local_sc_id = 0; local_sc_id < num_sc_per_device; ++local_sc_id) {
301302
int lhs_row_index = 0;

0 commit comments

Comments
 (0)