Skip to content

Commit

Permalink
[MHLO] Evaluate RuntimeAssertOp at compile time (llvm#1732)
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus authored Dec 22, 2022
1 parent ddbcf56 commit 49071f8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 0 deletions.
1 change: 1 addition & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
}

MHLO_PASS_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
Expand Down
25 changes: 25 additions & 0 deletions lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1297,13 +1297,38 @@ LogicalResult ConvertAtenOp<AtenGeluBackwardOp>::matchAndRewrite(
return success();
}

// RuntimeAssertOp
namespace {
class ConvertRuntimeAssertOp : public OpConversionPattern<RuntimeAssertOp> {
public:
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(RuntimeAssertOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
bool condition;
if (!matchPattern(op.getCondition(), m_TorchConstantBool(&condition))) {
return rewriter.notifyMatchFailure(
op, "unimplemented: condition must be a constant");
}
if (!condition) {
return op->emitError("condition must be true");
}
rewriter.eraseOp(op);
return success();
}
};
} // namespace

void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToMhloOptions &options) {
MLIRContext *context = patterns.getContext();

target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
target.addIllegalOp<RuntimeAssertOp>();
patterns.add<ConvertRuntimeAssertOp>(typeConverter, context);

#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \
target.addIllegalOp<AtenOp>(); \
Expand Down
11 changes: 11 additions & 0 deletions test/Conversion/TorchToMhlo/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,14 @@ func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtenso
%1 = torch.aten.cat %0, %int0 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,?],f32>
return %1 : !torch.vtensor<[?,?],f32>
}

// -----

// CHECK-LABEL: func.func @torch.runtime.assert(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: return %[[ARG_0]] : !torch.vtensor<[?,?],f32>
func.func @torch.runtime.assert(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%true = torch.constant.bool true
torch.runtime.assert %true, "this should not fail"
return %arg0: !torch.vtensor<[?,?],f32>
}

0 comments on commit 49071f8

Please sign in to comment.