Skip to content

Commit

Permalink
[MLIR][ONNX] Add OnnxToTorch support for Bernoulli and CastLike op
Browse files Browse the repository at this point in the history
Signed-Off By: Vivek Khandelwal <vivekkhandelwal1424@gmail.com>
  • Loading branch information
vivekkhandelwal1 committed Jan 10, 2024
1 parent 35e8f86 commit 4707d3b
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 15 deletions.
116 changes: 101 additions & 15 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,28 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::onnx_c;

static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
int64_t dtypeIntTorch;
// TODO: Add complete mapping.
switch (dtypeIntOnnx) {
case 1:
dtypeIntTorch = 6; // float
break;
case 10:
dtypeIntTorch = 5; // half
break;
case 11:
dtypeIntTorch = 7; // double
break;
case 16:
dtypeIntTorch = 15; // bfloat16
break;
default:
dtypeIntTorch = -1; // No dtype
}
return dtypeIntTorch;
}

// Simple rewrites for the default domain.
// See: https://onnx.ai/onnx/operators/
// For operators that are effectively version invariant, we register with
Expand Down Expand Up @@ -311,6 +333,53 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
}
return failure();
});
patterns.onOp(
"Bernoulli", 15,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input;
int64_t dtypeIntOnnx, dtypeIntTorch;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) ||
binder.tensorResultType(resultType))
return failure();

SmallString<64> name("torch.onnx.");
name.append("seed");
auto attr = binder.op->getAttr(name);
if (attr) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented: support not present for seed attribute");
}

Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value bernoulli = rewriter.create<Torch::AtenBernoulliOp>(
binder.getLoc(), input.getType(), input, /*generator=*/none);

if (dtypeIntOnnx == -1) {
// True, if dtype attribute value is not present.
rewriter.replaceOp(binder.op, bernoulli);
return success();
}
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
}
Value constDtype = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(),
rewriter.getIntegerAttr(rewriter.getIntegerType(64),
dtypeIntTorch));
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, bernoulli, constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
return success();
});
patterns.onOp(
"BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down Expand Up @@ -386,21 +455,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
binder.tensorResultType(resultType))
return failure();

// TODO: Add complete mapping.
switch (dtypeIntOnnx) {
case 1:
dtypeIntTorch = 6; // float
break;
case 10:
dtypeIntTorch = 5; // half
break;
case 11:
dtypeIntTorch = 7; // double
break;
case 16:
dtypeIntTorch = 15; // bfloat16
break;
default:
dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx);
if (dtypeIntTorch == -1) {
return rewriter.notifyMatchFailure(
binder.op,
"unimplemented support for the given dtype conversion");
Expand All @@ -418,6 +474,36 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
/*memory_format=*/none);
return success();
});
patterns.onOp(
"CastLike", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value input, target;
if (binder.tensorOperands(input, target) ||
binder.tensorResultType(resultType))
return failure();

// TODO: Add support to handle the `saturate` attribute.
// Ignoring it right now, since it's only using during the float8
// conversions which are not supported in Torch-MLIR right now.

Torch::ValueTensorType targetTy =
target.getType().cast<Torch::ValueTensorType>();
if (!targetTy.hasDtype()) {
return rewriter.notifyMatchFailure(binder.op,
"target tensor must have a dtype");
}
Type targetDtype = targetTy.getDtype();
Value constDtype = Torch::getDtypeIntValueForType(
rewriter, binder.getLoc(), targetDtype);
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, input, constDtype,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);
return success();
});
patterns.onOp("Ceil", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Expand Down
59 changes: 59 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,25 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,
return %0 : !torch.vtensor<[3,4,5],f32>
}

// CHECK-LABEL: @test_bernoulli
func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %0 = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f64>, !torch.none -> !torch.vtensor<[10],f64>
%0 = torch.operator "onnx.Bernoulli"(%arg0) : (!torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64>
return %0 : !torch.vtensor<[10],f64>
}

// CHECK-LABEL: @test_bernoulli_double
func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[BERNOULLI:.*]] = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32>
// CHECK: %[[DTYPE:.*]] = torch.constant.int 7
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %[[BERNOULLI]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f64>
%0 = torch.operator "onnx.Bernoulli"(%arg0) {torch.onnx.dtype = 11 : si64} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64>
return %0 : !torch.vtensor<[10],f64>
}

// CHECK-LABEL: @test_bitshift_left_uint8
func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8>
Expand Down Expand Up @@ -323,6 +342,46 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc
return %0 : !torch.vtensor<[3,4],f32>
}

// CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT
func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],bf16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}

// CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT
func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f64>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}

// CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE
func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 7
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64>
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64>
return %0 : !torch.vtensor<[3,4],f64>
}

// CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT
func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT:.*]] = torch.constant.int 6
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}

// CHECK-LABEL: @test_ceil_example
func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32>
Expand Down

0 comments on commit 4707d3b

Please sign in to comment.