Skip to content

Dynamo compiled_autograd crashed when leaf tensor is on CPU device and move to xla later #5676

Open
@JackCaoG

Description

@JackCaoG

🐛 Bug

To Reproduce

import torch
from torch._dynamo import compiled_autograd

def compiler_fn(gm):
    return torch.compile(
        gm, backend='eager', fullgraph=True, dynamic=False
    )

def fn_simple(x, y):
    a = torch.cos(x)
    b = torch.sin(y)
    c = a + b
    loss = c + 0.1
    return loss

device = "cuda"
x = torch.tensor(100.0, requires_grad=True)
y = torch.tensor(200.0, requires_grad=True)
xla_x = x.to(device)
xla_y = y.to(device)
fn_simple_dynamo = torch.compile(fn_simple, backend="eager")
loss = fn_simple_dynamo(xla_x, xla_y)
with compiled_autograd.enable(compiler_fn):
    loss.backward()

Steps to reproduce the behavior:

  1. run above code using nightly pytorch and torch_xla

Expected behavior

code finish without crash and calculate the grad correctly.

Environment

  • Reproducible on XLA backend [CPU/TPU]: both
  • torch_xla version: nightly

Additional context

If I changed the code to do

x = torch.tensor(100.0, requires_grad=True, device=device)
y = torch.tensor(200.0, requires_grad=True, device=device)

it actually passed. My observation is that if we first init the tensor then move it to device. x and y are leaf nodes instead of xla_x and xla_y. The grad on x and y are also on cpu device. My current guess is that x.grad and y.grad is part of the compiled_autograd backward graph and xla can not handle mix deivces(xla and cpu in this case).

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions