Skip to content

Commit

Permalink
AsyncCollectiveTensor: prevent wait_tensor() calls on graph inputs fr…
Browse files Browse the repository at this point in the history
…om getting DCEd (pytorch#125677)

@wanchaol was seeing the loss eventually become NaN when compiling individual transformer blocks in torchtitan - with this patch I no longer see the NaN loss.

The problem is the following:

(1) It is possible to have graph inputs to a compiled region that are AsyncCollectiveTensors. In particular: when we compile individual transformer blocks in the llama model, the first layer (embedding layer) is run in eager mode, and it outputs an AsyncCollectiveTensor that is fed to the first transformer block

(2) ideally, we would like that AsyncCollectiveTensor graph input to desugar into a `wait_tensor()` op that shows up at the beginning of the graph.

(3) the way this is supposed to happen is: AOTAutograd traces through the __torch_dispatch__ of AsyncCollectiveTensor, tracing out a `wait_tensor()` call before dispatching to any of the other ops in the function we are tracing

(4) however: `trigger_wait()` was getting called in a way where we would ignore its output (and return `self.elem` directly), which would cause the `wait_tensor` ops to get DCE'd.

Pull Request resolved: pytorch#125677
Approved by: https://github.com/wanchaol, https://github.com/yifuwang
ghstack dependencies: pytorch#125676
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 8, 2024
1 parent 5d97c22 commit e28d994
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 10 deletions.
28 changes: 27 additions & 1 deletion test/distributed/_tensor/test_dtensor_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, input):


def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
graph_cell[0] = fx_g.code
return fx_g


Expand Down Expand Up @@ -481,6 +481,32 @@ def fn(x_dt):
res = opt_fn(x_dt)
self.assertEqual(ref, res)

def test_graph_input_is_async(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

def fn(x):
return x.sin().sin()

opt_fn = torch.compile(fn, backend=aot_eager_graph, fullgraph=True)

x = torch.randn(4, 4, requires_grad=True)
x_dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
x2 = x_dt.redistribute(mesh, [Replicate()], async_op=True)
x2 = x2.to_local()
out = opt_fn(x2)
# The important part: we get a wait_tensor() in the graph.
# At runtime, the input to the graph is an AsyncCollectiveTensor,
# and inside the graph we need to issue a wait() to synchronize.
self.assertExpectedInline(
str(fw_graph_cell[0]).strip(),
"""\
def forward(self, primals_1):
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
sin = torch.ops.aten.sin.default(wait_tensor)
sin_1 = torch.ops.aten.sin.default(sin); sin = None
return [sin_1, primals_1, wait_tensor]""",
)

@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
def test_dtensor_partial_placement_graph_output(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
Expand Down
17 changes: 8 additions & 9 deletions torch/distributed/_functional_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,8 +590,7 @@ def __tensor_flatten__(self):
return ["elem"], None

def tolist(self):
self.trigger_wait()
return self.elem.tolist()
return self.trigger_wait().tolist()

@staticmethod
def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
Expand All @@ -600,18 +599,18 @@ def __tensor_unflatten__(inner_tensors, meta, outer_size, outer_stride):
return AsyncCollectiveTensor(elem)

def __repr__(self):
self.trigger_wait()
return f"AsyncCollectiveTensor({self.elem})"
return f"AsyncCollectiveTensor({self.trigger_wait()})"

def trigger_wait(self):
if not self.completed:
wait_tensor(self.elem)
out = wait_tensor(self.elem)
self.completed = True
return self.elem
return out
else:
return self.elem

def wait(self) -> torch.Tensor:
wait_tensor(self.elem)
return self.elem
return wait_tensor(self.elem)

def _get_acs_underlying_tensor(self):
"""This method enables _functional_collectives_impl to test if a tensor is an ACS"""
Expand All @@ -631,7 +630,7 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(e: AsyncCollectiveTensor):
# wait_tensor is idepotent and will do stream sync only once
if not is_view_op:
e.trigger_wait()
return e.trigger_wait()
return e.elem

def wrap(e: torch.Tensor):
Expand Down

0 comments on commit e28d994

Please sign in to comment.