Skip to content

Commit 2c7c216

Browse files
Introduce FdoStats struct for tracking FDO info.
PiperOrigin-RevId: 757849206
1 parent 8cae5b3 commit 2c7c216

File tree

7 files changed

+143
-13
lines changed

7 files changed

+143
-13
lines changed

jax_tpu_embedding/sparsecore/lib/core/BUILD

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
load("//jax_tpu_embedding/sparsecore:jax_tpu_embedding.bzl", "CORE_USERS")
15-
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension")
15+
load("//third_party/bazel/python:pybind11.bzl", "pybind_extension", "pybind_library")
1616
load("//third_party/bazel/python:pypi.bzl", "pypi_requirement")
1717
load("//third_party/bazel/python:pytype.bzl", "pytype_strict_contrib_test", "pytype_strict_library")
1818

@@ -78,6 +78,7 @@ pybind_extension(
7878
name = "input_preprocessing_cc",
7979
srcs = ["input_preprocessing.cc"],
8080
deps = [
81+
":fdo_types",
8182
":input_preprocessing_threads",
8283
":input_preprocessing_util",
8384
"@com_google_absl//absl/container:flat_hash_map",
@@ -90,6 +91,21 @@ pybind_extension(
9091
],
9192
)
9293

94+
pybind_library(
95+
name = "fdo_types",
96+
hdrs = ["fdo_types.h"],
97+
deps = ["@com_google_absl//absl/container:flat_hash_map"],
98+
)
99+
100+
pybind_extension(
101+
name = "fdo_types_cc",
102+
srcs = [
103+
"fdo_types.cc",
104+
"fdo_types.h",
105+
],
106+
deps = ["@com_google_absl//absl/container:flat_hash_map"],
107+
)
108+
93109
pytype_strict_library(
94110
name = "input_preprocessing",
95111
srcs = [
@@ -161,12 +177,17 @@ pytype_strict_library(
161177
srcs = ["__init__.py"],
162178
# C++ dependencies must go in "data".
163179
data = [
180+
":fdo_types", # buildcleaner: keep
164181
":input_preprocessing_threads", # buildcleaner: keep
165182
":input_preprocessing_util", # buildcleaner: keep
166183
],
167-
visibility = ["//jax_tpu_embedding/sparsecore/lib:__pkg__"],
184+
visibility = [
185+
"//jax_tpu_embedding/sparsecore/lib:__pkg__",
186+
"//jax_tpu_embedding/sparsecore/lib/extensions:__pkg__",
187+
],
168188
deps = [
169189
":constants", # buildcleaner: keep
190+
":fdo_types_cc", # buildcleaner: keep
170191
":input_preprocessing", # buildcleaner: keep
171192
":input_preprocessing_cc", # buildcleaner: keep
172193
"//jax_tpu_embedding/sparsecore/lib/core/primitives", # buildcleaner: keep
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"
15+
16+
#include "pybind11/cast.h" // from @pybind11
17+
#include "pybind11/numpy.h" // from @pybind11
18+
#include "pybind11/pybind11.h" // from @pybind11
19+
#include "pybind11/pytypes.h" // from @pybind11
20+
21+
namespace jax_sc_embedding {
22+
23+
namespace py = ::pybind11;
24+
25+
PYBIND11_MODULE(fdo_types_cc, m) {
26+
py::class_<PyFdoStats>(m, "FdoStats")
27+
.def_readwrite("max_ids_per_partition",
28+
&PyFdoStats::max_ids_per_partition)
29+
.def_readwrite("max_unique_ids_per_partition",
30+
&PyFdoStats::max_unique_ids_per_partition);
31+
}
32+
33+
} // namespace jax_sc_embedding
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
// Copyright 2024 The JAX SC Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_
15+
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_
16+
17+
#include <string>
18+
19+
#include "absl/container/flat_hash_map.h" // from @com_google_absl
20+
#include "pybind11/cast.h" // from @pybind11
21+
#include "pybind11/numpy.h" // from @pybind11
22+
#include "pybind11/pybind11.h" // from @pybind11
23+
#include "pybind11/pytypes.h" // from @pybind11
24+
25+
namespace jax_sc_embedding {
26+
27+
namespace py = ::pybind11;
28+
29+
template <typename ArrayType, typename MapType>
30+
struct is_valid_map_type {
31+
// is the value type of map same as T?
32+
static constexpr bool value = false;
33+
};
34+
35+
template <typename ArrayType, typename ValueType>
36+
struct is_valid_map_type<ArrayType,
37+
absl::flat_hash_map<std::string, ValueType>> {
38+
static constexpr bool value = true;
39+
};
40+
41+
template <typename ArrayType>
42+
struct is_valid_map_type<ArrayType, py::dict> {
43+
// dict does not support type annotations
44+
static constexpr bool value = true;
45+
};
46+
47+
template <typename ArrayType, // std::vector<int> or py::array_t<int>
48+
typename MapType // absl::flat_hash_map<std::string,
49+
// std::vector<int>> or py::dict
50+
>
51+
struct FdoStats {
52+
static_assert(is_valid_map_type<ArrayType, MapType>::value,
53+
"MapType must be a valid map type");
54+
55+
using FdoStatsPerSparseCore = ArrayType;
56+
57+
using StackedTableName = std::string;
58+
59+
// <StackedTableName, FdoStatsPerSparseCore>
60+
using FdoStatsPerStackedTable = MapType;
61+
62+
FdoStatsPerStackedTable max_ids_per_partition;
63+
FdoStatsPerStackedTable max_unique_ids_per_partition;
64+
FdoStatsPerStackedTable id_drop_counters;
65+
};
66+
67+
using PyFdoStats = FdoStats<py::array_t<int>, py::dict>;
68+
69+
} // namespace jax_sc_embedding
70+
71+
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_FDO_TYPES_H_

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414
#include <algorithm>
1515
#include <cmath>
16+
#include <cstdint>
1617
#include <optional>
1718
#include <string>
1819
#include <utility>
@@ -24,6 +25,7 @@
2425
#include "absl/strings/string_view.h" // from @com_google_absl
2526
#include "absl/synchronization/blocking_counter.h" // from @com_google_absl
2627
#include "absl/types/span.h" // from @com_google_absl
28+
#include "jax_tpu_embedding/sparsecore/lib/core/fdo_types.h"
2729
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_threads.h"
2830
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
2931
#include "pybind11/cast.h" // from @pybind11
@@ -382,8 +384,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
382384
py::dict lhs_embedding_ids;
383385
py::dict lhs_sample_ids;
384386
py::dict lhs_gains;
385-
py::dict max_ids_per_partition;
386-
py::dict max_unique_ids_per_partition;
387+
PyFdoStats::FdoStatsPerStackedTable max_ids_per_partition;
388+
PyFdoStats::FdoStatsPerStackedTable max_unique_ids_per_partition;
387389
const int num_scs = num_sc_per_device * global_device_count;
388390
const int row_pointers_size_per_sc = std::max(num_scs, 8);
389391

@@ -495,9 +497,10 @@ py::tuple PreprocessSparseDenseMatmulInput(
495497
}
496498
counter.Wait();
497499
}
498-
py::dict stats;
499-
stats["max_ids"] = max_ids_per_partition;
500-
stats["max_unique_ids"] = max_unique_ids_per_partition;
500+
PyFdoStats stats{
501+
.max_ids_per_partition = max_ids_per_partition,
502+
.max_unique_ids_per_partition = max_unique_ids_per_partition,
503+
};
501504
// GIL is held at this point.
502505
return py::make_tuple(lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids,
503506
lhs_gains, stats);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_cc_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ def test_multi_process_fdo(self, has_leading_dimension):
763763
allow_id_dropping=False,
764764
)
765765
)
766-
stats = embedding.SparseDenseMatmulInputStats.from_dict(stats)
766+
stats = embedding.SparseDenseMatmulInputStats.from_cc(stats)
767767
fdo_client.record(stats)
768768
fdo_client.publish()
769769
# Duplicated ids on row 0 and 6 are combined.

jax_tpu_embedding/sparsecore/lib/nn/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ pytype_strict_library(
4040
deps = [
4141
":embedding_spec",
4242
":table_stacking",
43+
"//jax_tpu_embedding/sparsecore/lib/core:fdo_types_cc",
4344
"//jax_tpu_embedding/sparsecore/lib/core:input_preprocessing_cc",
4445
"//jax_tpu_embedding/sparsecore/lib/core/primitives:sparse_dense_matmul_csr",
4546
"//jax_tpu_embedding/sparsecore/lib/proto:embedding_spec_py_pb2",

jax_tpu_embedding/sparsecore/lib/nn/embedding.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from jax.experimental.layout import DeviceLocalLayout as DLL
2525
from jax.experimental.layout import Layout
2626
import jax.numpy as jnp
27+
from jax_tpu_embedding.sparsecore.lib.core import fdo_types_cc
2728
from jax_tpu_embedding.sparsecore.lib.core import input_preprocessing_cc
2829
from jax_tpu_embedding.sparsecore.lib.core.primitives import sparse_dense_matmul_csr
2930
from jax_tpu_embedding.sparsecore.lib.nn import embedding_spec
@@ -69,12 +70,12 @@ class SparseDenseMatmulInputStats:
6970
max_unique_ids_per_partition: Mapping[str, np.ndarray]
7071

7172
@classmethod
72-
def from_dict(
73-
cls, stats: Mapping[str, Mapping[str, np.ndarray]]
73+
def from_cc(
74+
cls, stats: fdo_types_cc.FdoStats
7475
) -> "SparseDenseMatmulInputStats":
7576
return cls(
76-
max_ids_per_partition=stats["max_ids"],
77-
max_unique_ids_per_partition=stats["max_unique_ids"],
77+
max_ids_per_partition=stats.max_ids_per_partition,
78+
max_unique_ids_per_partition=stats.max_unique_ids_per_partition,
7879
)
7980

8081

@@ -380,7 +381,7 @@ def preprocess_sparse_dense_matmul_input(
380381

381382
return SparseDenseMatmulInput(
382383
*preprocessed_inputs
383-
), SparseDenseMatmulInputStats.from_dict(stats)
384+
), SparseDenseMatmulInputStats.from_cc(stats)
384385

385386

386387
def _get_activation_for_feature(

0 commit comments

Comments
 (0)