Skip to content

Commit b2badda

Browse files
authored
Add BroadcastIndexesIterator::operator+ (#9057)
Needed to efficiently use parallel_for with BroadcastIndexesRange.
1 parent 6ba6e3e commit b2badda

File tree

2 files changed

+85
-32
lines changed

2 files changed

+85
-32
lines changed

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include <iterator>
1515
#include <tuple>
1616

17+
#include <executorch/kernels/portable/cpu/util/delinearize_index.h>
1718
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1819
#include <executorch/runtime/core/exec_aten/util/tensor_dimension_limit.h>
1920

@@ -78,7 +79,9 @@ class BroadcastIndexesIterator {
7879
// You might wonder what happens if output_shape_[ii] == 0. In
7980
// that case, output.numel() would be 0, and thus we would have
8081
// begin() == end() and no iteration.
81-
if ET_UNLIKELY (delinearized_output_index_[ii] == output_shape_[ii] - 1) {
82+
if ET_UNLIKELY (
83+
static_cast<exec_aten::SizesType>(delinearized_output_index_[ii]) ==
84+
output_shape_[ii] - 1) {
8285
const auto old_delinearized_output_index_item =
8386
delinearized_output_index_[ii];
8487
delinearized_output_index_[ii] = 0;
@@ -104,11 +107,42 @@ class BroadcastIndexesIterator {
104107
return it;
105108
}
106109

110+
BroadcastIndexesIterator& operator+=(difference_type n) {
111+
if (n <= 3) {
112+
std::advance(*this, n);
113+
return *this;
114+
}
115+
116+
output_index() += n;
117+
delinearize_index(
118+
output_index(),
119+
output_shape_,
120+
delinearized_output_index_.data(),
121+
delinearized_output_index_.size());
122+
for (const auto ii : c10::irange(1, kNumInputs + 1)) {
123+
current_indexes_[ii] = 0;
124+
for (const auto jj : c10::irange(output_dim_)) {
125+
current_indexes_[ii] += delinearized_output_index_[jj] *
126+
effective_input_broadcast_strides_[ii - 1][jj];
127+
}
128+
}
129+
return *this;
130+
}
131+
132+
BroadcastIndexesIterator operator+(difference_type n) {
133+
auto it = *this;
134+
it += n;
135+
return it;
136+
}
137+
107138
difference_type operator-(const BroadcastIndexesIterator& rhs) const {
108139
return difference_type(output_index() - rhs.output_index());
109140
}
110141

111142
private:
143+
using ShapeType =
144+
std::array<std::size_t, executorch::runtime::kTensorDimensionLimit>;
145+
112146
ssize_t output_index() const {
113147
return current_indexes_[0];
114148
}
@@ -117,11 +151,10 @@ class BroadcastIndexesIterator {
117151
return current_indexes_[0];
118152
}
119153

120-
std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>
121-
effective_input_broadcast_stride(const Tensor& output, const Tensor& t)
122-
const {
123-
std::array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>
124-
result = {0};
154+
ShapeType effective_input_broadcast_stride(
155+
const Tensor& output,
156+
const Tensor& t) const {
157+
ShapeType result = {0};
125158
ET_CHECK_MSG(
126159
t.dim() <= output.dim(),
127160
"input to broadcasting op should have dim at most output dim, but %d > %d!",
@@ -146,8 +179,6 @@ class BroadcastIndexesIterator {
146179
// The 0th entry is the current linear index into the output,
147180
// followed by kNumInputs input indexes.
148181
std::array<ssize_t, kNumInputs + 1> current_indexes_ = {0};
149-
using ShapeType = std::
150-
array<exec_aten::SizesType, executorch::runtime::kTensorDimensionLimit>;
151182
ShapeType delinearized_output_index_ = {0};
152183
ssize_t output_dim_;
153184
ArrayRef<exec_aten::SizesType> output_shape_;

kernels/portable/cpu/util/test/broadcast_indexes_range_test.cpp

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,15 @@ TEST(BroadcastIndexesRangeTest, ScalarBroadcastToOneD) {
6868
EXPECT_EQ(expected, actual);
6969
}
7070

71+
template <typename Range>
72+
void test_operator_plus(const Range& range) {
73+
size_t idx = 0;
74+
for (const auto indexes : range) {
75+
EXPECT_EQ(*(range.begin() + idx), indexes);
76+
idx++;
77+
}
78+
}
79+
7180
// [1] -> [H, W]
7281
// [W] -> [H, W]
7382
// [1, 1] -> [H, W]
@@ -87,14 +96,15 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) {
8796

8897
Tensor in_not_broadcast = tf.zeros({3, 4});
8998

90-
auto actual = range_to_vec(BroadcastIndexesRange<6>(
99+
const auto range = BroadcastIndexesRange<6>(
91100
out,
92101
in_0d_scalar,
93102
in_1d_scalar,
94103
in_2d_scalar,
95104
in_row,
96105
in_col,
97-
in_not_broadcast));
106+
in_not_broadcast);
107+
auto actual = range_to_vec(range);
98108
decltype(actual) expected = {
99109
{0, 0, 0, 0, 0, 0, 0},
100110
{1, 0, 0, 0, 1, 0, 1},
@@ -110,6 +120,8 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDExhaustive) {
110120
{11, 0, 0, 0, 3, 2, 11},
111121
};
112122
EXPECT_EQ(expected, actual);
123+
124+
test_operator_plus(range);
113125
}
114126

115127
// Make sure nothing is thrown off by a size-1 dim in the output:
@@ -138,20 +150,20 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) {
138150
Tensor in_col = tf.zeros({H, 1});
139151

140152
size_t idx = 0;
153+
const auto range_row = BroadcastIndexesRange<5>(
154+
out_row,
155+
in_0d_scalar,
156+
in_1d_scalar,
157+
in_2d_scalar,
158+
in_row,
159+
in_leading_one_row);
141160
for (const auto
142161
[out_idx,
143162
in_0d_idx,
144163
in_1d_idx,
145164
in_2d_idx,
146165
in_row_idx,
147-
in_leading_one_row_idx] :
148-
BroadcastIndexesRange<5>(
149-
out_row,
150-
in_0d_scalar,
151-
in_1d_scalar,
152-
in_2d_scalar,
153-
in_row,
154-
in_leading_one_row)) {
166+
in_leading_one_row_idx] : range_row) {
155167
EXPECT_EQ(out_idx, idx++);
156168
EXPECT_EQ(in_0d_idx, 0);
157169
EXPECT_EQ(in_1d_idx, 0);
@@ -160,16 +172,21 @@ TEST(BroadcastIndexesRangeTest, OneAndTwoDWith1InOutputShapeExhaustive) {
160172
EXPECT_EQ(in_leading_one_row_idx, out_idx);
161173
}
162174

175+
test_operator_plus(range_row);
176+
163177
idx = 0;
178+
const auto range_col = BroadcastIndexesRange<4>(
179+
out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col);
164180
for (const auto [out_idx, in_0d_idx, in_1d_idx, in_2d_idx, in_col_idx] :
165-
BroadcastIndexesRange<4>(
166-
out_col, in_0d_scalar, in_1d_scalar, in_2d_scalar, in_col)) {
181+
range_col) {
167182
EXPECT_EQ(out_idx, idx++);
168183
EXPECT_EQ(in_0d_idx, 0);
169184
EXPECT_EQ(in_1d_idx, 0);
170185
EXPECT_EQ(in_2d_idx, 0);
171186
EXPECT_EQ(in_col_idx, out_idx);
172187
}
188+
189+
test_operator_plus(range_col);
173190
}
174191

175192
// [1, 1, 1] -> [C, H, W]
@@ -197,16 +214,17 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) {
197214
// take the opportunity to mutation test against delinearize_index
198215
// and linearize_access_indexes.
199216
int idx = 0;
200-
for (const auto indexes : BroadcastIndexesRange<8>(
201-
out,
202-
input_tensors[0],
203-
input_tensors[1],
204-
input_tensors[2],
205-
input_tensors[3],
206-
input_tensors[4],
207-
input_tensors[5],
208-
input_tensors[6],
209-
input_tensors[7])) {
217+
const auto range = BroadcastIndexesRange<8>(
218+
out,
219+
input_tensors[0],
220+
input_tensors[1],
221+
input_tensors[2],
222+
input_tensors[3],
223+
input_tensors[4],
224+
input_tensors[5],
225+
input_tensors[6],
226+
input_tensors[7]);
227+
for (const auto indexes : range) {
210228
const auto out_idx = indexes[0];
211229
EXPECT_EQ(out_idx, idx++);
212230
size_t out_indexes[executorch::runtime::kTensorDimensionLimit];
@@ -219,6 +237,7 @@ TEST(BroadcastIndexesRangeTest, ThreeDBroadcasting) {
219237
out_indexes, out.dim(), input_tensors[tensor_idx]));
220238
}
221239
}
240+
test_operator_plus(range);
222241
}
223242

224243
// 4-D should generalize, but we will go ahead and test:
@@ -235,8 +254,9 @@ void four_d_broadcasting_test() {
235254
// take the opportunity to mutation test against delinearize_index
236255
// and linearize_access_indexes.
237256
int idx = 0;
238-
for (const auto [out_idx, in_cw_idx, in_nh_idx] :
239-
BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh)) {
257+
const auto range =
258+
BroadcastIndexesRange<2>(out, in_broadcast_cw, in_broadcast_nh);
259+
for (const auto [out_idx, in_cw_idx, in_nh_idx] : range) {
240260
EXPECT_EQ(out_idx, idx++);
241261
size_t out_indexes[executorch::runtime::kTensorDimensionLimit];
242262
delinearize_index(
@@ -248,6 +268,8 @@ void four_d_broadcasting_test() {
248268
in_nh_idx,
249269
linearize_access_indexes(out_indexes, out.dim(), in_broadcast_nh));
250270
}
271+
272+
test_operator_plus(range);
251273
}
252274

253275
TEST(BroadcastIndexesRangeTest, FourDBroadcasting) {

0 commit comments

Comments
 (0)