Skip to content

How does fake tensor works with tensor subclass in torch.compile? #136287

Closed
@jerryzh168

Description

@jerryzh168

🐛 Describe the bug

I'm working on an example for quantized tensor subclass + DTensor (tensor parallel) + compile: pytorch/ao#785

the test works with eager mode, but failed due to a shape mismatch in compile right now.

input shape: (128, 1024), linear weight shape: (512, 1024) (out * in)

Errors in torch.mm op with fake tensor:

[rank2]:     result = fn(*args, is_out=(out is not None), **kwargs)  # type: ignore[arg-type]                                                                                                      12:53:17 [554/1896]
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_decomp/decompositions.py", line 4333, in matmul
[rank2]:     return torch.mm(tensor1, tensor2)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_compile.py", line 32, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 632, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/distributed/tensor/_dispatch.py", line 215, in dispatch
[rank2]:     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank2]:     return self._op(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torchao-0.6.0+gitbd264f91-py3.10-linux-x86_64.egg/torchao/utils.py", line 372, in _dispatch__torch_function__
[rank2]:     return cls._ATEN_OP_OR_TORCH_FN_TABLE[func](func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torchao-0.6.0+gitbd264f91-py3.10-linux-x86_64.egg/torchao/utils.py", line 355, in wrapper
[rank2]:     return func(f, types, args, kwargs)
[rank2]:   File "/data/users/jerryzh/ao/tutorials/developer_api_guide/tensor_parallel.py", line 86, in _
[rank2]:     return aten.mm(input_tensor, weight_tensor)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 1116, in __call__
[rank2]:     return self._op(*args, **(kwargs or {}))
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/utils/_stats.py", line 21, in wrapper
[rank2]:     return fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1238, in __torch_dispatch__
[rank2]:     return self.dispatch(func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1692, in dispatch
[rank2]:     return self._cached_dispatch_impl(func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1339, in _cached_dispatch_impl
[rank2]:     output = self._dispatch_impl(func, types, args, kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 2039, in _dispatch_impl
[rank2]:     r = func(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_ops.py", line 716, in __call__
[rank2]:     return self._op(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 273, in _fn
[rank2]:     result = fn(*args, **kwargs)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/_meta_registrations.py", line 2100, in meta_mm
[rank2]:     torch._check(
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 1565, in _check
[rank2]:     _check_with(RuntimeError, cond, message)
[rank2]:   File "/home/jerryzh/.conda/envs/ao/lib/python3.10/site-packages/torch/__init__.py", line 1547, in _check_with
[rank2]:     raise error_type(message_evaluated)
[rank2]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(128, 1024)), device_mesh=DeviceMesh('cuda', [0, 1,
 2, 3]), placements=(Replicate(),)), DTensor(local_tensor=MyDTypeTensorTP(data=FakeTensor(..., device='cuda:0', size=(512, 1024)), shape=torch.Size([512, 1024]), device=cuda:0, dtype=torch.float32, requires_grad=Fa
lse), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3]), placements=(Shard(dim=0),)), None), **{}):
[rank2]: a and b must have same reduction dim, but got [128, 1024] X [512, 1024].

transpose implementation looks like the following:

@implements(aten.t.default)
def _(func, types, args, kwargs):
    tensor = args[0]
    print("before transpose, ", tensor.shape)
    shape = tensor.shape[::-1]
    new = tensor.__class__(tensor.layout_tensor.t(), shape, tensor.dtype)
    print("after transpose:", new.shape)
    return return_and_correct_aliasing(func, args, kwargs, new)

It seems that the fake tensor did not pick up the changes to the shape in this case.

Repro:

Versions

main

cc @ezyang @albanD @chauhang @penguinwu @eellison @zou3519 @bdhirsh

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fakeTensormodule: pt2-dispatcherPT2 dispatcher-related issues (e.g., aotdispatch, functionalization, faketensor, custom-op,oncall: pt2tensor subclassRelated to tensor subclassestriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions