Skip to content

Commit 3528c3a

Browse files
author
Tanyo Kwok
committed
BladeDISC related patches
* fix float width * fix divide_floor & export promoteTypes api (#9) * To comply with the old pytorch versions * Add native_dropout_backward & native_layer_norm_backward decomposition (#15) * add native_dropout and related ops pattern (#1211) * [MHLO] fix dot general contract * Fix batch_norm, div.Tensor_mode and folder (#21) * reimplement linear lowering * reimplement 2-D rhs for mutmul * add torchdynamo
1 parent 9536174 commit 3528c3a

File tree

23 files changed

+782
-152
lines changed

23 files changed

+782
-152
lines changed

include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6596,9 +6596,10 @@ def Torch_AtenOnesLikeOp : Torch_Op<"aten.ones_like", [
65966596
}
65976597

65986598
def Torch_AtenEmptyMemoryFormatOp : Torch_Op<"aten.empty.memory_format", [
6599+
Pure,
65996600
AllowsTypeRefinement,
66006601
HasValueSemantics,
6601-
ReadOnly
6602+
ReadOnly,
66026603
]> {
66036604
let summary = "Generated op for `aten::empty.memory_format : (int[], int?, int?, Device?, bool?, int?) -> (Tensor)`";
66046605
let arguments = (ins
@@ -7129,6 +7130,31 @@ def Torch_AtenMaxOp : Torch_Op<"aten.max", [
71297130
}];
71307131
}
71317132

7133+
def Torch_AtenAmaxOp : Torch_Op<"aten.amax", [
7134+
AllowsTypeRefinement,
7135+
HasValueSemantics,
7136+
ReadOnly
7137+
]> {
7138+
let summary = "Generated op for `aten::amax : (Tensor, int[]?, bool) -> Tensor`";
7139+
let arguments = (ins
7140+
AnyTorchTensorType:$self,
7141+
AnyTorchOptionalListOfTorchIntType:$dim,
7142+
Torch_BoolType:$keepdim
7143+
);
7144+
let results = (outs
7145+
AnyTorchTensorType:$results
7146+
);
7147+
let hasCustomAssemblyFormat = 1;
7148+
let extraClassDefinition = [{
7149+
ParseResult AtenAmaxOp::parse(OpAsmParser &parser, OperationState &result) {
7150+
return parseDefaultTorchOp(parser, result, 3, 1);
7151+
}
7152+
void AtenAmaxOp::print(OpAsmPrinter &printer) {
7153+
printDefaultTorchOp(printer, *this, 3, 1);
7154+
}
7155+
}];
7156+
}
7157+
71327158
def Torch_AtenMaxDimOp : Torch_Op<"aten.max.dim", [
71337159
AllowsTypeRefinement,
71347160
HasValueSemantics,

include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ enum Layout { Strided, Sparse, SparseCsr, Mkldnn, NumOptions };
160160
//===-----------------------------------------------------------------------===//
161161
enum EmbeddingBagMode { MODE_SUM, MODE_MEAN, MODE_MAX };
162162

163+
ScalarType promoteTypes(ScalarType a, ScalarType b);
163164
} // namespace torch_upstream
164165
} // namespace torch
165166
} // namespace mlir

include/torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
4242
let assemblyFormat = [{
4343
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
4444
}];
45+
let hasCanonicalizer = 1;
4546
let hasVerifier = 1;
4647
}
4748

@@ -61,6 +62,7 @@ def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tenso
6162
let assemblyFormat = [{
6263
$operand attr-dict `:` qualified(type($operand)) `->` qualified(type($result))
6364
}];
65+
let hasCanonicalizer = 1;
6466
let hasVerifier = 1;
6567
}
6668

@@ -80,6 +82,7 @@ def TorchConversion_ToI1Op : TorchConversion_Op<"to_i1", [
8082
let assemblyFormat = [{
8183
$operand attr-dict
8284
}];
85+
let hasFolder = 1;
8386
}
8487

8588
def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
@@ -98,6 +101,7 @@ def TorchConversion_FromI1Op : TorchConversion_Op<"from_i1", [
98101
let assemblyFormat = [{
99102
$operand attr-dict
100103
}];
104+
let hasFolder = 1;
101105
}
102106

103107
def TorchConversion_ToI64Op : TorchConversion_Op<"to_i64", [

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,13 +383,17 @@ class ConvertTorchToArith : public ConvertTorchToArithBase<ConvertTorchToArith>
383383
target.addIllegalOp<Torch::ConstantIntOp>();
384384
patterns.add<ConvertTorchConstantOp<Torch::ConstantIntOp>>(typeConverter,
385385
context);
386-
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp>();
386+
target.addIllegalOp<AtenAddIntOp, AtenSubIntOp, AtenMulIntOp,
387+
AtenRemainderIntOp>();
387388
patterns.add<ConvertAtenBinaryOp<AtenAddIntOp, arith::AddIOp>>(
388389
typeConverter, context);
389390
patterns.add<ConvertAtenBinaryOp<AtenSubIntOp, arith::SubIOp>>(
390391
typeConverter, context);
391392
patterns.add<ConvertAtenBinaryOp<AtenMulIntOp, arith::MulIOp>>(
392393
typeConverter, context);
394+
patterns.add<ConvertAtenBinaryOp<AtenRemainderIntOp, arith::RemSIOp>>(
395+
typeConverter, context);
396+
393397
target.addIllegalOp<AtenSubFloatOp>();
394398
patterns.add<ConvertAtenBinaryOp<AtenSubFloatOp, arith::SubFOp>>(
395399
typeConverter, context);

lib/Conversion/TorchToMhlo/Basic.cpp

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
using namespace mlir;
3030
using namespace mlir::torch;
3131
using namespace mlir::torch::Torch;
32+
using namespace mlir::torch::TorchConversion;
3233
using namespace mlir::torch::torch_to_mhlo;
3334

3435
LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op,
@@ -166,16 +167,19 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern<AtenOpT> {
166167
if (!selfTy)
167168
return op.emitError("only Tensor types supported in MHLO");
168169

169-
if (selfTy.getElementType().isa<mlir::FloatType>()) {
170+
auto outTy = OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
171+
op.getType());
172+
if (selfTy != outTy) {
173+
auto out = rewriter.create<MhloOpT>(op.getLoc(), selfTy, self);
174+
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, outTy, out);
175+
return success();
176+
} else {
170177
rewriter.replaceOpWithNewOp<MhloOpT>(
171178
op,
172179
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
173180
op.getType()),
174181
self);
175182
return success();
176-
} else {
177-
return op.emitError(
178-
"only floating-point datatype legalization supported");
179183
}
180184
}
181185
};
@@ -345,15 +349,10 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
345349
} else if (!rhsType) {
346350
rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy);
347351
}
348-
DenseIntElementsAttr bcastDimensions;
349-
lhs = mhlo::promoteType(rewriter, lhs, outType);
350-
rhs = mhlo::promoteType(rewriter, rhs, outType);
351-
auto loc = op.getLoc();
352-
Value result =
353-
rewriter.create<ChloOpT>(loc, outType, lhs, rhs, bcastDimensions);
354-
355352
if (!isa<AtenDivTensorModeOp>(op)) {
356-
rewriter.replaceOp(op, result);
353+
lhs = mhlo::promoteType(rewriter, lhs, outType);
354+
rhs = mhlo::promoteType(rewriter, rhs, outType);
355+
rewriter.replaceOpWithNewOp<ChloOpT>(op, outType, lhs, rhs, nullptr);
357356
return success();
358357
}
359358

@@ -365,6 +364,17 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
365364
return rewriter.notifyMatchFailure(
366365
op, "only support constant str rounding mode");
367366

367+
auto computeTy = outType;
368+
if (outElemTy.isIntOrIndex()) {
369+
computeTy =
370+
RankedTensorType::get(outType.getShape(), rewriter.getF32Type());
371+
}
372+
lhs = mhlo::promoteType(rewriter, lhs, computeTy);
373+
rhs = mhlo::promoteType(rewriter, rhs, computeTy);
374+
auto loc = op.getLoc();
375+
auto result =
376+
rewriter.create<ChloOpT>(loc, computeTy, lhs, rhs, nullptr).getResult();
377+
368378
if (roundingMode == "trunc") {
369379
// "trunc" - rounds the results of the division towards zero. Equivalent
370380
// to C-style integer division.
@@ -378,7 +388,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern<AtenOpT> {
378388
// floor division in Python (the // operator)
379389
result = rewriter.create<mhlo::FloorOp>(loc, result).getResult();
380390
}
381-
rewriter.replaceOp(op, result);
391+
rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, outType, result);
382392
return success();
383393
}
384394
};
@@ -836,7 +846,11 @@ LogicalResult ConvertAtenOp<AtenReluOp>::matchAndRewrite(
836846
APFloat::getZero(lhsElemTy.cast<mlir::FloatType>().getFloatSemantics(),
837847
false),
838848
lhs);
839-
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, lhs, zeroTensor);
849+
auto outType = getTypeConverter()
850+
->convertType(op.getType())
851+
.template dyn_cast<TensorType>();
852+
853+
rewriter.replaceOpWithNewOp<mhlo::MaxOp>(op, outType, lhs, zeroTensor);
840854
return success();
841855
}
842856

@@ -862,7 +876,11 @@ LogicalResult ConvertAtenOp<AtenGeluOp>::matchAndRewrite(
862876
auto erf = rewriter.create<mlir::chlo::ErfOp>(loc, erfElement);
863877
auto erfAdd = rewriter.create<mhlo::AddOp>(loc, erf, one);
864878
auto halfMul = rewriter.create<mhlo::MulOp>(loc, erfAdd, half);
865-
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, input, halfMul);
879+
auto outType = getTypeConverter()
880+
->convertType(op.getType())
881+
.template dyn_cast<TensorType>();
882+
883+
rewriter.replaceOpWithNewOp<mhlo::MulOp>(op, outType, input, halfMul);
866884
return success();
867885
}
868886

@@ -1463,7 +1481,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
14631481
INSERT_ATENOP_PATTERN(ValueTensorLiteralOp);
14641482
INSERT_ATENOP_PATTERN(AtenReciprocalOp);
14651483
INSERT_ATENOP_PATTERN(PrimNumToTensorScalarOp);
1466-
INSERT_ATENOP_PATTERN(AtenContiguousOp);
14671484

14681485
INSERT_ATENOP_PATTERN(AtenReluOp);
14691486
INSERT_ATENOP_PATTERN(AtenGeluOp);

0 commit comments

Comments
 (0)