Skip to content

Commit c8472c5

Browse files
Introduce FdoStats struct for tracking FDO info.
PiperOrigin-RevId: 757849206
1 parent 2d0481f commit c8472c5

11 files changed

+159
-55
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "CORE_USERS")
15-
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension")
15+
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library")
1616
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
1717
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library")
1818

@@ -77,6 +77,7 @@ pybind_extension(
7777
name = "input_preprocessing_cc",
7878
srcs = ["input_preprocessing.cc"],
7979
deps = [
80+
":fdo_types",
8081
":input_preprocessing_threads",
8182
":input_preprocessing_util",
8283
"@com_google_absl//absl/container:flat_hash_map",
@@ -89,6 +90,24 @@ pybind_extension(
8990
],
9091
)
9192

93+
pybind_library(
94+
name = "fdo_types",
95+
hdrs = ["fdo_types.h"],
96+
deps = ["@com_google_absl//absl/container:flat_hash_map"],
97+
)
98+
99+
pybind_extension(
100+
name = "fdo_types_cc",
101+
srcs = [
102+
"fdo_types.cc",
103+
"fdo_types.h",
104+
],
105+
deps = [
106+
"//third_party/pybind11_abseil:absl_casters",
107+
"@com_google_absl//absl/container:flat_hash_map",
108+
],
109+
)
110+
92111
pytype_strict_library(
93112
name = "input_preprocessing",
94113
srcs = [
@@ -160,12 +179,17 @@ pytype_strict_library(
160179
srcs = ["__init__.py"],
161180
# C++ dependencies must go in "data".
162181
data = [
182+
":fdo_types", # buildcleaner: keep
163183
":input_preprocessing_threads", # buildcleaner: keep
164184
":input_preprocessing_util", # buildcleaner: keep
165185
],
166-
visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"],
186+
visibility = [
187+
"//jax_tpu_embedding/sparsecore/lib:__pkg__",
188+
"//jax_tpu_embedding/sparsecore/lib/extensions:__pkg__",
189+
],
167190
deps = [
168191
":constants", # buildcleaner: keep
192+
":fdo_types_cc", # buildcleaner: keep
169193
":input_preprocessing", # buildcleaner: keep
170194
":input_preprocessing_cc", # buildcleaner: keep
171195
"//jax_tpu_embedding/sparsecore/lib/core/primitives", # buildcleaner: keep
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"
15+
16+
#include "pybind11/cast.h" // from @pybind11
17+
#include "pybind11/numpy.h" // from @pybind11
18+
#include "pybind11/pybind11.h" // from @pybind11
19+
#include "pybind11/pytypes.h" // from @pybind11
20+
#include "pybind11/stl.h" // from @pybind11
21+
#include "third_party/pybind11_abseil/absl_casters.h"
22+
23+
namespace jax_sc_embedding {
24+
25+
namespace py = ::pybind11;
26+
27+
PYBIND11_MODULE(fdo_types_cc, m) {
28+
py::class_<FdoStats>(m, "FdoStats")
29+
.def_readonly("max_ids_per_partition", &FdoStats::max_ids_per_partition)
30+
.def_readonly("max_unique_ids_per_partition",
31+
&FdoStats::max_unique_ids_per_partition)
32+
.def_readonly("id_drop_counters", &FdoStats::id_drop_counters)
33+
.def_readonly("required_buffer_sizes", &FdoStats::required_buffer_sizes);
34+
}
35+
36+
} // namespace jax_sc_embedding
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_
15+
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_
16+
17+
#include <string>
18+
#include <vector>
19+
20+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
21+
22+
namespace jax_sc_embedding {
23+
24+
struct FdoStats {
25+
using FdoStatsPerSparseCore = std::vector<int>;
26+
27+
using StackedTableName = std::string;
28+
29+
using FdoStatsPerStackedTable =
30+
absl::flat_hash_map<StackedTableName, FdoStatsPerSparseCore>;
31+
32+
FdoStatsPerStackedTable max_ids_per_partition;
33+
FdoStatsPerStackedTable max_unique_ids_per_partition;
34+
FdoStatsPerStackedTable id_drop_counters;
35+
FdoStatsPerStackedTable required_buffer_sizes;
36+
};
37+
38+
} // namespace jax_sc_embedding
39+
40+
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 39 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414
#include <algorithm>
1515
#include <cmath>
16+
#include <cstddef>
17+
#include <cstdint>
1618
#include <optional>
1719
#include <string>
1820
#include <utility>
@@ -24,6 +26,7 @@
2426
#include "absl/strings/string_view.h" // from @com_google_absl
2527
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
2628
#include "absl/types/span.h" // from @com_google_absl
29+
#include "jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"
2730
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
2831
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
2932
#include "pybind11/cast.h" // from @pybind11
@@ -250,8 +253,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
250253
const absl::string_view stacked_table_name, const bool allow_id_dropping,
251254
py::array_t<int> row_pointer_buffer, py::array_t<int> embedding_id_buffer,
252255
py::array_t<int> sample_id_buffer, py::array_t<float> gain_buffer,
253-
py::array_t<int> max_ids_buffer, py::array_t<int> max_unique_ids_buffer,
254-
py::array_t<int> required_buffer_size_per_sc_buffer) {
256+
absl::Span<int> max_ids_buffer, absl::Span<int> max_unique_ids_buffer,
257+
absl::Span<int> required_buffer_size_per_sc_buffer) {
255258
const int num_scs = num_sc_per_device * num_global_devices;
256259
int batch_size_for_device = 0;
257260
int total_num_coo_tensors = 0;
@@ -299,10 +302,6 @@ void PreprocessInputForStackedTablePerLocalDevice(
299302
auto* embedding_ids_data = embedding_id_buffer.mutable_data();
300303
auto* sample_ids_data = sample_id_buffer.mutable_data();
301304
auto* gains_data = gain_buffer.mutable_data();
302-
auto* total_max_ids_per_sc = max_ids_buffer.mutable_data();
303-
auto* total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data();
304-
auto* required_buffer_size_per_sc =
305-
required_buffer_size_per_sc_buffer.mutable_data();
306305
// The remaining section does not require GIL.
307306
py::gil_scoped_release release;
308307

@@ -318,8 +317,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
318317
stacked_table_metadata[0].max_ids_per_partition,
319318
stacked_table_metadata[0].max_unique_ids_per_partition,
320319
stacked_table_name, allow_id_dropping, num_sc_per_device,
321-
total_num_coo_tensors, total_max_ids_per_sc,
322-
total_max_unique_ids_per_sc, required_buffer_size_per_sc);
320+
total_num_coo_tensors, max_ids_buffer, max_unique_ids_buffer,
321+
required_buffer_size_per_sc_buffer);
323322
for (int i = 0; i < num_sc_per_device; ++i) {
324323
coo_tensors_by_id[i].emplace_back(batch_size_per_sc * (i + 1), 0, 0.0);
325324
}
@@ -359,6 +358,13 @@ static inline py::slice GetBufferSliceForGivenDevice(bool has_leading_dimension,
359358
(start_index + 1) * first_dim_size, 1);
360359
}
361360

361+
static inline absl::Span<int> GetStatsSliceForGivenDevice(
362+
std::vector<int>& stats, int device_index, int stats_size_per_device) {
363+
return absl::MakeSpan(stats).subspan(
364+
device_index * stats_size_per_device,
365+
(device_index + 1) * stats_size_per_device);
366+
}
367+
362368
py::tuple PreprocessSparseDenseMatmulInput(
363369
py::list features, py::list feature_weights, py::list feature_specs,
364370
const int local_device_count, const int global_device_count,
@@ -379,9 +385,9 @@ py::tuple PreprocessSparseDenseMatmulInput(
379385
py::dict lhs_embedding_ids;
380386
py::dict lhs_sample_ids;
381387
py::dict lhs_gains;
382-
py::dict max_ids_per_partition;
383-
py::dict max_unique_ids_per_partition;
384-
py::dict required_buffer_sizes;
388+
FdoStats::FdoStatsPerStackedTable max_ids_per_partition;
389+
FdoStats::FdoStatsPerStackedTable max_unique_ids_per_partition;
390+
FdoStats::FdoStatsPerStackedTable required_buffer_sizes;
385391
const int num_scs = num_sc_per_device * global_device_count;
386392
const int row_pointers_size_per_sc = std::max(num_scs, 8);
387393

@@ -437,15 +443,10 @@ py::tuple PreprocessSparseDenseMatmulInput(
437443
py::array_t<float> gains_per_device =
438444
py::array_t<float>(shape_container);
439445
const int stats_size_per_device = num_scs;
440-
py::array::ShapeContainer stats_shape = GetArrayShapeBasedOnLeadingDim(
441-
/*has_leading_dimension=*/false, local_device_count,
442-
stats_size_per_device);
443-
py::array_t<int> max_ids_per_partition_per_sc =
444-
py::array_t<int>(stats_shape);
445-
py::array_t<int> max_unique_ids_per_partition_per_sc =
446-
py::array_t<int>(stats_shape);
447-
py::array_t<int> required_buffer_size_per_sc =
448-
py::array_t<int>(stats_shape);
446+
size_t stats_size = local_device_count * stats_size_per_device;
447+
std::vector<int> max_ids_per_partition_per_sc(stats_size);
448+
std::vector<int> max_unique_ids_per_partition_per_sc(stats_size);
449+
std::vector<int> required_buffer_size_per_sc(stats_size);
449450
for (int local_device = 0; local_device < local_device_count;
450451
++local_device) {
451452
// Get the tuple outputs for the current split.
@@ -459,15 +460,14 @@ py::tuple PreprocessSparseDenseMatmulInput(
459460
embedding_ids_per_device[static_buffer_slice];
460461
auto sample_id_buffer = sample_ids_per_device[static_buffer_slice];
461462
auto gain_buffer = gains_per_device[static_buffer_slice];
462-
py::slice stats_slice =
463-
GetBufferSliceForGivenDevice(/*has_leading_dimension=*/false,
464-
local_device, stats_size_per_device);
465-
auto max_ids_per_partition_per_sc_buffer =
466-
max_ids_per_partition_per_sc[stats_slice];
467-
auto max_unique_ids_per_partition_per_sc_buffer =
468-
max_unique_ids_per_partition_per_sc[stats_slice];
469-
auto required_buffer_size_per_sc_buffer =
470-
required_buffer_size_per_sc[stats_slice];
463+
auto device_max_ids_per_partition =
464+
GetStatsSliceForGivenDevice(max_ids_per_partition_per_sc,
465+
local_device, stats_size_per_device);
466+
auto device_max_unique_ids_per_partition =
467+
GetStatsSliceForGivenDevice(max_unique_ids_per_partition_per_sc,
468+
local_device, stats_size_per_device);
469+
auto device_required_buffer_size = GetStatsSliceForGivenDevice(
470+
required_buffer_size_per_sc, local_device, stats_size_per_device);
471471
PreprocessInputForStackedTablePerLocalDevice(
472472
stacked_table_metadata, features, feature_weights, local_device,
473473
local_device_count, coo_buffer_size, row_pointers_size_per_sc,
@@ -477,10 +477,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
477477
py::cast<py::array_t<int>>(embedding_id_buffer),
478478
py::cast<py::array_t<int>>(sample_id_buffer),
479479
py::cast<py::array_t<float>>(gain_buffer),
480-
py::cast<py::array_t<int>>(max_ids_per_partition_per_sc_buffer),
481-
py::cast<py::array_t<int>>(
482-
max_unique_ids_per_partition_per_sc_buffer),
483-
py::cast<py::array_t<int>>(required_buffer_size_per_sc_buffer));
480+
device_max_ids_per_partition, device_max_unique_ids_per_partition,
481+
device_required_buffer_size);
484482
}
485483
lhs_row_pointers[stacked_table_name.c_str()] =
486484
std::move(row_pointers_per_device);
@@ -490,21 +488,21 @@ py::tuple PreprocessSparseDenseMatmulInput(
490488
std::move(sample_ids_per_device);
491489
lhs_gains[stacked_table_name.c_str()] = std::move(gains_per_device);
492490
max_ids_per_partition[stacked_table_name.c_str()] =
493-
std::move(max_ids_per_partition_per_sc);
491+
max_ids_per_partition_per_sc;
494492
max_unique_ids_per_partition[stacked_table_name.c_str()] =
495-
std::move(max_unique_ids_per_partition_per_sc);
493+
max_unique_ids_per_partition_per_sc;
496494
required_buffer_sizes[stacked_table_name.c_str()] =
497-
std::move(required_buffer_size_per_sc);
495+
required_buffer_size_per_sc;
498496
counter.DecrementCount();
499497
});
500498
}
501499
counter.Wait();
502500
}
503-
py::dict stats;
504-
stats["max_ids"] = max_ids_per_partition;
505-
stats["max_unique_ids"] = max_unique_ids_per_partition;
506-
stats["required_buffer_size"] = std::move(required_buffer_sizes);
507-
501+
FdoStats stats{
502+
.max_ids_per_partition = max_ids_per_partition,
503+
.max_unique_ids_per_partition = max_unique_ids_per_partition,
504+
.required_buffer_sizes = required_buffer_sizes,
505+
};
508506
// GIL is held at this point.
509507
return py::make_tuple(lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids,
510508
lhs_gains, stats);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_cc_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ def test_multi_process_fdo(self, has_leading_dimension):
763763
allow_id_dropping=False,
764764
)
765765
)
766-
stats = embedding.SparseDenseMatmulInputStats.from_dict(stats)
766+
stats = embedding.SparseDenseMatmulInputStats.from_cc(stats)
767767
fdo_client.record(stats)
768768
fdo_client.publish()
769769
# Duplicated ids on row 0 and 6 are combined.

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,8 @@ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
105105
const int32_t max_unique_ids_per_partition,
106106
const absl::string_view stacked_table_name, const bool allow_id_dropping,
107107
const int num_sc_per_device, const int total_num_coo_tensors,
108-
int max_ids_per_sc[], int max_unique_ids_per_sc[],
109-
int required_buffer_size_per_sc[]) {
108+
absl::Span<int> max_ids_per_sc, absl::Span<int> max_unique_ids_per_sc,
109+
absl::Span<int> required_buffer_size_per_sc) {
110110
tsl::profiler::TraceMe t("SortAndGroupCooTensors");
111111
const int local_sc_count = batch_size_for_device / batch_size_per_sc;
112112
std::vector<std::vector<CooFormat>> coo_tensors_by_id;

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,18 +120,19 @@ struct StackedTableMetadata {
120120

121121
std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
122122
absl::Span<const CooFormat> coo_tensors, int batch_size_per_sc,
123-
int global_sc_count,
124-
int32_t batch_size_for_device, // Batch size for the local device.
123+
int global_sc_count, int32_t batch_size_for_device,
125124
int32_t max_ids_per_partition, int32_t max_unique_ids_per_partition,
126125
absl::string_view stacked_table_name, bool allow_id_dropping,
127-
int num_sc_per_device, int total_num_coo_tensors, int max_ids_per_sc[],
128-
int max_unique_ids_per_sc[], int required_buffer_size_per_sc[]);
126+
int num_sc_per_device, int total_num_coo_tensors,
127+
absl::Span<int> max_ids_per_sc, absl::Span<int> max_unique_ids_per_sc,
128+
absl::Span<int> required_buffer_size_per_sc);
129129

130130
int ComputeCooBufferSize(
131131
int num_scs, int num_scs_per_device,
132132
absl::Span<const StackedTableMetadata> stacked_table_metadata,
133133
int static_buffer_size_multiplier);
134134

135+
135136
void IncrementScId(std::pair<int, int>& sc_id, int num_scs,
136137
int num_scs_per_device);
137138

jax_tpu_embedding/sparsecore/lib/nn/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pytype_strict_library(
4040
deps = [
4141
":embedding_spec",
4242
":table_stacking",
43+
"//jax_tpu_embedding/sparsecore/lib/core:fdo_types_cc",
4344
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc",
4445
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr",
4546
"//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2",

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from jax.experimental.layout import DeviceLocalLayout as DLL
2525
from jax.experimental.layout import Layout
2626
import jax.numpy as jnp
27+
from jax_tpu_embedding.sparsecore.lib.core import fdo_types_cc
2728
from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_cc
2829
from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_csr
2930
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
@@ -69,12 +70,12 @@ class SparseDenseMatmulInputStats:
6970
max_unique_ids_per_partition: Mapping[str, np.ndarray]
7071

7172
@classmethod
72-
def from_dict(
73-
cls, stats: Mapping[str, Mapping[str, np.ndarray]]
73+
def from_cc(
74+
cls, stats: fdo_types_cc.FdoStats
7475
) -> "SparseDenseMatmulInputStats":
7576
return cls(
76-
max_ids_per_partition=stats["max_ids"],
77-
max_unique_ids_per_partition=stats["max_unique_ids"],
77+
max_ids_per_partition=stats.max_ids_per_partition,
78+
max_unique_ids_per_partition=stats.max_unique_ids_per_partition,
7879
)
7980

8081

@@ -380,7 +381,7 @@ def preprocess_sparse_dense_matmul_input(
380381

381382
return SparseDenseMatmulInput(
382383
*preprocessed_inputs
383-
), SparseDenseMatmulInputStats.from_dict(stats)
384+
), SparseDenseMatmulInputStats.from_cc(stats)
384385

385386

386387
def _get_activation_for_feature(

jax_tpu_embedding/sparsecore/lib/nn/tests/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ py_binary(
4141
name = "preprocess_input_benchmarks",
4242
srcs = ["preprocess_input_benchmarks.py"],
4343
deps = [
44+
"//jax_tpu_embedding/sparsecore/lib/core:fdo_types_cc",
4445
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc",
4546
"//jax_tpu_embedding/sparsecore/lib/nn:embedding_spec",
4647
pypi_requirement("google_benchmark"),

0 commit comments

Comments
 (0)