Skip to content

Commit f72e953

Browse files
committed
add cloneToShapedType static method
1 parent 878e69a commit f72e953

File tree

1 file changed

+19
-30
lines changed

1 file changed

+19
-30
lines changed

mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
3535
return rewriter.create<arith::ConstantOp>(loc, attr);
3636
}
3737

38+
/// Creates shapedType using shape from cloneFrom and base type from cloneTo
39+
static Type cloneToShapedType(Type cloneFrom, Type cloneTo) {
40+
if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41+
return shapedTy.clone(cloneTo);
42+
}
43+
return cloneTo;
44+
}
45+
3846
namespace {
3947

4048
/// Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225233
return rewriter.notifyMatchFailure(op, "not a ext of bf16 to f32.");
226234
}
227235

228-
Type i16Ty = b.getI16Type();
229-
Type i32Ty = b.getI32Type();
230-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231-
i16Ty = shapedTy.clone(i16Ty);
232-
i32Ty = shapedTy.clone(i32Ty);
233-
}
236+
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
237+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
234238

235239
Value bitcast = b.create<arith::BitcastOp>(i16Ty, operand);
236240
Value exti = b.create<arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264268
op, "only applicable to default rounding mode.");
265269
}
266270

267-
Type i16Ty = b.getI16Type();
268-
Type i32Ty = b.getI32Type();
269-
Type f32Ty = b.getF32Type();
270-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271-
i16Ty = shapedTy.clone(i16Ty);
272-
i32Ty = shapedTy.clone(i32Ty);
273-
f32Ty = shapedTy.clone(f32Ty);
274-
}
271+
Type i16Ty = cloneToShapedType(operandTy, b.getI16Type());
272+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
275273

276274
// Algorithm borrowed from this excellent code:
277275
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -340,14 +338,9 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
340338
return rewriter.notifyMatchFailure(op, "not a ext of F8E8M0FNU");
341339
}
342340

343-
Type i8Ty = b.getI8Type();
344-
Type i32Ty = b.getI32Type();
345-
Type f32Ty = b.getF32Type();
346-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
347-
i8Ty = shapedTy.clone(i8Ty);
348-
i32Ty = shapedTy.clone(i32Ty);
349-
f32Ty = shapedTy.clone(f32Ty);
350-
}
341+
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
342+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
343+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
351344

352345
Value bitcast = b.create<arith::BitcastOp>(i8Ty, operand);
353346
// create constants for NaNs
@@ -397,14 +390,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
397390
op, "only applicable to default rounding mode.");
398391
}
399392

400-
Type i8Ty = b.getI8Type();
401-
Type i32Ty = b.getI32Type();
402-
Type f32Ty = b.getF32Type();
403-
if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
404-
i8Ty = shapedTy.clone(i8Ty);
405-
i32Ty = shapedTy.clone(i32Ty);
406-
f32Ty = shapedTy.clone(f32Ty);
407-
}
393+
Type i8Ty = cloneToShapedType(operandTy, b.getI8Type());
394+
Type i32Ty = cloneToShapedType(operandTy, b.getI32Type());
395+
Type f32Ty = cloneToShapedType(operandTy, b.getF32Type());
396+
408397
if (operandETy.getIntOrFloatBitWidth() < 32) {
409398
operand = b.create<arith::ExtFOp>(f32Ty, operand);
410399
} else if (operandETy.getIntOrFloatBitWidth() > 32) {

0 commit comments

Comments
 (0)