Skip to content

Commit 2d0481f

Browse files
Compute required buffer size for SC
PiperOrigin-RevId: 759621757
1 parent 7f77ac4 commit 2d0481f

File tree

4 files changed

+272
-162
lines changed

4 files changed

+272
-162
lines changed

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
250250
const absl::string_view stacked_table_name, const bool allow_id_dropping,
251251
py::array_t<int> row_pointer_buffer, py::array_t<int> embedding_id_buffer,
252252
py::array_t<int> sample_id_buffer, py::array_t<float> gain_buffer,
253-
py::array_t<int> max_ids_buffer, py::array_t<int> max_unique_ids_buffer) {
253+
py::array_t<int> max_ids_buffer, py::array_t<int> max_unique_ids_buffer,
254+
py::array_t<int> required_buffer_size_per_sc_buffer) {
254255
const int num_scs = num_sc_per_device * num_global_devices;
255256
int batch_size_for_device = 0;
256257
int total_num_coo_tensors = 0;
@@ -300,6 +301,8 @@ void PreprocessInputForStackedTablePerLocalDevice(
300301
auto* gains_data = gain_buffer.mutable_data();
301302
auto* total_max_ids_per_sc = max_ids_buffer.mutable_data();
302303
auto* total_max_unique_ids_per_sc = max_unique_ids_buffer.mutable_data();
304+
auto* required_buffer_size_per_sc =
305+
required_buffer_size_per_sc_buffer.mutable_data();
303306
// The remaining section does not require GIL.
304307
py::gil_scoped_release release;
305308

@@ -308,21 +311,15 @@ void PreprocessInputForStackedTablePerLocalDevice(
308311
//
309312
const int batch_size_per_sc =
310313
CeilOfRatio(batch_size_for_device, num_sc_per_device);
311-
std::vector<std::vector<CooFormat>> coo_tensors_by_id;
312-
coo_tensors_by_id.resize(num_sc_per_device);
313314

314-
const int approximate_num_coo_tensors_per_sc =
315-
total_num_coo_tensors / num_sc_per_device + 1;
316-
for (int i = 0; i < num_sc_per_device; ++i) {
317-
// Roughly estimate the number of COO tensors for each SC.
318-
coo_tensors_by_id[i].reserve(approximate_num_coo_tensors_per_sc);
319-
}
320-
SortAndGroupCooTensorsPerLocalDevice(
321-
coo_tensors, batch_size_per_sc, num_scs, batch_size_for_device,
322-
stacked_table_metadata[0].max_ids_per_partition,
323-
stacked_table_metadata[0].max_unique_ids_per_partition,
324-
stacked_table_name, allow_id_dropping, coo_tensors_by_id,
325-
total_max_ids_per_sc, total_max_unique_ids_per_sc);
315+
std::vector<std::vector<CooFormat>> coo_tensors_by_id =
316+
SortAndGroupCooTensorsPerLocalDevice(
317+
coo_tensors, batch_size_per_sc, num_scs, batch_size_for_device,
318+
stacked_table_metadata[0].max_ids_per_partition,
319+
stacked_table_metadata[0].max_unique_ids_per_partition,
320+
stacked_table_name, allow_id_dropping, num_sc_per_device,
321+
total_num_coo_tensors, total_max_ids_per_sc,
322+
total_max_unique_ids_per_sc, required_buffer_size_per_sc);
326323
for (int i = 0; i < num_sc_per_device; ++i) {
327324
coo_tensors_by_id[i].emplace_back(batch_size_per_sc * (i + 1), 0, 0.0);
328325
}
@@ -384,6 +381,7 @@ py::tuple PreprocessSparseDenseMatmulInput(
384381
py::dict lhs_gains;
385382
py::dict max_ids_per_partition;
386383
py::dict max_unique_ids_per_partition;
384+
py::dict required_buffer_sizes;
387385
const int num_scs = num_sc_per_device * global_device_count;
388386
const int row_pointers_size_per_sc = std::max(num_scs, 8);
389387

@@ -446,6 +444,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
446444
py::array_t<int>(stats_shape);
447445
py::array_t<int> max_unique_ids_per_partition_per_sc =
448446
py::array_t<int>(stats_shape);
447+
py::array_t<int> required_buffer_size_per_sc =
448+
py::array_t<int>(stats_shape);
449449
for (int local_device = 0; local_device < local_device_count;
450450
++local_device) {
451451
// Get the tuple outputs for the current split.
@@ -466,6 +466,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
466466
max_ids_per_partition_per_sc[stats_slice];
467467
auto max_unique_ids_per_partition_per_sc_buffer =
468468
max_unique_ids_per_partition_per_sc[stats_slice];
469+
auto required_buffer_size_per_sc_buffer =
470+
required_buffer_size_per_sc[stats_slice];
469471
PreprocessInputForStackedTablePerLocalDevice(
470472
stacked_table_metadata, features, feature_weights, local_device,
471473
local_device_count, coo_buffer_size, row_pointers_size_per_sc,
@@ -477,7 +479,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
477479
py::cast<py::array_t<float>>(gain_buffer),
478480
py::cast<py::array_t<int>>(max_ids_per_partition_per_sc_buffer),
479481
py::cast<py::array_t<int>>(
480-
max_unique_ids_per_partition_per_sc_buffer));
482+
max_unique_ids_per_partition_per_sc_buffer),
483+
py::cast<py::array_t<int>>(required_buffer_size_per_sc_buffer));
481484
}
482485
lhs_row_pointers[stacked_table_name.c_str()] =
483486
std::move(row_pointers_per_device);
@@ -490,6 +493,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
490493
std::move(max_ids_per_partition_per_sc);
491494
max_unique_ids_per_partition[stacked_table_name.c_str()] =
492495
std::move(max_unique_ids_per_partition_per_sc);
496+
required_buffer_sizes[stacked_table_name.c_str()] =
497+
std::move(required_buffer_size_per_sc);
493498
counter.DecrementCount();
494499
});
495500
}
@@ -498,6 +503,8 @@ py::tuple PreprocessSparseDenseMatmulInput(
498503
py::dict stats;
499504
stats["max_ids"] = max_ids_per_partition;
500505
stats["max_unique_ids"] = max_unique_ids_per_partition;
506+
stats["required_buffer_size"] = std::move(required_buffer_sizes);
507+
501508
// GIL is held at this point.
502509
return py::make_tuple(lhs_row_pointers, lhs_embedding_ids, lhs_sample_ids,
503510
lhs_gains, stats);

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc

Lines changed: 76 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,49 @@
3434
#include "tsl/profiler/lib/traceme.h"
3535

3636
namespace jax_sc_embedding {
37+
namespace {
38+
39+
void ValidateMaxIdsOrDie(const int32_t observed_max_ids_per_partition,
40+
const int32_t observed_max_unique_ids_per_partition,
41+
const int32_t max_ids_per_partition,
42+
const int32_t max_unique_ids_per_partition,
43+
const absl::string_view stacked_table_name,
44+
const bool allow_id_dropping) {
45+
// If id dropping is allowed, we log a warning if the observed max ids per
46+
// partition is greater than the set max ids per partition.
47+
if (observed_max_ids_per_partition > max_ids_per_partition) {
48+
if (allow_id_dropping) {
49+
LOG(WARNING) << "Allowing ID dropping for table: " << stacked_table_name
50+
<< " observed max ids per partition: "
51+
<< observed_max_ids_per_partition
52+
<< " is greater than the set max ids per partition: "
53+
<< max_ids_per_partition;
54+
} else {
55+
LOG(FATAL) << "Observed max ids per partition: "
56+
<< observed_max_ids_per_partition
57+
<< " for table: " << stacked_table_name
58+
<< " is greater than the set max ids per partition: "
59+
<< max_ids_per_partition;
60+
}
61+
}
62+
if (observed_max_unique_ids_per_partition > max_unique_ids_per_partition) {
63+
if (allow_id_dropping) {
64+
LOG(WARNING) << "Allowing ID dropping for table: " << stacked_table_name
65+
<< " observed max unique ids per partition: "
66+
<< observed_max_unique_ids_per_partition
67+
<< " is greater than the set max unique ids per partition: "
68+
<< max_unique_ids_per_partition;
69+
} else {
70+
LOG(FATAL) << "Observed max unique ids per partition: "
71+
<< observed_max_unique_ids_per_partition
72+
<< " for table: " << stacked_table_name
73+
<< " is greater than the set max unique ids per partition: "
74+
<< max_unique_ids_per_partition;
75+
}
76+
}
77+
}
78+
79+
} // namespace
3780

3881
int GetColId(const int col_id, const int col_shift, const int col_offset,
3982
const int num_scs_mod, const int num_scs_mod_inv) {
@@ -55,19 +98,28 @@ RowCombiner GetRowCombiner(absl::string_view combiner) {
5598
return RowCombiner::kSum;
5699
}
57100

58-
void SortAndGroupCooTensorsPerLocalDevice(
101+
std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
59102
absl::Span<const CooFormat> coo_tensors, const int batch_size_per_sc,
60103
const int global_sc_count, const int32_t batch_size_for_device,
61104
const int32_t max_ids_per_partition,
62105
const int32_t max_unique_ids_per_partition,
63106
const absl::string_view stacked_table_name, const bool allow_id_dropping,
64-
std::vector<std::vector<CooFormat>>& coo_tensors_by_id, int* max_ids_per_sc,
65-
int* max_unique_ids_per_sc) {
107+
const int num_sc_per_device, const int total_num_coo_tensors,
108+
int max_ids_per_sc[], int max_unique_ids_per_sc[],
109+
int required_buffer_size_per_sc[]) {
66110
tsl::profiler::TraceMe t("SortAndGroupCooTensors");
67111
const int local_sc_count = batch_size_for_device / batch_size_per_sc;
68-
uint32_t index = 0;
112+
std::vector<std::vector<CooFormat>> coo_tensors_by_id;
113+
coo_tensors_by_id.resize(num_sc_per_device);
114+
const int approximate_num_coo_tensors_per_sc =
115+
total_num_coo_tensors / num_sc_per_device + 1;
116+
for (int i = 0; i < num_sc_per_device; ++i) {
117+
// Roughly estimate the number of COO tensors for each SC.
118+
coo_tensors_by_id[i].reserve(approximate_num_coo_tensors_per_sc);
119+
}
120+
121+
uint32_t coo_tensor_index = 0;
69122
const int32_t num_scs_bit = std::log2(global_sc_count);
70-
const int total_coo_tensors = coo_tensors.size();
71123
// Initialize the aggregated max ids and unique ids per SC to 0.
72124
for (int32_t global_sc_id = 0; global_sc_id < global_sc_count;
73125
++global_sc_id) {
@@ -76,29 +128,30 @@ void SortAndGroupCooTensorsPerLocalDevice(
76128
}
77129
// Loop over scs for this device.
78130
for (int32_t local_sc_id = 0; local_sc_id < local_sc_count; ++local_sc_id) {
79-
const int num_partitions = global_sc_count;
80-
std::vector<int32_t> ids_per_sc_partition(num_partitions, 0);
81-
std::vector<int32_t> unique_ids_per_sc_partition(num_partitions, 0);
131+
std::vector<int32_t> ids_per_sc_partition(global_sc_count, 0);
132+
std::vector<int32_t> unique_ids_per_sc_partition(global_sc_count, 0);
82133
std::vector<uint64_t> keys;
83134
keys.reserve(batch_size_per_sc);
84135
// We take the advantage of the fact that the row_ids are already sorted
85136
// within each batch.
86-
while (index < total_coo_tensors &&
87-
coo_tensors[index].row_id < (local_sc_id + 1) * batch_size_per_sc) {
137+
for (; coo_tensor_index < coo_tensors.size() &&
138+
coo_tensors[coo_tensor_index].row_id <
139+
(local_sc_id + 1) * batch_size_per_sc;
140+
coo_tensor_index++) {
88141
// The key here is [col_ids % num_scs, col_ids / num_scs, index].
89142
// Note that this assumes `num_scs` is a power of 2.
90143
keys.push_back(
91144
(static_cast<uint64_t>(absl::rotr(
92-
static_cast<uint32_t>(coo_tensors[index].col_id), num_scs_bit))
145+
static_cast<uint32_t>(coo_tensors[coo_tensor_index].col_id),
146+
num_scs_bit))
93147
<< 32) +
94-
index);
95-
++index;
148+
coo_tensor_index);
96149
}
97150
hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending());
98151

99152
uint32_t prev_col_id = std::numeric_limits<uint32_t>::max();
100153
uint32_t prev_row_id = std::numeric_limits<uint32_t>::max();
101-
for (const auto key : keys) {
154+
for (const uint64_t key : keys) {
102155
const uint32_t index = static_cast<uint32_t>(key);
103156
const CooFormat& coo_tensor = coo_tensors[index];
104157
const uint32_t global_sc_id =
@@ -133,6 +186,8 @@ void SortAndGroupCooTensorsPerLocalDevice(
133186
for (int global_sc_id = 0; global_sc_id < global_sc_count; ++global_sc_id) {
134187
max_ids_per_sc[global_sc_id] = std::max(
135188
max_ids_per_sc[global_sc_id], ids_per_sc_partition[global_sc_id]);
189+
required_buffer_size_per_sc[local_sc_id] +=
190+
jax_sc_embedding::RoundUpTo(ids_per_sc_partition[global_sc_id], 8);
136191
max_unique_ids_per_sc[global_sc_id] =
137192
std::max(max_unique_ids_per_sc[global_sc_id],
138193
unique_ids_per_sc_partition[global_sc_id]);
@@ -158,44 +213,14 @@ void SortAndGroupCooTensorsPerLocalDevice(
158213
*absl::c_max_element(ids_per_sc_partition);
159214
const int32_t observed_max_unique_ids_per_partition =
160215
*absl::c_max_element(unique_ids_per_sc_partition);
161-
// If id dropping is allowed, we log a warning if the observed max ids per
162-
// partition is greater than the set max ids per partition.
163-
if (observed_max_ids_per_partition > max_ids_per_partition) {
164-
if (allow_id_dropping) {
165-
LOG(WARNING) << "Allowing ID dropping for table: " << stacked_table_name
166-
<< " observed max ids per partition: "
167-
<< observed_max_ids_per_partition
168-
<< " is greater than the set max ids per partition: "
169-
<< max_ids_per_partition;
170-
} else {
171-
LOG(FATAL) << "Observed max ids per partition: "
172-
<< observed_max_ids_per_partition
173-
<< " for table: " << stacked_table_name
174-
<< " is greater than the set max ids per partition: "
175-
<< max_ids_per_partition;
176-
}
177-
}
178-
if (observed_max_unique_ids_per_partition > max_unique_ids_per_partition) {
179-
if (allow_id_dropping) {
180-
LOG(WARNING)
181-
<< "Allowing ID dropping for table: " << stacked_table_name
182-
<< " observed max unique ids per partition: "
183-
<< observed_max_unique_ids_per_partition
184-
<< " is greater than the set max unique ids per partition: "
185-
<< max_unique_ids_per_partition;
186-
} else {
187-
LOG(FATAL) << "Observed max unique ids per partition: "
188-
<< observed_max_unique_ids_per_partition
189-
<< " for table: " << stacked_table_name
190-
<< " is greater than the set max unique ids per partition: "
191-
<< max_unique_ids_per_partition;
192-
}
193-
}
216+
217+
ValidateMaxIdsOrDie(observed_max_ids_per_partition,
218+
observed_max_unique_ids_per_partition,
219+
max_ids_per_partition, max_unique_ids_per_partition,
220+
stacked_table_name, allow_id_dropping);
194221
}
222+
return coo_tensors_by_id;
195223
}
196-
197-
198-
199224
int ComputeCooBufferSize(
200225
const int num_scs, const int num_scs_per_device,
201226
absl::Span<const StackedTableMetadata> stacked_table_metadata,
@@ -251,7 +276,7 @@ void FillRowPointersPerLocalDevice(
251276
absl::Span<const std::vector<CooFormat>> coo_tensors_by_id,
252277
const int row_pointers_size_per_sc, const int coo_buffer_size_per_sc,
253278
const int batch_size_per_sc, const int num_scs, const int num_sc_per_device,
254-
int* row_pointers, int* embedding_ids, int* sample_ids, float* gains) {
279+
int row_pointers[], int embedding_ids[], int sample_ids[], float gains[]) {
255280
tsl::profiler::TraceMe t("FillRowPointers");
256281
for (int local_sc_id = 0; local_sc_id < num_sc_per_device; ++local_sc_id) {
257282
int lhs_row_index = 0;

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ struct CooFormat {
5151
// table.
5252
//
5353
// This packing allows for efficient storage and extractions using bitwise
54-
// masks (assuming `num_scs` is a power of 2).
54+
// masks (assuming number of sparsecores (SC) is a power of 2).
5555
int col_id;
5656
float gain;
5757

@@ -66,6 +66,8 @@ int GetColId(int col_id, int col_shift, int col_offset, int num_scs_mod,
6666
int num_scs_mod_inv);
6767

6868
// Rounds up the given value to the next multiple of the given alignment.
69+
// This is equivalent to ceil(value / align) * align, but implemented in an
70+
// integer-safe way.
6971
template <typename T>
7072
static inline T RoundUpTo(T value, T align) {
7173
return (value + align - 1) / align * align;
@@ -116,14 +118,14 @@ struct StackedTableMetadata {
116118
int max_col_id;
117119
};
118120

119-
void SortAndGroupCooTensorsPerLocalDevice(
121+
std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice(
120122
absl::Span<const CooFormat> coo_tensors, int batch_size_per_sc,
121123
int global_sc_count,
122124
int32_t batch_size_for_device, // Batch size for the local device.
123125
int32_t max_ids_per_partition, int32_t max_unique_ids_per_partition,
124126
absl::string_view stacked_table_name, bool allow_id_dropping,
125-
std::vector<std::vector<CooFormat>>& coo_tensors_by_id, int* max_ids_per_sc,
126-
int* max_unique_ids_per_sc);
127+
int num_sc_per_device, int total_num_coo_tensors, int max_ids_per_sc[],
128+
int max_unique_ids_per_sc[], int required_buffer_size_per_sc[]);
127129

128130
int ComputeCooBufferSize(
129131
int num_scs, int num_scs_per_device,
@@ -140,7 +142,7 @@ void FillRowPointersPerLocalDevice(
140142
absl::Span<const std::vector<CooFormat>> coo_tensors_by_id,
141143
int row_pointers_size_per_sc, int coo_buffer_size_per_sc,
142144
int batch_size_per_sc, int num_scs, int num_sc_per_device,
143-
int* row_pointers, int* embedding_ids, int* sample_ids, float* gains);
145+
int row_pointers[], int embedding_ids[], int sample_ids[], float gains[]);
144146

145147
} // namespace jax_sc_embedding
146148

0 commit comments

Comments
 (0)