Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions python/test/eager_mode/normalize_args_kwargs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

# RUN: %PYTHON %s | FileCheck %s


import torch

from torch_mlir.eager_mode.torch_mlir_dispatch import normalize_args_kwargs


def run_test(test, XFAIL=False, XPASS=False):
try:
test()
print(("X" if XPASS else "") + f"PASS - {test.__name__}")
except Exception as e:
print(("X" if XFAIL else "") + f"FAIL - {test.__name__}")
print(e)


# CHECK: PASS - should_normalize
def should_normalize():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": [3, 3]}
golden = {
"kernel_size": [3, 3],
# This is due to the schema for max_pool2d_with_indices defining
# the stride arg as int[2] stride=[].
"stride": [],
"padding": [0, 0],
"dilation": [1, 1],
"ceil_mode": False,
}

new_args, new_kwargs = normalize_args_kwargs(target, args, kwargs)
for arg, new_arg in zip(args, new_args):
assert torch.allclose(arg, new_arg)
for k, v in new_kwargs.items():
assert v == golden[k]


# CHECK: FAIL - shouldnt_normalize1
# CHECK: Couldn't normalize args and kwargs
def shouldnt_normalize1():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"stride": []}
normalize_args_kwargs(target, args, kwargs)


# This next two tests are XPASS because of https://github.com/pytorch/pytorch/issues/75342
# I.e., they should fail but in fact they pass because of the upstream bug.
# The reason for the bug is a fast path branch in operator_schemas.normalize_function
# that doesn't do rigorous type checking, and hence lets type mistmatches slip through.
# TODO(max): change these to FAIL when the upstream bug is fixed.

# CHECK: XPASS - shouldnt_normalize2
def shouldnt_normalize2():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": []}
normalize_args_kwargs(target, args, kwargs)


# CHECK: XPASS - shouldnt_normalize3
def shouldnt_normalize3():
target = torch.ops.aten.max_pool2d_with_indices.default.overloadpacket
args = (torch.randn((1, 3, 32, 32)),)
kwargs = {"kernel_size": [3, 3], "padding": None}
normalize_args_kwargs(target, args, kwargs)


def main():
run_test(should_normalize)
run_test(shouldnt_normalize1)
run_test(shouldnt_normalize2, XPASS=True)
run_test(shouldnt_normalize3, XPASS=True)


if __name__ == "__main__":
main()
6 changes: 4 additions & 2 deletions python/torch_mlir/eager_mode/torch_mlir_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def normalize_args_kwargs(target: Callable, args: Tuple[Any], kwargs: Dict[str,

arg_types = map_aggregate(args, type)
assert isinstance(arg_types, tuple)
arg_types = tuple([create_type_hint(i) for i in arg_types])
kwarg_types = {k: type(v) for k, v in kwargs.items()}
arg_types = map_aggregate(map_aggregate(args, type), create_type_hint)
kwarg_types = {
k: create_type_hint(map_aggregate(v, type)) for k, v in kwargs.items()
}

new_args_and_kwargs = normalize_function(
target, args, kwargs, arg_types, kwarg_types, normalize_to_only_use_kwargs=False
Expand Down