Skip to content

Commit

Permalink
Print args/kwargs types for utils and examples (pytorch#963)
Browse files Browse the repository at this point in the history
Summary:
printing arg types are useful for understanding how the op should be implemented

Test Plan:
python tutorials/developer_api_guide/print_op_and_shapes.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 authored and melvinebenezer committed Oct 7, 2024
1 parent 3dfcedd commit e86acdd
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 4 deletions.
22 changes: 21 additions & 1 deletion test/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import unittest
from unittest.mock import patch

import torch
from torchao.utils import torch_version_at_least
from torchao.utils import TorchAOBaseTensor

class TestTorchVersionAtLeast(unittest.TestCase):
def test_torch_version_at_least(self):
Expand All @@ -20,7 +23,24 @@ def test_torch_version_at_least(self):
result = torch_version_at_least(compare_version)

self.assertEqual(result, expected_result, f"Failed for torch.__version__={torch_version}, comparing with {compare_version}")
print(f"{torch_version}: {result}")


class TestTorchAOBaseTensor(unittest.TestCase):

def test_print_arg_types(self):
class MyTensor(TorchAOBaseTensor):
def __new__(cls, data):
shape = data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape) # type: ignore[attr-defined]

def __init__(self, data):
self.data = data


l = torch.nn.Linear(10, 10)
with self.assertRaisesRegex(NotImplementedError, "arg_types"):
l.weight = torch.nn.Parameter(MyTensor(l.weight))


if __name__ == '__main__':
unittest.main()
4 changes: 3 additions & 1 deletion torchao/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,9 @@ class MyTensor(torch.Tensor):
func in cls._ATEN_OP_OR_TORCH_FN_TABLE:
return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)

raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func}")
arg_types = tuple(type(arg) for arg in args)
kwarg_types = {k: type(arg) for k, arg in kwargs}
raise NotImplementedError(f"{cls.__name__} dispatch: attempting to run unimplemented operator/function: {func=}, {types=}, {arg_types=}, {kwarg_types=}")

def _register_layout_cls(cls: Callable, layout_type_class: Callable):
"""Helper function for layout registrations, this is used to implement
Expand Down
26 changes: 24 additions & 2 deletions tutorials/developer_api_guide/print_op_and_shapes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch

PRINT_ARGS = False
linear_shapes = []
from torch.overrides import TorchFunctionMode
class TorchFunctionLoggingMode(TorchFunctionMode):
Expand All @@ -16,11 +17,28 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):
M, K = flattened_input_tensor.shape[0], flattened_input_tensor.shape[1]
assert K == weight_tensor.shape[1]
N = weight_tensor.shape[0]
print(f"TORCH_FUNC={str(func)} (M, K, N):", M, K, N)
print(f"TORCH_FUNC {func=} (M, K, N):", M, K, N)
linear_shapes.append((M, K, N))
else:
arg_shape = args[0].shape if len(args) > 0 and isinstance(args[0], torch.Tensor) else None
print(f"TORCH_FUNC={str(func)} args[0] shape:", arg_shape)
if PRINT_ARGS:
print(f"TORCH_FUNC {func=}, {types=}, {args=}, {kwargs=}, args[0] shape: {arg_shape}")
else:
print(f"TORCH_FUNC {func=}, {types=}, args[0] shape: {arg_shape}")
return func(*args, **kwargs)


from torch.utils._python_dispatch import TorchDispatchMode
class TorchDispatchLoggingMode(TorchDispatchMode):
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
arg_shape = args[0].shape if len(args) > 0 and isinstance(args[0], torch.Tensor) else None
if PRINT_ARGS:
print(f"ATEN_FUNC {func=}, {types=}, {args=}, {kwargs=}, args[0] shape: {arg_shape}")
else:
print(f"ATEN_FUNC {func=}, {types=}, args[0] shape: {arg_shape}")

return func(*args, **kwargs)

# NOTE: Modify this with your own model
Expand All @@ -33,3 +51,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

print()
print("all linear shapes (M, K, N):", linear_shapes)

# check all aten ops that's called in the model
# with TorchDispatchLoggingMode():
# m(*example_inputs)

0 comments on commit e86acdd

Please sign in to comment.