Skip to content

Commit 2e1dadf

Browse files
authored
Unbreak BroadcastIndexesRange::operator+= when there is no broadcasting (#9374)
operator+= had a loop over 0 elements in this case, resulting in the indices array being full of zeros. Added a += test to our test case that covers this.
1 parent f86d7e3 commit 2e1dadf

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,11 @@ class BroadcastIndexesIterator {
122122
}
123123

124124
output_index() += n;
125+
if (output_dim_or_zero_if_no_broadcasting_ == 0) {
126+
std::fill(
127+
current_indexes_.begin() + 1, current_indexes_.end(), output_index());
128+
return *this;
129+
}
125130
delinearize_index(
126131
output_index(),
127132
output_shape_,

kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ TEST(BroadcastIndexesRangeTest, OneDNotBroadcasted) {
4444

4545
Tensor out = tf.zeros({5});
4646
int idx = 0;
47-
for (const auto& elem : range_to_vec(BroadcastIndexesRange<1>(out, out))) {
47+
const auto range = BroadcastIndexesRange<1>(out, out);
48+
for (const auto& elem : range_to_vec(range)) {
49+
EXPECT_EQ(*(range.begin() + idx), elem);
4850
EXPECT_EQ(elem[0], idx++);
4951
EXPECT_EQ(elem[0], elem[1]);
5052
}
@@ -71,7 +73,7 @@ TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) {
7173
template <typename Range>
7274
void test_operator_plus(const Range& range) {
7375
size_t idx = 0;
74-
for (const auto indexes : range) {
76+
for (const auto& indexes : range) {
7577
EXPECT_EQ(*(range.begin() + idx), indexes);
7678
idx++;
7779
}

0 commit comments

Comments
 (0)