Skip to content

Commit

Permalink
[onnx] Handle optional arguments in Clip op pattern. (#2976)
Browse files Browse the repository at this point in the history
  • Loading branch information
ScottTodd authored Mar 7, 2024
1 parent 6e84752 commit 7b18646
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 34 deletions.
80 changes: 54 additions & 26 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,38 +602,66 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(
return success();
});
patterns.onOp(
"Clip", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
"Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// https://onnx.ai/onnx/operators/onnx__Clip.html

// Inputs and outputs must be tensors.
Value source;
Torch::ValueTensorType resultType;
if (binder.op->getNumOperands() == 1) {
Value source;
if (binder.tensorOperand(source) ||
binder.tensorResultType(resultType))
if (binder.tensorOperandAtIndex(source, 0) ||
binder.tensorResultType(resultType)) {
return failure();
}

// Min and max can be args (version 11+) or attributes (version 6-).
// They default to numeric_limits::lowest() and numeric_limits::max().
Value min;
Value max;
if (binder.op->getNumOperands() >= 2)
min = binder.op->getOperand(1);
if (binder.op->getNumOperands() == 3)
max = binder.op->getOperand(2);

// Note: attribute versions of the op only support float types.
auto resultDtype = resultType.getDtype();
if (!min && binder.op->hasAttr("torch.onnx.min")) {
float minValue;
if (binder.f32FloatAttr(minValue, "min",
std::numeric_limits<float>::lowest()))
return failure();
Value cstNone =
rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
rewriter.replaceOpWithNewOp<Torch::AtenClampOp>(
binder.op, resultType, source, /*min=*/cstNone, /*max=*/cstNone);
return success();
} else if (binder.op->getNumOperands() == 2) {
Value source, min;
if (binder.tensorOperands(source, min) ||
binder.tensorResultType(resultType))
auto minSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
rewriter.getFloatAttr(resultDtype, minValue));
min = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, minSplatAttr);
}
if (!max && binder.op->hasAttr("torch.onnx.max")) {
float maxValue;
if (binder.f32FloatAttr(maxValue, "max",
std::numeric_limits<float>::max()))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenClampMinTensorOp>(
binder.op, resultType, source, /*min=*/min);
auto maxSplatAttr = SplatElementsAttr::get(
resultType.toBuiltinTensor().clone(resultDtype),
rewriter.getFloatAttr(resultDtype, maxValue));
max = rewriter.create<Torch::ValueTensorLiteralOp>(
binder.getLoc(), resultType, maxSplatAttr);
}

if (!min && !max) {
// Cliping with no limits is a no-op.
rewriter.replaceOp(binder.op, source);
return success();
} else if (binder.op->getNumOperands() == 3) {
Value source, min, max;
if (binder.tensorOperandAtIndex(source, 0) ||
binder.tensorOperandAtIndex(min, 1) ||
binder.tensorOperandAtIndex(max, 2) ||
binder.tensorResultType(resultType))
return failure();
rewriter.replaceOpWithNewOp<Torch::AtenClampTensorOp>(
binder.op, resultType, source, min, max);
}

if (!max) {
rewriter.replaceOpWithNewOp<Torch::AtenClampMinTensorOp>(
binder.op, resultType, source, min);
return success();
}
return failure();

rewriter.replaceOpWithNewOp<Torch::AtenClampTensorOp>(
binder.op, resultType, source, min, max);
return success();
});
patterns.onOp(
"Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
Expand Down
8 changes: 0 additions & 8 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1913,14 +1913,6 @@
"TypeConversionI64ToI32Module_basic",

# Failure - onnx_lowering: onnx.Clip
"ElementwiseClampMaxModule_basic",
"ElementwiseClampMinModule_basic",
"ElementwiseClampMinTensorFloatModule_basic",
"ElementwiseClampMinTensorIntModule_basic",
"ElementwiseClampModule_basic",
"ElementwiseClampTensorFloatModule_basic",
"ElementwiseClampTensorInt8Module_basic",
"ElementwiseClampTensorIntModule_basic",
"NormalizeModule_basic",

# Failure - onnx_lowering: onnx.Einsum
Expand Down
24 changes: 24 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,16 @@ func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1:

// -----

// CHECK-LABEL: @test_clip_default_int8_max
func.func @test_clip_default_int8_max(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
%none = torch.constant.none
// CHECK: torch.aten.clamp.Tensor %arg0, %none, %arg1 : !torch.vtensor<[3,4,5],si8>, !torch.none, !torch.vtensor<[],si8> -> !torch.vtensor<[3,4,5],si8>
%0 = torch.operator "onnx.Clip"(%arg0, %none, %arg1) : (!torch.vtensor<[3,4,5],si8>, !torch.none, !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8>
return %0 : !torch.vtensor<[3,4,5],si8>
}

// -----

// CHECK-LABEL: @test_clip_default_min
func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32>
Expand Down Expand Up @@ -549,6 +559,20 @@ func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[

// -----

module {
func.func @test_clip_attrs(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} {
%none = torch.constant.none

// CHECK: %[[MIN:.+]] = torch.vtensor.literal(dense<-5.000000e-01> : tensor<3x4xf32>) : !torch.vtensor<[3,4],f32>
// CHECK: %[[MAX:.+]] = torch.vtensor.literal(dense<5.000000e-01> : tensor<3x4xf32>) : !torch.vtensor<[3,4],f32>
// CHECK: %[[CLAMP:.+]] = torch.aten.clamp.Tensor %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32>
%0 = torch.operator "onnx.Clip"(%arg0) {torch.onnx.max = 5.000000e-01 : f32, torch.onnx.min = -5.000000e-01 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32>
return %0 : !torch.vtensor<[3,4],f32>
}
}

// -----

// CHECK-LABEL: @test_cos_example
func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32>
Expand Down

0 comments on commit 7b18646

Please sign in to comment.