Skip to content

Commit e628bf6

Browse files
Replace numpy arrays with Eigen Matrix
PiperOrigin-RevId: 765328968
1 parent 40f8ee3 commit e628bf6

File tree

6 files changed

+148
-157
lines changed

6 files changed

+148
-157
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ cc_library(
5050
srcs = ["input_preprocessing_util.cc"],
5151
hdrs = ["input_preprocessing_util.h"],
5252
deps = [
53+
"//third_party/eigen3",
5354
"@com_google_absl//absl/algorithm:container",
5455
"@com_google_absl//absl/log",
5556
"@com_google_absl//absl/log:check",
@@ -68,6 +69,7 @@ cc_test(
6869
env = {"JAX_PLATFORMS": "cpu"},
6970
deps = [
7071
":input_preprocessing_util",
72+
"//third_party/eigen3",
7173
"@com_google_googletest//:gtest_main",
7274
],
7375
)
@@ -78,11 +80,14 @@ pybind_extension(
7880
deps = [
7981
":input_preprocessing_threads",
8082
":input_preprocessing_util",
83+
"//third_party/eigen3",
8184
"@com_google_absl//absl/container:flat_hash_map",
85+
"@com_google_absl//absl/log",
8286
"@com_google_absl//absl/log:check",
8387
"@com_google_absl//absl/strings",
8488
"@com_google_absl//absl/synchronization",
8589
"@com_google_absl//absl/types:span",
90+
"@pybind11//:pybind11_eigen",
8691
"@tsl//tsl/profiler/lib:connected_traceme",
8792
"@tsl//tsl/profiler/lib:traceme",
8893
],

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 82 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
// limitations under the License.
1414
#include <algorithm>
1515
#include <cmath>
16+
#include <functional>
1617
#include <limits>
1718
#include <optional>
1819
#include <string>
1920
#include <utility>
21+
#include <variant>
2022
#include <vector>
2123

2224
#include "absl/container/flat_hash_map.h" // from @com_google_absl
@@ -25,13 +27,16 @@
2527
#include "absl/strings/string_view.h" // from @com_google_absl
2628
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
2729
#include "absl/types/span.h" // from @com_google_absl
30+
#include "third_party/eigen3/Eigen/Core"
2831
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
2932
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
3033
#include "pybind11/cast.h" // from @pybind11
34+
#include "pybind11/eigen.h" // from @pybind11
3135
#include "pybind11/gil.h" // from @pybind11
3236
#include "pybind11/numpy.h" // from @pybind11
3337
#include "pybind11/pybind11.h" // from @pybind11
3438
#include "pybind11/pytypes.h" // from @pybind11
39+
#include "pybind11/stl.h" // from @pybind11
3540
#include "tsl/profiler/lib/connected_traceme.h"
3641
#include "tsl/profiler/lib/traceme.h"
3742

@@ -40,6 +45,10 @@ namespace jax_sc_embedding {
4045
namespace {
4146

4247
namespace py = ::pybind11;
48+
using MatrixXi =
49+
Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
50+
using MatrixXf =
51+
Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
4352

4453
namespace {
4554

@@ -126,16 +135,15 @@ int ExtractCooTensorsFrom1dArray(const py::array& features,
126135
const int global_device_count,
127136
const RowCombiner combiner,
128137
std::vector<CooFormat>& coo_tensors) {
138+
// We use proxy objects to the python array for the remainder of the function
139+
// and can hence release the GIL.
140+
py::gil_scoped_release release_gil;
129141
// The assumption here is that the gains are always represented as 32bit
130142
// float arrays (np array with dtype=np.float32) and the features are always
131143
// represented as 32bit int arrays (np array with dtype=np.int32).
132144
auto f = features.unchecked<py::array_t<int>, 1>();
133145
auto fw = feature_weights.unchecked<py::array_t<float>, 1>();
134146

135-
// We use proxy objects to the python array for the remainder of the function
136-
// and can hence release the GIL.
137-
py::gil_scoped_release release_gil;
138-
139147
coo_tensors.reserve(f.shape(0));
140148
int coo_tensors_extracted = 0;
141149

@@ -257,10 +265,13 @@ void PreprocessInputForStackedTablePerLocalDevice(
257265
const int row_pointers_size_per_sc, const int num_global_devices,
258266
const int num_sc_per_device, const int sharding_strategy,
259267
const absl::string_view stacked_table_name, const bool allow_id_dropping,
260-
py::array_t<int> row_pointer_buffer, py::array_t<int> embedding_id_buffer,
261-
py::array_t<int> sample_id_buffer, py::array_t<float> gain_buffer,
262-
py::array_t<int> max_ids_buffer, py::array_t<int> max_unique_ids_buffer,
263-
py::array_t<int> required_buffer_size_per_sc_buffer) {
268+
Eigen::Ref<Eigen::VectorXi> row_pointer_buffer,
269+
Eigen::Ref<Eigen::VectorXi> embedding_id_buffer,
270+
Eigen::Ref<Eigen::VectorXi> sample_id_buffer,
271+
Eigen::Ref<Eigen::VectorXf> gain_buffer,
272+
Eigen::Ref<Eigen::VectorXi> max_ids_buffer,
273+
Eigen::Ref<Eigen::VectorXi> max_unique_ids_buffer,
274+
Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer) {
264275
const int num_scs = num_sc_per_device * num_global_devices;
265276
int batch_size_for_device = 0;
266277
int total_num_coo_tensors = 0;
@@ -302,19 +313,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
302313
feature_split, feature_weights_split, row_offset, col_offset, col_shift,
303314
num_scs, num_global_devices, metadata.row_combiner, coo_tensors);
304315
}
305-
row_pointer_buffer[py::make_tuple(py::ellipsis())] = coo_buffer_size;
306-
307-
auto* row_pointer_data = row_pointer_buffer.mutable_data();
308-
auto* embedding_ids_data = embedding_id_buffer.mutable_data();
309-
auto* sample_ids_data = sample_id_buffer.mutable_data();
310-
auto* gains_data = gain_buffer.mutable_data();
311-
auto* total_max_ids_per_sc = max_ids_buffer.mutable_data();
312-
auto* total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data();
313-
auto* required_buffer_size_per_sc =
314-
required_buffer_size_per_sc_buffer.mutable_data();
315316
// The remaining section does not require GIL.
316317
py::gil_scoped_release release;
317318

319+
row_pointer_buffer.setConstant(coo_buffer_size);
320+
318321
//
319322
// Step 2: Sort the COO tensors and group them by SC.
320323
//
@@ -327,11 +330,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
327330
stacked_table_metadata[0].max_ids_per_partition,
328331
stacked_table_metadata[0].max_unique_ids_per_partition,
329332
stacked_table_name, allow_id_dropping, num_sc_per_device,
330-
total_num_coo_tensors, total_max_ids_per_sc,
331-
total_max_unique_ids_per_sc, required_buffer_size_per_sc);
333+
total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer,
334+
required_buffer_size_per_sc_buffer);
332335
for (int i = 0; i < num_sc_per_device; ++i) {
333336
coo_tensors_by_id[i].emplace_back(batch_size_per_sc * (i + 1), 0, 0.0);
334-
required_buffer_size_per_sc[i]++;
337+
required_buffer_size_per_sc_buffer[i]++;
335338
}
336339
//
337340
// Step 3: Compute the row pointers for each group of IDs.
@@ -340,35 +343,11 @@ void PreprocessInputForStackedTablePerLocalDevice(
340343
const int coo_buffer_size_per_sc = coo_buffer_size / num_sc_per_device;
341344
FillRowPointersPerLocalDevice(
342345
coo_tensors_by_id, row_pointers_size_per_sc, coo_buffer_size_per_sc,
343-
batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_data,
344-
embedding_ids_data, sample_ids_data, gains_data);
346+
batch_size_per_sc, num_scs, num_sc_per_device, row_pointer_buffer,
347+
embedding_id_buffer, sample_id_buffer, gain_buffer);
345348
}
346349
}
347350

348-
// Helper function to get the shape container for the output arrays.
349-
// If `has_leading_dimension` is true, the shape will be
350-
// [local_device_count, inner_dim_size]. Otherwise, the shape will be
351-
// [local_device_count * inner_dim_size].
352-
static inline py::array::ShapeContainer GetArrayShapeBasedOnLeadingDim(
353-
bool has_leading_dimension, int local_device_count, int inner_dim_size) {
354-
return has_leading_dimension
355-
? py::array::ShapeContainer({local_device_count, inner_dim_size})
356-
: py::array::ShapeContainer({local_device_count * inner_dim_size});
357-
}
358-
359-
// Helper function to get the slice for a given device.
360-
// If `has_leading_dimension` is true, the slice will be
361-
// [device_index:device_index+1]. Otherwise, the slice will be
362-
// [device_index * first_dim_size:(device_index + 1) * first_dim_size].
363-
static inline py::slice GetBufferSliceForGivenDevice(bool has_leading_dimension,
364-
int start_index,
365-
int first_dim_size) {
366-
return has_leading_dimension
367-
? py::slice(start_index, start_index + 1, 1)
368-
: py::slice(start_index * first_dim_size,
369-
(start_index + 1) * first_dim_size, 1);
370-
}
371-
372351
py::tuple PreprocessSparseDenseMatmulInput(
373352
py::list features, py::list feature_weights, py::list feature_specs,
374353
const int local_device_count, const int global_device_count,
@@ -429,81 +408,63 @@ py::tuple PreprocessSparseDenseMatmulInput(
429408
const int coo_buffer_size_per_device = ComputeCooBufferSizePerDevice(
430409
num_scs, num_sc_per_device, stacked_table_metadata);
431410

432-
// Acquire GIL before creating Python arrays.
433-
py::gil_scoped_acquire acq;
434-
py::array_t<int> row_pointers_per_device =
435-
py::array_t<int>(GetArrayShapeBasedOnLeadingDim(
436-
has_leading_dimension, local_device_count,
437-
row_pointers_size_per_sc * num_sc_per_device));
438-
439-
py::array::ShapeContainer shape_container =
440-
GetArrayShapeBasedOnLeadingDim(has_leading_dimension,
441-
local_device_count,
442-
coo_buffer_size_per_device);
443-
py::array_t<int> embedding_ids_per_device =
444-
py::array_t<int>(shape_container);
445-
py::array_t<int> sample_ids_per_device =
446-
py::array_t<int>(shape_container);
447-
py::array_t<float> gains_per_device =
448-
py::array_t<float>(shape_container);
411+
MatrixXi row_pointers_per_device(
412+
local_device_count, row_pointers_size_per_sc * num_sc_per_device);
413+
MatrixXi embedding_ids_per_device(local_device_count,
414+
coo_buffer_size_per_device);
415+
MatrixXi sample_ids_per_device(local_device_count,
416+
coo_buffer_size_per_device);
417+
MatrixXf gains_per_device(local_device_count,
418+
coo_buffer_size_per_device);
419+
449420
const int stats_size_per_device = num_scs;
450421
// NOTE: max ids and max unique ids are {global_sc_count *
451422
// num_devices}, where they are then aggregated(max) along device
452423
// dimension to get {global_sc_count} (i.e. max [unique] ids for each
453424
// sc), which can be further aggregated(max) for a single value for
454425
// all SCs.
455-
py::array::ShapeContainer max_ids_stats_shape =
456-
GetArrayShapeBasedOnLeadingDim(
457-
/*has_leading_dimension=*/false, local_device_count,
458-
stats_size_per_device);
459-
py::array_t<int> max_ids_per_partition_per_sc =
460-
py::array_t<int>(max_ids_stats_shape);
461-
py::array_t<int> max_unique_ids_per_partition_per_sc =
462-
py::array_t<int>(max_ids_stats_shape);
426+
MatrixXi max_ids_per_partition_per_sc(local_device_count,
427+
stats_size_per_device);
428+
MatrixXi max_unique_ids_per_partition_per_sc(local_device_count,
429+
stats_size_per_device);
463430
// NOTE: required buffer size is {local_sc_count * num_devices}, which
464431
// is same as {global_sc_count}, and can be further aggregated to get
465432
// the maximum size of any SC buffer shard.
466-
py::array_t<int> required_buffer_size_per_sc =
467-
py::array_t<int>(GetArrayShapeBasedOnLeadingDim(
468-
false, local_device_count, num_sc_per_device));
433+
MatrixXi required_buffer_size_per_sc(local_device_count,
434+
num_sc_per_device);
469435
for (int local_device = 0; local_device < local_device_count;
470436
++local_device) {
471437
// Get the tuple outputs for the current split.
472-
auto row_pointer_buffer =
473-
row_pointers_per_device[GetBufferSliceForGivenDevice(
474-
has_leading_dimension, local_device,
475-
row_pointers_size_per_sc * num_sc_per_device)];
476-
py::slice static_buffer_slice = GetBufferSliceForGivenDevice(
477-
has_leading_dimension, local_device, coo_buffer_size_per_device);
478-
auto embedding_id_buffer =
479-
embedding_ids_per_device[static_buffer_slice];
480-
auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
481-
auto gain_buffer = gains_per_device[static_buffer_slice];
482-
py::slice max_ids_stats_slice =
483-
GetBufferSliceForGivenDevice(/*has_leading_dimension=*/false,
484-
local_device, stats_size_per_device);
485-
auto max_ids_per_partition_per_sc_buffer =
486-
max_ids_per_partition_per_sc[max_ids_stats_slice];
487-
auto max_unique_ids_per_partition_per_sc_buffer =
488-
max_unique_ids_per_partition_per_sc[max_ids_stats_slice];
489-
490-
auto required_buffer_size_per_sc_buffer =
491-
required_buffer_size_per_sc[GetBufferSliceForGivenDevice(
492-
false, local_device, num_sc_per_device)];
438+
Eigen::Ref<Eigen::VectorXi> row_pointer_buffer =
439+
row_pointers_per_device.row(local_device);
440+
Eigen::Ref<Eigen::VectorXi> embedding_id_buffer =
441+
embedding_ids_per_device.row(local_device);
442+
Eigen::Ref<Eigen::VectorXi> sample_id_buffer =
443+
sample_ids_per_device.row(local_device);
444+
Eigen::Ref<Eigen::VectorXf> gain_buffer =
445+
gains_per_device.row(local_device);
446+
Eigen::Ref<Eigen::VectorXi> max_ids_per_partition_per_sc_buffer =
447+
max_ids_per_partition_per_sc.row(local_device);
448+
Eigen::Ref<Eigen::VectorXi>
449+
max_unique_ids_per_partition_per_sc_buffer =
450+
max_unique_ids_per_partition_per_sc.row(local_device);
451+
Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc_buffer =
452+
required_buffer_size_per_sc.row(local_device);
453+
454+
// Acquire GIL
455+
py::gil_scoped_acquire acq;
493456
PreprocessInputForStackedTablePerLocalDevice(
494457
stacked_table_metadata, features, feature_weights, local_device,
495458
local_device_count, coo_buffer_size_per_device,
496459
row_pointers_size_per_sc, global_device_count, num_sc_per_device,
497460
sharding_strategy, stacked_table_name, allow_id_dropping,
498-
py::cast<py::array_t<int>>(row_pointer_buffer),
499-
py::cast<py::array_t<int>>(embedding_id_buffer),
500-
py::cast<py::array_t<int>>(sample_id_buffer),
501-
py::cast<py::array_t<float>>(gain_buffer),
502-
py::cast<py::array_t<int>>(max_ids_per_partition_per_sc_buffer),
503-
py::cast<py::array_t<int>>(
504-
max_unique_ids_per_partition_per_sc_buffer),
505-
py::cast<py::array_t<int>>(required_buffer_size_per_sc_buffer));
461+
row_pointer_buffer, embedding_id_buffer, sample_id_buffer,
462+
gain_buffer, max_ids_per_partition_per_sc_buffer,
463+
max_unique_ids_per_partition_per_sc_buffer,
464+
required_buffer_size_per_sc_buffer);
506465
}
466+
// Acquire GIL before updating Python dicts.
467+
py::gil_scoped_acquire acq;
507468
lhs_row_pointers[stacked_table_name.c_str()] =
508469
std::move(row_pointers_per_device);
509470
lhs_embedding_ids[stacked_table_name.c_str()] =
@@ -517,14 +478,29 @@ py::tuple PreprocessSparseDenseMatmulInput(
517478
std::move(max_unique_ids_per_partition_per_sc);
518479
required_buffer_sizes[stacked_table_name.c_str()] =
519480
std::move(required_buffer_size_per_sc);
481+
// To be eventually extracted out of the library
482+
if (!has_leading_dimension) {
483+
for (auto& vec : {lhs_row_pointers, lhs_embedding_ids, lhs_gains,
484+
lhs_sample_ids}) {
485+
vec[stacked_table_name.c_str()] =
486+
py::cast<py::array>(vec[stacked_table_name.c_str()])
487+
.reshape({-1});
488+
}
489+
}
490+
for (auto& vec : {max_ids_per_partition, max_unique_ids_per_partition,
491+
required_buffer_sizes}) {
492+
vec[stacked_table_name.c_str()] =
493+
py::cast<py::array>(vec[stacked_table_name.c_str()])
494+
.reshape({-1});
495+
}
520496
counter.DecrementCount();
521497
});
522498
}
523499
counter.Wait();
524500
}
525501
py::dict stats;
526-
stats["max_ids"] = max_ids_per_partition;
527-
stats["max_unique_ids"] = max_unique_ids_per_partition;
502+
stats["max_ids"] = std::move(max_ids_per_partition);
503+
stats["max_unique_ids"] = std::move(max_unique_ids_per_partition);
528504
stats["required_buffer_size"] = std::move(required_buffer_sizes);
529505

530506
// GIL is held at this point.

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "absl/strings/str_join.h" // from @com_google_absl
3131
#include "absl/strings/string_view.h" // from @com_google_absl
3232
#include "absl/types/span.h" // from @com_google_absl
33+
#include "third_party/eigen3/Eigen/Core"
3334
#include "hwy/contrib/sort/order.h" // from @highway
3435
#include "hwy/contrib/sort/vqsort.h" // from @highway
3536
#include "tsl/profiler/lib/traceme.h"
@@ -106,8 +107,9 @@ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
106107
const int32_t max_unique_ids_per_partition,
107108
const absl::string_view stacked_table_name, const bool allow_id_dropping,
108109
const int num_sc_per_device, const int total_num_coo_tensors,
109-
int max_ids_per_sc[], int max_unique_ids_per_sc[],
110-
int required_buffer_size_per_sc[]) {
110+
Eigen::Ref<Eigen::VectorXi> max_ids_per_sc,
111+
Eigen::Ref<Eigen::VectorXi> max_unique_ids_per_sc,
112+
Eigen::Ref<Eigen::VectorXi> required_buffer_size_per_sc) {
111113
tsl::profiler::TraceMe t("SortAndGroupCooTensors");
112114
const int local_sc_count = batch_size_for_device / batch_size_per_sc;
113115
std::vector<std::vector<CooFormat>> coo_tensors_by_id;
@@ -122,16 +124,13 @@ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
122124
uint32_t coo_tensor_index = 0;
123125
const int32_t num_scs_bit = std::log2(global_sc_count);
124126
// Initialize the aggregated max ids and unique ids per SC to 0.
125-
for (int32_t global_sc_id = 0; global_sc_id < global_sc_count;
126-
++global_sc_id) {
127-
max_ids_per_sc[global_sc_id] = 0;
128-
max_unique_ids_per_sc[global_sc_id] = 0;
129-
}
127+
max_ids_per_sc.fill(0);
128+
max_unique_ids_per_sc.fill(0);
129+
required_buffer_size_per_sc.fill(0);
130130
// Loop over scs for this device.
131131
for (int32_t local_sc_id = 0; local_sc_id < local_sc_count; ++local_sc_id) {
132132
std::vector<int32_t> ids_per_sc_partition(global_sc_count, 0);
133133
std::vector<int32_t> unique_ids_per_sc_partition(global_sc_count, 0);
134-
required_buffer_size_per_sc[local_sc_id] = 0;
135134
std::vector<uint64_t> keys;
136135
keys.reserve(batch_size_per_sc);
137136
// We take the advantage of the fact that the row_ids are already sorted
@@ -295,7 +294,9 @@ void FillRowPointersPerLocalDevice(
295294
absl::Span<const std::vector<CooFormat>> coo_tensors_by_id,
296295
const int row_pointers_size_per_sc, const int coo_buffer_size_per_sc,
297296
const int batch_size_per_sc, const int num_scs, const int num_sc_per_device,
298-
int row_pointers[], int embedding_ids[], int sample_ids[], float gains[]) {
297+
Eigen::Ref<Eigen::VectorXi> row_pointers,
298+
Eigen::Ref<Eigen::VectorXi> embedding_ids,
299+
Eigen::Ref<Eigen::VectorXi> sample_ids, Eigen::Ref<Eigen::VectorXf> gains) {
299300
tsl::profiler::TraceMe t("FillRowPointers");
300301
for (int local_sc_id = 0; local_sc_id < num_sc_per_device; ++local_sc_id) {
301302
int lhs_row_index = 0;

0 commit comments

Comments
 (0)