Skip to content

Commit 90e7687

Browse files
[JAX SC] Update MergeAll method to take an absl::Span<PartitionedCooTensors>. This allows for more flexible usage and avoids unnecessary moves.
PiperOrigin-RevId: 845459258
1 parent 41d81e6 commit 90e7687

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ class PartitionedCooTensors {
207207
int GetNumMinibatches() const { return bucket_count_per_sc_; }
208208

209209
static PartitionedCooTensors MergeAll(
210-
std::vector<PartitionedCooTensors>&& parts) {
210+
absl::Span<PartitionedCooTensors> parts) {
211211
DCHECK(!parts.empty());
212212
// If there is only one part, no merging is needed.
213213
if (parts.size() == 1) {

jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors_test.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include <gmock/gmock.h>
2121
#include <gtest/gtest.h>
22+
#include "absl/types/span.h" // from @com_google_absl
2223
#include "jax_tpu_embedding/sparsecore/lib/core/coo_format.h"
2324

2425
namespace jax_sc_embedding {
@@ -217,7 +218,7 @@ TEST(PartitionedCooTensorsTest, MergeDistinctSparseCores) {
217218
parts.push_back(std::move(part2));
218219

219220
PartitionedCooTensors result =
220-
PartitionedCooTensors::MergeAll(std::move(parts));
221+
PartitionedCooTensors::MergeAll(absl::MakeSpan(parts));
221222

222223
// Verify that the result contains 2 SCs.
223224
EXPECT_EQ(result.GetNumMinibatches(), 2);
@@ -259,7 +260,7 @@ TEST(PartitionedCooTensorsTest, MergeDistinctSparseCoresWithOffsets) {
259260
parts.push_back(std::move(part2));
260261

261262
PartitionedCooTensors result =
262-
PartitionedCooTensors::MergeAll(std::move(parts));
263+
PartitionedCooTensors::MergeAll(absl::MakeSpan(parts));
263264

264265
EXPECT_EQ(result.GetNumMinibatches(), 2);
265266
// SC 0 has 2 buckets, sizes 1 and 1.

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ SortAndGroupCooTensorsPerLocalDeviceImpl(
523523
// Merge the parts for this task.
524524
SparseCoreTaskResult task_result = {
525525
.grouped_coo_tensors =
526-
PartitionedCooTensors::MergeAll(std::move(sc_parts)),
526+
PartitionedCooTensors::MergeAll(absl::MakeSpan(sc_parts)),
527527
.stats_host = std::move(local_stats_host),
528528
.dropped_id_count = total_dropped,
529529
.split_val = task_split};
@@ -575,7 +575,7 @@ SortAndGroupCooTensorsPerLocalDeviceImpl(
575575
}
576576

577577
PartitionedCooTensors merged =
578-
PartitionedCooTensors::MergeAll(std::move(parts));
578+
PartitionedCooTensors::MergeAll(absl::MakeSpan(parts));
579579
final_result.emplace(std::make_pair(std::move(merged), total_dropped));
580580
});
581581

0 commit comments

Comments
 (0)