Skip to content

Commit

Permalink
Reland "make sure dynamo doesn't inline DTensor __new__ or __torch_di…
Browse files Browse the repository at this point in the history
…spatch__ (pytorch#123347)" (pytorch#125288)

Re-land of pytorch#123347.

The original PR broke internal because of a circular import due to importing dynamo in the DTensor code. The new version uses `torch._dynamo_disable` to work around

This reverts commit 9d88339.

Pull Request resolved: pytorch#125288
Approved by: https://github.com/ezyang, https://github.com/yanboliang, https://github.com/yoyoyocmu, https://github.com/anijain2305, https://github.com/fegin
ghstack dependencies: pytorch#124398, pytorch#124399, pytorch#124400
  • Loading branch information
bdhirsh authored and pytorchmergebot committed May 1, 2024
1 parent 9e9ba61 commit 599a2e2
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
41 changes: 41 additions & 0 deletions test/distributed/_tensor/test_dtensor_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,47 @@ def fn(x):
res = opt_fn(x)
self.assertEqual(res, ref)

def test_dtensor_constructor_w_graph_break(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

# test passing in DTensor as inputs/outputs and run some tensor computation
def fn(x):
print("graph break!")
return DTensor(
x,
mesh,
(Replicate(), Shard(0)),
shape=[128, 32],
dtype=x.dtype,
requires_grad=x.requires_grad,
stride=[32, 1],
)

x = torch.randn(64, 32, requires_grad=True)
out = fn(x)
out2 = torch.compile(fn, backend="eager")(x)

def test_dtensor_constructor_w_dynamo_disable(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

@torch._dynamo.disable(recursive=False)
def fn(x):
print("foo")
return DTensor(
x,
mesh,
(Replicate(),),
shape=torch.Size([32]),
dtype=x.dtype,
requires_grad=x.requires_grad,
stride=(1,),
)

x = torch.randn(32, requires_grad=True)
out = fn(x)
out2 = torch.compile(fn, backend="eager")(x)
self.assertEqual(out, out2)

def test_dtensor_noncontiguous_output(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))

Expand Down
2 changes: 2 additions & 0 deletions torch/distributed/_tensor/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ class DTensor(torch.Tensor): # pyre-ignore[13]: pyre is bad at __new__
_op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher()

@staticmethod
@torch._disable_dynamo
def __new__(
cls,
local_tensor: torch.Tensor,
Expand Down Expand Up @@ -288,6 +289,7 @@ def __coerce_same_metadata_as_tangent__(self, metadata_tensor):
)

@classmethod
@torch._disable_dynamo
# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
Expand Down

0 comments on commit 599a2e2

Please sign in to comment.