Closed
Description
🐛 Describe the bug
I'm working on an example for quantized tensor subclass + DTensor (tensor parallel) + compile: pytorch/ao#785
the test works with eager mode, but failed due to a shape mismatch in compile right now.
input shape: (128, 1024), linear weight shape: (512, 1024) (out * in)
Errors in torch.mm op with fake tensor:
[rank2]: result = fn(*args, is_out=(out is not None), **kwargs) # type: ignore[arg-type] 12:53:17 [554/1896]
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 4333, in matmul
[rank2]: return torch.mm(tensor1, tensor2)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank2]: return disable_fn(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank2]: return fn(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]: return DTensor._op_dispatcher.dispatch(
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 215, in dispatch
[rank2]: local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank2]: return self._op(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torchao-0.6.0+gitbd264f91-py3.10-linux-x86_64.egg/torchao/utils.py", line 372, in _dispatch__torch_function__
[rank2]: return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torchao-0.6.0+gitbd264f91-py3.10-linux-x86_64.egg/torchao/utils.py", line 355, in wrapper
[rank2]: return func(f, types, args, kwargs)
[rank2]: File "/data/users/jerryzh/ao/tutorials/developer_api_guide/tensor_parallel.py", line 86, in _
[rank2]: return aten.mm(input_tensor, weight_tensor)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
[rank2]: return self._op(*args, **(kwargs or {}))
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank2]: return fn(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank2]: return self.dispatch(func, types, args, kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank2]: return self._cached_dispatch_impl(func, types, args, kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
[rank2]: output = self._dispatch_impl(func, types, args, kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2039, in _dispatch_impl
[rank2]: r = func(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank2]: return self._op(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
[rank2]: result = fn(*args, **kwargs)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2100, in meta_mm
[rank2]: torch._check(
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 1565, in _check
[rank2]: _check_with(RuntimeError, cond, message)
[rank2]: File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 1547, in _check_with
[rank2]: raise error_type(message_evaluated)
[rank2]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(512, 1024)), shape=torch.Size([512, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=Fa
lse), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
[rank2]: a and b must have same reduction dim, but got [128, 1024] X [512, 1024].
transpose implementation looks like the following:
@implements(aten.t.default)
def _(func, types, args, kwargs):
tensor = args[0]
print("before transpose, ", tensor.shape)
shape = tensor.shape[::-1]
new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
print("after transpose:", new.shape)
return return_and_correct_aliasing(func, args, kwargs, new)
It seems that the fake tensor did not pick up the changes to the shape in this case.
Repro:
- checkout Adding example for quantized tensor + tensor parallelism ao#785
- build torchao (python setup.py install/develop)
- run:
with-proxy torchrun --standalone --nnodes=1 --nproc-per-node=4 tutorials/developer_api_guide/tensor_parallel.py
Versions
main
cc @ezyang @albanD @chauhang @penguinwu @eellison @zou3519 @bdhirsh