@@ -141,19 +141,22 @@ void simplify_iteration_space(int &nd,
141
141
assert (simplified_shape.size () == static_cast <size_t >(nd));
142
142
143
143
simplified_src_strides.reserve (nd);
144
- simplified_src_strides.push_back (
145
- (src_strides[0 ] >= 0 ) ? src_strides[0 ] : -src_strides[0 ]);
146
- if ((src_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
147
- src_offset += (shape[0 ] - 1 ) * src_strides[0 ];
148
- }
149
- assert (simplified_src_strides.size () == static_cast <size_t >(nd));
150
-
151
144
simplified_dst_strides.reserve (nd);
152
- simplified_dst_strides.push_back (
153
- (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
154
- if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
155
- dst_offset += (shape[0 ] - 1 ) * dst_strides[0 ];
145
+
146
+ if (src_strides[0 ] < 0 && dst_strides[0 ] < 0 ) {
147
+ simplified_src_strides.push_back (-src_strides[0 ]);
148
+ simplified_dst_strides.push_back (-dst_strides[0 ]);
149
+ if (shape[0 ] > 1 ) {
150
+ src_offset += (shape[0 ] - 1 ) * src_strides[0 ];
151
+ dst_offset += (shape[0 ] - 1 ) * dst_strides[0 ];
152
+ }
153
+ }
154
+ else {
155
+ simplified_src_strides.push_back (src_strides[0 ]);
156
+ simplified_dst_strides.push_back (dst_strides[0 ]);
156
157
}
158
+
159
+ assert (simplified_src_strides.size () == static_cast <size_t >(nd));
157
160
assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
158
161
}
159
162
}
@@ -226,27 +229,28 @@ void simplify_iteration_space_3(
226
229
assert (simplified_shape.size () == static_cast <size_t >(nd));
227
230
228
231
simplified_src1_strides.reserve (nd);
229
- simplified_src1_strides.push_back (
230
- (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
231
- if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
232
- src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
233
- }
234
- assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
235
-
236
232
simplified_src2_strides.reserve (nd);
237
- simplified_src2_strides.push_back (
238
- (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
239
- if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
240
- src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
241
- }
242
- assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
243
-
244
233
simplified_dst_strides.reserve (nd);
245
- simplified_dst_strides.push_back (
246
- (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
247
- if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
248
- dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
234
+
235
+ if ((src1_strides[0 ] < 0 ) && (src2_strides[0 ] < 0 ) &&
236
+ (dst_strides[0 ] < 0 )) {
237
+ simplified_src1_strides.push_back (-src1_strides[0 ]);
238
+ simplified_src2_strides.push_back (-src2_strides[0 ]);
239
+ simplified_dst_strides.push_back (-dst_strides[0 ]);
240
+ if (shape[0 ] > 1 ) {
241
+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
242
+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
243
+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
244
+ }
245
+ }
246
+ else {
247
+ simplified_src1_strides.push_back (src1_strides[0 ]);
248
+ simplified_src2_strides.push_back (src2_strides[0 ]);
249
+ simplified_dst_strides.push_back (dst_strides[0 ]);
249
250
}
251
+
252
+ assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
253
+ assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
250
254
assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
251
255
}
252
256
}
@@ -333,35 +337,34 @@ void simplify_iteration_space_4(
333
337
assert (simplified_shape.size () == static_cast <size_t >(nd));
334
338
335
339
simplified_src1_strides.reserve (nd);
336
- simplified_src1_strides.push_back (
337
- (src1_strides[0 ] >= 0 ) ? src1_strides[0 ] : -src1_strides[0 ]);
338
- if ((src1_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
339
- src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
340
- }
341
- assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
342
-
343
340
simplified_src2_strides.reserve (nd);
344
- simplified_src2_strides.push_back (
345
- (src2_strides[0 ] >= 0 ) ? src2_strides[0 ] : -src2_strides[0 ]);
346
- if ((src2_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
347
- src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
348
- }
349
- assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
350
-
351
341
simplified_src3_strides.reserve (nd);
352
- simplified_src3_strides.push_back (
353
- (src3_strides[0 ] >= 0 ) ? src3_strides[0 ] : -src3_strides[0 ]);
354
- if ((src3_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
355
- src3_offset += src3_strides[0 ] * (shape[0 ] - 1 );
356
- }
357
- assert (simplified_src3_strides.size () == static_cast <size_t >(nd));
358
-
359
342
simplified_dst_strides.reserve (nd);
360
- simplified_dst_strides.push_back (
361
- (dst_strides[0 ] >= 0 ) ? dst_strides[0 ] : -dst_strides[0 ]);
362
- if ((dst_strides[0 ] < 0 ) && (shape[0 ] > 1 )) {
363
- dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
343
+
344
+ if ((src1_strides[0 ] < 0 ) && (src2_strides[0 ] < 0 ) &&
345
+ (src3_strides[0 ] < 0 ) && (dst_strides[0 ] < 0 ))
346
+ {
347
+ simplified_src1_strides.push_back (-src1_strides[0 ]);
348
+ simplified_src2_strides.push_back (-src2_strides[0 ]);
349
+ simplified_src3_strides.push_back (-src3_strides[0 ]);
350
+ simplified_dst_strides.push_back (-dst_strides[0 ]);
351
+ if (shape[0 ] > 1 ) {
352
+ src1_offset += src1_strides[0 ] * (shape[0 ] - 1 );
353
+ src2_offset += src2_strides[0 ] * (shape[0 ] - 1 );
354
+ src3_offset += src3_strides[0 ] * (shape[0 ] - 1 );
355
+ dst_offset += dst_strides[0 ] * (shape[0 ] - 1 );
356
+ }
357
+ }
358
+ else {
359
+ simplified_src1_strides.push_back (src1_strides[0 ]);
360
+ simplified_src2_strides.push_back (src2_strides[0 ]);
361
+ simplified_src3_strides.push_back (src3_strides[0 ]);
362
+ simplified_dst_strides.push_back (dst_strides[0 ]);
364
363
}
364
+
365
+ assert (simplified_src1_strides.size () == static_cast <size_t >(nd));
366
+ assert (simplified_src2_strides.size () == static_cast <size_t >(nd));
367
+ assert (simplified_src3_strides.size () == static_cast <size_t >(nd));
365
368
assert (simplified_dst_strides.size () == static_cast <size_t >(nd));
366
369
}
367
370
}
0 commit comments