- 
                Notifications
    You must be signed in to change notification settings 
- Fork 25.7k
Closed
Labels
good first issuemodule: decompositionsTopics related to decomposition (excluding PrimTorch)Topics related to decomposition (excluding PrimTorch)oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Describe the bug
Repro:
import torch
import torch._dynamo
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
    return torch.functional.split(x, 0)
fn(torch.empty((0,)))
Fails with
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 1997, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 2042, in get_fake_value
    raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 1974, in get_fake_value
    ret_val = wrap_fake_exception(
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 1459, in wrap_fake_exception
    return fn()
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 1975, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 2110, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 2092, in run_node
    return node.target(*args, **kwargs)
  File "/data/users/ezyang/b/pytorch/torch/functional.py", line 207, in split
    return tensor.split(split_size_or_sections, dim)
  File "/data/users/ezyang/b/pytorch/torch/_tensor.py", line 922, in split
    return torch._VF.split(self, split_size, dim)  # type: ignore[attr-defined]
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <function split at 0x7f4376e096c0>(*(FakeTensor(..., size=(0,)), 0), **{}):
View operation returned a tensor that is the same as the input base tensor.  This is no longer allowed; you must explicitly create a new tensor (e.g., using .detach()). As a user, you could have made a mistake implementing __torch_dispatch__ or a Python operator decomposition or meta registration; if that's not the case, please report a bug to PyTorch or the backend you are using.
from user code:
   File "/data/users/ezyang/b/pytorch/a.py", line 6, in fn
    return torch.functional.split(x, 0)
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
Doesn't fail in eager, so something wrong with the decomp.
Internal xref: https://fb.workplace.com/groups/6829516587176185/posts/7694722327322269/
Versions
main
bearzx
Metadata
Metadata
Assignees
Labels
good first issuemodule: decompositionsTopics related to decomposition (excluding PrimTorch)Topics related to decomposition (excluding PrimTorch)oncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module