@@ -35,6 +35,14 @@ static Value createConst(Location loc, Type type, int value,
35
35
return rewriter.create <arith::ConstantOp>(loc, attr);
36
36
}
37
37
38
+ // / Creates shapedType using shape from cloneFrom and base type from cloneTo
39
+ static Type cloneToShapedType (Type cloneFrom, Type cloneTo) {
40
+ if (auto shapedTy = dyn_cast<ShapedType>(cloneFrom)) {
41
+ return shapedTy.clone (cloneTo);
42
+ }
43
+ return cloneTo;
44
+ }
45
+
38
46
namespace {
39
47
40
48
// / Expands CeilDivUIOp (n, m) into
@@ -225,12 +233,8 @@ struct BFloat16ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
225
233
return rewriter.notifyMatchFailure (op, " not a ext of bf16 to f32." );
226
234
}
227
235
228
- Type i16Ty = b.getI16Type ();
229
- Type i32Ty = b.getI32Type ();
230
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
231
- i16Ty = shapedTy.clone (i16Ty);
232
- i32Ty = shapedTy.clone (i32Ty);
233
- }
236
+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
237
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
234
238
235
239
Value bitcast = b.create <arith::BitcastOp>(i16Ty, operand);
236
240
Value exti = b.create <arith::ExtUIOp>(i32Ty, bitcast);
@@ -264,14 +268,8 @@ struct BFloat16TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
264
268
op, " only applicable to default rounding mode." );
265
269
}
266
270
267
- Type i16Ty = b.getI16Type ();
268
- Type i32Ty = b.getI32Type ();
269
- Type f32Ty = b.getF32Type ();
270
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
271
- i16Ty = shapedTy.clone (i16Ty);
272
- i32Ty = shapedTy.clone (i32Ty);
273
- f32Ty = shapedTy.clone (f32Ty);
274
- }
271
+ Type i16Ty = cloneToShapedType (operandTy, b.getI16Type ());
272
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
275
273
276
274
// Algorithm borrowed from this excellent code:
277
275
// https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/c10/util/BFloat16.h#L60-L79
@@ -340,14 +338,9 @@ struct F8E8M0ExtFOpConverter : public OpRewritePattern<arith::ExtFOp> {
340
338
return rewriter.notifyMatchFailure (op, " not a ext of F8E8M0FNU" );
341
339
}
342
340
343
- Type i8Ty = b.getI8Type ();
344
- Type i32Ty = b.getI32Type ();
345
- Type f32Ty = b.getF32Type ();
346
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
347
- i8Ty = shapedTy.clone (i8Ty);
348
- i32Ty = shapedTy.clone (i32Ty);
349
- f32Ty = shapedTy.clone (f32Ty);
350
- }
341
+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
342
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
343
+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
351
344
352
345
Value bitcast = b.create <arith::BitcastOp>(i8Ty, operand);
353
346
// create constants for NaNs
@@ -397,14 +390,10 @@ struct F8E8M0TruncFOpConverter : public OpRewritePattern<arith::TruncFOp> {
397
390
op, " only applicable to default rounding mode." );
398
391
}
399
392
400
- Type i8Ty = b.getI8Type ();
401
- Type i32Ty = b.getI32Type ();
402
- Type f32Ty = b.getF32Type ();
403
- if (auto shapedTy = dyn_cast<ShapedType>(operandTy)) {
404
- i8Ty = shapedTy.clone (i8Ty);
405
- i32Ty = shapedTy.clone (i32Ty);
406
- f32Ty = shapedTy.clone (f32Ty);
407
- }
393
+ Type i8Ty = cloneToShapedType (operandTy, b.getI8Type ());
394
+ Type i32Ty = cloneToShapedType (operandTy, b.getI32Type ());
395
+ Type f32Ty = cloneToShapedType (operandTy, b.getF32Type ());
396
+
408
397
if (operandETy.getIntOrFloatBitWidth () < 32 ) {
409
398
operand = b.create <arith::ExtFOp>(f32Ty, operand);
410
399
} else if (operandETy.getIntOrFloatBitWidth () > 32 ) {
0 commit comments