@@ -11163,13 +11163,31 @@ struct WhileInductionReduction : public OpRewritePattern<stablehlo::WhileOp> {
11163
11163
for (int i = 0 ; i < pair.lowerBounds .size (); i++) {
11164
11164
update_starts.push_back (rewriter.create <stablehlo::ConstantOp>(
11165
11165
pair.argOperand .getLoc (), itype,
11166
- makeAttr (itype, pair.lowerUpdateBounds [i])
11166
+ makeAttr (itype,
11167
+ pair.lowerUpdateBounds [i] - pair.lowerBounds [i])
11167
11168
.cast <ElementsAttr>()));
11168
11169
}
11169
11170
11170
11171
newArg = rewriter.create <stablehlo::DynamicUpdateSliceOp>(
11171
- pair.argOperand .getLoc (), pair.outerOperand , newArg,
11172
+ pair.argOperand .getLoc (), pair.outerSlice , newArg,
11172
11173
update_starts);
11174
+
11175
+ auto ctype = RankedTensorType::get (
11176
+ {}, cast<RankedTensorType>(pair.argOperand .getType ())
11177
+ .getElementType ());
11178
+ auto padVal = rewriter.create <stablehlo::ConstantOp>(
11179
+ pair.argOperand .getLoc (), ctype,
11180
+ makeAttr (ctype, 0 ).cast <ElementsAttr>());
11181
+
11182
+ SmallVector<int64_t > slow = llvm::to_vector (pair.lowerBounds );
11183
+ SmallVector<int64_t > shigh = llvm::to_vector (
11184
+ cast<RankedTensorType>(pair.argOperand .getType ()).getShape ());
11185
+ for (int i = 0 ; i < shigh.size (); i++)
11186
+ shigh[i] -= pair.upperBounds [i];
11187
+ SmallVector<int64_t > sint (shigh.size (), 0 );
11188
+
11189
+ newArg = rewriter.create <stablehlo::PadOp>(
11190
+ pair.argOperand .getLoc (), newArg, padVal, slow, shigh, sint);
11173
11191
break ;
11174
11192
}
11175
11193
}
@@ -11229,13 +11247,30 @@ struct WhileInductionReduction : public OpRewritePattern<stablehlo::WhileOp> {
11229
11247
for (int i = 0 ; i < pair.lowerBounds .size (); i++) {
11230
11248
update_starts.push_back (rewriter.create <stablehlo::ConstantOp>(
11231
11249
pair.argOperand .getLoc (), itype,
11232
- makeAttr (itype, pair.lowerUpdateBounds [i])
11250
+ makeAttr (itype,
11251
+ pair.lowerUpdateBounds [i] - pair.lowerBounds [i])
11233
11252
.cast <ElementsAttr>()));
11234
11253
}
11235
11254
11236
11255
newArg = rewriter.create <stablehlo::DynamicUpdateSliceOp>(
11237
- pair.argOperand .getLoc (), pair.outerOperand , newArg,
11256
+ pair.argOperand .getLoc (), pair.outerSlice , newArg,
11238
11257
update_starts);
11258
+ auto ctype = RankedTensorType::get (
11259
+ {}, cast<RankedTensorType>(pair.condOperand .getType ())
11260
+ .getElementType ());
11261
+ auto padVal = rewriter.create <stablehlo::ConstantOp>(
11262
+ pair.condOperand .getLoc (), ctype,
11263
+ makeAttr (ctype, 0 ).cast <ElementsAttr>());
11264
+
11265
+ SmallVector<int64_t > slow = llvm::to_vector (pair.lowerBounds );
11266
+ SmallVector<int64_t > shigh = llvm::to_vector (
11267
+ cast<RankedTensorType>(pair.condOperand .getType ()).getShape ());
11268
+ for (int i = 0 ; i < shigh.size (); i++)
11269
+ shigh[i] -= pair.upperBounds [i];
11270
+ SmallVector<int64_t > sint (shigh.size (), 0 );
11271
+
11272
+ newArg = rewriter.create <stablehlo::PadOp>(
11273
+ pair.condOperand .getLoc (), newArg, padVal, slow, shigh, sint);
11239
11274
break ;
11240
11275
}
11241
11276
}
0 commit comments