Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1730,4 +1730,56 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, input);
return success();
});

patterns.onOp(
"Hardmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) {
// onnx.Hardmax can be expanded into the following python code:
//
// import torch.nn.functional as F
// def hardmax(tensor, dim=-1):
// maximums = torch.argmax(tensor, dim=dim, keepdim=False)
// return F.one_hot(maximums)
//
// Given an example input:
// tensor([[1, 2, 3],
// [4, 6, 5],
// [9, 8, 7]])
// Above code yields the following:
// tensor([[0, 0, 1],
// [0, 1, 0],
// [1, 0, 0]])

Torch::ValueTensorType resultType;
int64_t axisValue;
Value input, axis;
if (binder.tensorOperand(input) ||
binder.s64IntegerAttr(axisValue, "axis") ||
binder.tensorResultType(resultType))
return failure();

auto loc = binder.getLoc();

std::optional<int64_t> axisIntTorch =
onnxDtypeIntToTorchDtypeInt(axisValue);
if (!axisIntTorch.has_value())
return rewriter.notifyMatchFailure(
binder.op, "unimplemented support for the given axis conversion");
axis = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(axisIntTorch.value()));

// torch.argmax
Value constKeepDims = rewriter.create<Torch::ConstantBoolOp>(
loc, rewriter.getType<Torch::BoolType>(),
rewriter.getBoolAttr(false));
Value argmax = rewriter.create<Torch::AtenArgmaxOp>(
loc, resultType, input, axis, constKeepDims);

// one_hot
Value oneInt = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
rewriter.replaceOpWithNewOp<Torch::AtenOneHotOp>(binder.op, resultType,
argmax, oneInt);

return success();
});
}
10 changes: 10 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -959,3 +959,13 @@ func.func @test_hardswish(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<
%0 = torch.operator "onnx.HardSwish"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}

// -----

// CHECK-LABEL: func.func @test_hardmax
func.func @test_hardmax(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[ARGMAX:.+]] = torch.aten.argmax %arg0, %int6, %false : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.bool -> !torch.vtensor<[3,4,5],f32>
// CHECK: torch.aten.one_hot %[[ARGMAX]], %int1 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32>
%0 = torch.operator "onnx.Hardmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32>
return %0 : !torch.vtensor<[3,4,5],f32>
}