34
34
#include " tsl/profiler/lib/traceme.h"
35
35
36
36
namespace jax_sc_embedding {
37
+ namespace {
38
+
39
+ void ValidateMaxIdsOrDie (const int32_t observed_max_ids_per_partition,
40
+ const int32_t observed_max_unique_ids_per_partition,
41
+ const int32_t max_ids_per_partition,
42
+ const int32_t max_unique_ids_per_partition,
43
+ const absl::string_view stacked_table_name,
44
+ const bool allow_id_dropping) {
45
+ // If id dropping is allowed, we log a warning if the observed max ids per
46
+ // partition is greater than the set max ids per partition.
47
+ if (observed_max_ids_per_partition > max_ids_per_partition) {
48
+ if (allow_id_dropping) {
49
+ LOG (WARNING) << " Allowing ID dropping for table: " << stacked_table_name
50
+ << " observed max ids per partition: "
51
+ << observed_max_ids_per_partition
52
+ << " is greater than the set max ids per partition: "
53
+ << max_ids_per_partition;
54
+ } else {
55
+ LOG (FATAL) << " Observed max ids per partition: "
56
+ << observed_max_ids_per_partition
57
+ << " for table: " << stacked_table_name
58
+ << " is greater than the set max ids per partition: "
59
+ << max_ids_per_partition;
60
+ }
61
+ }
62
+ if (observed_max_unique_ids_per_partition > max_unique_ids_per_partition) {
63
+ if (allow_id_dropping) {
64
+ LOG (WARNING) << " Allowing ID dropping for table: " << stacked_table_name
65
+ << " observed max unique ids per partition: "
66
+ << observed_max_unique_ids_per_partition
67
+ << " is greater than the set max unique ids per partition: "
68
+ << max_unique_ids_per_partition;
69
+ } else {
70
+ LOG (FATAL) << " Observed max unique ids per partition: "
71
+ << observed_max_unique_ids_per_partition
72
+ << " for table: " << stacked_table_name
73
+ << " is greater than the set max unique ids per partition: "
74
+ << max_unique_ids_per_partition;
75
+ }
76
+ }
77
+ }
78
+
79
+ } // namespace
37
80
38
81
int GetColId (const int col_id, const int col_shift, const int col_offset,
39
82
const int num_scs_mod, const int num_scs_mod_inv) {
@@ -55,19 +98,28 @@ RowCombiner GetRowCombiner(absl::string_view combiner) {
55
98
return RowCombiner::kSum ;
56
99
}
57
100
58
- void SortAndGroupCooTensorsPerLocalDevice (
101
+ std::vector<std::vector<CooFormat>> SortAndGroupCooTensorsPerLocalDevice (
59
102
absl::Span<const CooFormat> coo_tensors, const int batch_size_per_sc,
60
103
const int global_sc_count, const int32_t batch_size_for_device,
61
104
const int32_t max_ids_per_partition,
62
105
const int32_t max_unique_ids_per_partition,
63
106
const absl::string_view stacked_table_name, const bool allow_id_dropping,
64
- std::vector<std::vector<CooFormat>>& coo_tensors_by_id, int * max_ids_per_sc,
65
- int * max_unique_ids_per_sc) {
107
+ const int num_sc_per_device, const int total_num_coo_tensors,
108
+ int max_ids_per_sc[], int max_unique_ids_per_sc[],
109
+ int required_buffer_size_per_sc[]) {
66
110
tsl::profiler::TraceMe t (" SortAndGroupCooTensors" );
67
111
const int local_sc_count = batch_size_for_device / batch_size_per_sc;
68
- uint32_t index = 0 ;
112
+ std::vector<std::vector<CooFormat>> coo_tensors_by_id;
113
+ coo_tensors_by_id.resize (num_sc_per_device);
114
+ const int approximate_num_coo_tensors_per_sc =
115
+ total_num_coo_tensors / num_sc_per_device + 1 ;
116
+ for (int i = 0 ; i < num_sc_per_device; ++i) {
117
+ // Roughly estimate the number of COO tensors for each SC.
118
+ coo_tensors_by_id[i].reserve (approximate_num_coo_tensors_per_sc);
119
+ }
120
+
121
+ uint32_t coo_tensor_index = 0 ;
69
122
const int32_t num_scs_bit = std::log2 (global_sc_count);
70
- const int total_coo_tensors = coo_tensors.size ();
71
123
// Initialize the aggregated max ids and unique ids per SC to 0.
72
124
for (int32_t global_sc_id = 0 ; global_sc_id < global_sc_count;
73
125
++global_sc_id) {
@@ -76,29 +128,30 @@ void SortAndGroupCooTensorsPerLocalDevice(
76
128
}
77
129
// Loop over scs for this device.
78
130
for (int32_t local_sc_id = 0 ; local_sc_id < local_sc_count; ++local_sc_id) {
79
- const int num_partitions = global_sc_count;
80
- std::vector<int32_t > ids_per_sc_partition (num_partitions, 0 );
81
- std::vector<int32_t > unique_ids_per_sc_partition (num_partitions, 0 );
131
+ std::vector<int32_t > ids_per_sc_partition (global_sc_count, 0 );
132
+ std::vector<int32_t > unique_ids_per_sc_partition (global_sc_count, 0 );
82
133
std::vector<uint64_t > keys;
83
134
keys.reserve (batch_size_per_sc);
84
135
// We take the advantage of the fact that the row_ids are already sorted
85
136
// within each batch.
86
- while (index < total_coo_tensors &&
87
- coo_tensors[index ].row_id < (local_sc_id + 1 ) * batch_size_per_sc) {
137
+ for (; coo_tensor_index < coo_tensors.size () &&
138
+ coo_tensors[coo_tensor_index].row_id <
139
+ (local_sc_id + 1 ) * batch_size_per_sc;
140
+ coo_tensor_index++) {
88
141
// The key here is [col_ids % num_scs, col_ids / num_scs, index].
89
142
// Note that this assumes `num_scs` is a power of 2.
90
143
keys.push_back (
91
144
(static_cast <uint64_t >(absl::rotr (
92
- static_cast <uint32_t >(coo_tensors[index ].col_id ), num_scs_bit))
145
+ static_cast <uint32_t >(coo_tensors[coo_tensor_index].col_id ),
146
+ num_scs_bit))
93
147
<< 32 ) +
94
- index );
95
- ++index ;
148
+ coo_tensor_index);
96
149
}
97
150
hwy::VQSort (keys.data (), keys.size (), hwy::SortAscending ());
98
151
99
152
uint32_t prev_col_id = std::numeric_limits<uint32_t >::max ();
100
153
uint32_t prev_row_id = std::numeric_limits<uint32_t >::max ();
101
- for (const auto key : keys) {
154
+ for (const uint64_t key : keys) {
102
155
const uint32_t index = static_cast <uint32_t >(key);
103
156
const CooFormat& coo_tensor = coo_tensors[index ];
104
157
const uint32_t global_sc_id =
@@ -133,6 +186,8 @@ void SortAndGroupCooTensorsPerLocalDevice(
133
186
for (int global_sc_id = 0 ; global_sc_id < global_sc_count; ++global_sc_id) {
134
187
max_ids_per_sc[global_sc_id] = std::max (
135
188
max_ids_per_sc[global_sc_id], ids_per_sc_partition[global_sc_id]);
189
+ required_buffer_size_per_sc[local_sc_id] +=
190
+ jax_sc_embedding::RoundUpTo (ids_per_sc_partition[global_sc_id], 8 );
136
191
max_unique_ids_per_sc[global_sc_id] =
137
192
std::max (max_unique_ids_per_sc[global_sc_id],
138
193
unique_ids_per_sc_partition[global_sc_id]);
@@ -158,44 +213,14 @@ void SortAndGroupCooTensorsPerLocalDevice(
158
213
*absl::c_max_element (ids_per_sc_partition);
159
214
const int32_t observed_max_unique_ids_per_partition =
160
215
*absl::c_max_element (unique_ids_per_sc_partition);
161
- // If id dropping is allowed, we log a warning if the observed max ids per
162
- // partition is greater than the set max ids per partition.
163
- if (observed_max_ids_per_partition > max_ids_per_partition) {
164
- if (allow_id_dropping) {
165
- LOG (WARNING) << " Allowing ID dropping for table: " << stacked_table_name
166
- << " observed max ids per partition: "
167
- << observed_max_ids_per_partition
168
- << " is greater than the set max ids per partition: "
169
- << max_ids_per_partition;
170
- } else {
171
- LOG (FATAL) << " Observed max ids per partition: "
172
- << observed_max_ids_per_partition
173
- << " for table: " << stacked_table_name
174
- << " is greater than the set max ids per partition: "
175
- << max_ids_per_partition;
176
- }
177
- }
178
- if (observed_max_unique_ids_per_partition > max_unique_ids_per_partition) {
179
- if (allow_id_dropping) {
180
- LOG (WARNING)
181
- << " Allowing ID dropping for table: " << stacked_table_name
182
- << " observed max unique ids per partition: "
183
- << observed_max_unique_ids_per_partition
184
- << " is greater than the set max unique ids per partition: "
185
- << max_unique_ids_per_partition;
186
- } else {
187
- LOG (FATAL) << " Observed max unique ids per partition: "
188
- << observed_max_unique_ids_per_partition
189
- << " for table: " << stacked_table_name
190
- << " is greater than the set max unique ids per partition: "
191
- << max_unique_ids_per_partition;
192
- }
193
- }
216
+
217
+ ValidateMaxIdsOrDie (observed_max_ids_per_partition,
218
+ observed_max_unique_ids_per_partition,
219
+ max_ids_per_partition, max_unique_ids_per_partition,
220
+ stacked_table_name, allow_id_dropping);
194
221
}
222
+ return coo_tensors_by_id;
195
223
}
196
-
197
-
198
-
199
224
int ComputeCooBufferSize (
200
225
const int num_scs, const int num_scs_per_device,
201
226
absl::Span<const StackedTableMetadata> stacked_table_metadata,
@@ -251,7 +276,7 @@ void FillRowPointersPerLocalDevice(
251
276
absl::Span<const std::vector<CooFormat>> coo_tensors_by_id,
252
277
const int row_pointers_size_per_sc, const int coo_buffer_size_per_sc,
253
278
const int batch_size_per_sc, const int num_scs, const int num_sc_per_device,
254
- int * row_pointers, int * embedding_ids, int * sample_ids, float * gains) {
279
+ int row_pointers[] , int embedding_ids[] , int sample_ids[] , float gains[] ) {
255
280
tsl::profiler::TraceMe t (" FillRowPointers" );
256
281
for (int local_sc_id = 0 ; local_sc_id < num_sc_per_device; ++local_sc_id) {
257
282
int lhs_row_index = 0 ;
0 commit comments