Skip to content

Commit

Permalink
Add aten.isclose support and its torch-to-tosa lowering (llvm#2512)
Browse files Browse the repository at this point in the history
Add aten.isclose op
Add its torch-to-tosa lowering
Update the TorchToTosa/basic.mlir tests


To test e2e tosa lowering:
`python -m e2e_testing.main -v -c=tosa`

---------

Co-authored-by: Ze Zhang <ze.zhang@getcruise.com>
  • Loading branch information
zezhang and Ze Zhang authored Oct 16, 2023
1 parent e649e06 commit f2c53b8
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 0 deletions.
4 changes: 4 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8
"Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier",
"UnflattenStaticModule_basic",
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
}

TORCHDYNAMO_XFAIL_SET = {
Expand Down Expand Up @@ -928,6 +930,8 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"IscloseStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"TileBigDimsSizeModule_basic",
"TileSmallDimsSizeModule_basic",
"IndexPutImpl2DNoneIndexStaticModule_basic",
Expand Down
27 changes: 27 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -4162,6 +4162,33 @@ def Torch_AtenViewAsRealOp : Torch_Op<"aten.view_as_real", [
}];
}

def Torch_AtenIscloseOp : Torch_Op<"aten.isclose", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchTensorType:$other,
Torch_FloatType:$rtol,
Torch_FloatType:$atol,
Torch_BoolType:$equal_nan
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenIscloseOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 5, 1);
}
void AtenIscloseOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 5, 1);
}
}];
}

def Torch_AtenUnbindCopyIntOp : Torch_Op<"aten.unbind_copy.int", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
54 changes: 54 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3920,6 +3920,59 @@ LogicalResult ConvertAtenOp<AtenLeTensorOp>::matchAndRewrite(
return success();
}

template <>
LogicalResult ConvertAtenOp<AtenIscloseOp>::matchAndRewrite(
AtenIscloseOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// check args
double rtol, atol;
bool equalNan;
if (!matchPattern(op.getRtol(), m_TorchConstantFloat(&rtol)))
return rewriter.notifyMatchFailure(op, "rtol must be a scalar constant");
if (!matchPattern(op.getAtol(), m_TorchConstantFloat(&atol)))
return rewriter.notifyMatchFailure(op, "atol must be a scalar constant");
if (!matchPattern(op.getEqualNan(), m_TorchConstantBool(&equalNan)))
return rewriter.notifyMatchFailure(
op, "unimplemented: equal_nan is expected to be false");

// check tensor type.
auto selfType = adaptor.getSelf().getType().dyn_cast<TensorType>();
auto otherType = adaptor.getOther().getType().dyn_cast<TensorType>();
if (!selfType || !otherType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!selfType.hasStaticShape() || !otherType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "Only tensor types with static shape are supported");
if (!selfType.getElementType().isa<mlir::FloatType>() ||
!otherType.getElementType().isa<mlir::FloatType>()) {
return rewriter.notifyMatchFailure(
op, "unimplemented: only FP element type is supported");
}

auto rhsSubOp = rewriter.create<tosa::SubOp>(
op->getLoc(), selfType, adaptor.getSelf(), adaptor.getOther());
auto rhsAbsOp =
rewriter.create<tosa::AbsOp>(op->getLoc(), selfType, rhsSubOp);

auto lhsAbsOp =
rewriter.create<tosa::AbsOp>(op->getLoc(), otherType, adaptor.getOther());
auto rtolConstOp =
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(rtol));
auto mulOp = rewriter.create<tosa::MulOp>(op->getLoc(), otherType,
rtolConstOp, lhsAbsOp, /*shift=*/0);
auto atolConstOp =
tosa::getTosaConstTensorSingleF32(rewriter, op, static_cast<float>(atol));
auto addOp =
rewriter.create<tosa::AddOp>(op->getLoc(), otherType, atolConstOp, mulOp);

auto outType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tosa::GreaterEqualOp>(op, outType, addOp,
rhsAbsOp);

return success();
}

template <>
LogicalResult ConvertAtenOp<AtenClampOp>::matchAndRewrite(
AtenClampOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -5134,6 +5187,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase<ConvertTorchToTosa> {
INSERT_ATENOP_PATTERN(AtenRemainderScalarOp);
INSERT_ATENOP_PATTERN(AtenCatOp);
INSERT_ATENOP_PATTERN(AtenSqrtOp);
INSERT_ATENOP_PATTERN(AtenIscloseOp);
#undef INSERT_ATENOP_PATTERN

#define INSERT_CLONE_ATENOP_PATTERN(AtenOp) \
Expand Down
8 changes: 8 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7480,6 +7480,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.isclose\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list<int>, !torch.list<int>) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.unsqueeze\"(%arg0: !torch.list<int>, %arg1: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.unsqueeze(%arg0, %arg1) : (!torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -9093,6 +9097,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.isclose\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.float, %arg3: !torch.float, %arg4: !torch.bool) -> !torch.int {\n"
" %int11 = torch.constant.int 11\n"
" return %int11 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.scaled_dot_product_attention\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.tuple<int, int>, %arg2: !torch.tuple<int, int>, %arg3: !torch.optional<tuple<int, int>>, %arg4: !torch.float, %arg5: !torch.bool, %arg6: !torch.optional<float>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,9 @@ def aten〇lt〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
def aten〇le〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇isclose〡shape(self: List[int], other: List[int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> List[int]:
return upstream_shape_functions.broadcast(self, other)

def aten〇unsqueeze〡shape(self: List[int], dim: int) -> List[int]:
return upstream_shape_functions.unsqueeze(self, dim)

Expand Down Expand Up @@ -2171,6 +2174,10 @@ def aten〇logical_and〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtyp
def aten〇logical_not〡dtype(self_rank_dtype: Tuple[int, int]) -> int:
return torch.bool

@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=2))
def aten〇isclose〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], rtol: float = 1.0000000000000001e-05, atol: float = 1e-08, equal_nan: bool = False) -> int:
return torch.bool

@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(3, 4, 32, 16), (3, 4, 32, 16), (3, 4, 32, 16)]))
def aten〇scaled_dot_product_attention〡dtype(query_rank_dtype: Tuple[int, int], key_rank_dtype: Tuple[int, int], value_rank_dtype: Tuple[int, int], attn_mask_rank_dtype: Optional[Tuple[int, int]] = None, dropout_p: float = 0., is_causal: bool = False, scale: Optional[float] = None) -> int:
_, query_dtype = query_rank_dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def emit_with_mutating_variants(key, **kwargs):
emit("aten::imag : (Tensor) -> (Tensor)")
emit("aten::view_as_complex : (Tensor) -> (Tensor)")
emit("aten::view_as_real : (Tensor) -> (Tensor)")
emit("aten::isclose : (Tensor, Tensor, float, float, bool) -> (Tensor)")

# Ops with dynamic number of outputs
emit("aten::unbind_copy.int : (Tensor, int) -> (Tensor[])")
Expand Down
45 changes: 45 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4580,3 +4580,48 @@ def forward(self, x):
@register_test_case(module_factory=lambda: Add_Module())
def Add_Module_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 3))


# ==============================================================================


class IscloseStaticModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
([5, 5], torch.float32, True),
])
def forward(self, x, y):
return torch.isclose(x, y)


@register_test_case(module_factory=lambda: IscloseStaticModule())
def IscloseStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(5, 5), tu.rand(5, 5))


# ==============================================================================


class IscloseStaticModuleTrue(torch.nn.Module):

def __init__(self):
super().__init__()
self.register_buffer('tensor', torch.ones(1))

@export
@annotate_args([
None,
([5, 5], torch.float32, True),
])
def forward(self, x):
return torch.isclose(x, self.tensor)

@register_test_case(module_factory=lambda: IscloseStaticModuleTrue())
def IscloseStaticModuleTrue_basic(module, tu: TestUtils):
module.forward(torch.ones(5, 5))
29 changes: 29 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1155,3 +1155,32 @@ func.func @torch.aten.remainder.Scalar(%arg0: !torch.vtensor<[2, 4],f32>) -> !to
%0 = torch.aten.remainder.Scalar %arg0, %int2 : !torch.vtensor<[2, 4],f32>, !torch.int -> !torch.vtensor<[2, 4],f32>
return %0 : !torch.vtensor<[2, 4],f32>
}

// -----

// CHECK-LABEL: func.func @forward(
// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[5,5],f32>,
// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[5,5],f32> -> tensor<5x5xf32>
// CHECK: %[[ATOL:.*]] = torch.constant.float 1.000000e-08
// CHECK: %[[RTOL:.*]] = torch.constant.float 1.000000e-05
// CHECK: %[[FALSE:.*]] = torch.constant.bool false
// CHECK: %[[VAL_2:.*]] = tosa.sub %[[VAL_0]], %[[VAL_1]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_3:.*]] = tosa.abs %[[VAL_2]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_4:.*]] = tosa.abs %[[VAL_1]] : (tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_4]] {shift = 0 : i32} : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<9.99999993E-9> : tensor<f32>}> : () -> tensor<f32>
// CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_7]], %[[VAL_6]] : (tensor<f32>, tensor<5x5xf32>) -> tensor<5x5xf32>
// CHECK: %[[VAL_9:.*]] = tosa.greater_equal %[[VAL_8]], %[[VAL_3]] : (tensor<5x5xf32>, tensor<5x5xf32>) -> tensor<5x5xi1>
// CHECK: %[[VAL_10:.*]] = torch_c.from_builtin_tensor %[[VAL_9]] : tensor<5x5xi1> -> !torch.vtensor<[5,5],i1>
// CHECK: return %[[VAL_10]] : !torch.vtensor<[5,5],i1>
// CHECK: }
func.func @forward(%arg0: !torch.vtensor<[5,5],f32>, %arg1: !torch.vtensor<[5,5],f32>) -> !torch.vtensor<[5,5],i1> {
%float1.000000e-08 = torch.constant.float 1.000000e-08
%float1.000000e-05 = torch.constant.float 1.000000e-05
%false = torch.constant.bool false
%0 = torch.aten.isclose %arg0, %arg1, %float1.000000e-05, %float1.000000e-08, %false : !torch.vtensor<[5,5],f32>, !torch.vtensor<[5,5],f32>, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[5,5],i1>
return %0 : !torch.vtensor<[5,5],i1>
}

0 comments on commit f2c53b8

Please sign in to comment.