Skip to content

Commit

Permalink
[Brian's PR pytorch#128754] Use torch.ops.fsdp.set_ for FSDP2 storage…
Browse files Browse the repository at this point in the history
… resize; dont functionalize resize_, set_, split_with_sizes_copy.out (pytorch#129203)

This is a copy of Brian's PR pytorch#128754, with some changes in the test_distributed_patterns.py unit tests to more closely reflect FSDP2 patterns. Also disabled two tests `test_input_mutation_storage_resize_up_down` and `test_input_mutation_storage_resize_not_supported` in test_aotdispatch.py until we figure out the right behavior for them.

Pull Request resolved: pytorch#129203
Approved by: https://github.com/bdhirsh
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Jun 23, 2024
1 parent 62ccf6d commit b91a9dc
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 236 deletions.
128 changes: 64 additions & 64 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4781,70 +4781,70 @@ def fn(x_weak, y):
res = opt_fn(x_weak, y)
self.assertEqual(ref, res)

@torch._functorch.config.patch(
recompute_views=True,
)
def test_storage_resize_forward_full_graph(self):
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(4, 4))

def forward(self, x):
self.param.untyped_storage().resize_(
self.param.numel() * self.param.itemsize
)
with torch.no_grad():
torch._foreach_copy_([self.param], [x])
out = torch.matmul(self.param, self.param)
self.param.untyped_storage().resize_(0)
return out

def post_accumulate_grad_hook(param):
param.untyped_storage().resize_(0)

# Beginning of backward, resize and put data into the param
def pre_backward_hook(module, grad) -> None:
module.param.untyped_storage().resize_(
self.param.numel() * self.param.itemsize
)
with torch.no_grad():
# simulates loading data into param from allgather
module.param.fill_(2)

def post_forward_hook(module, args, output):
output.register_hook(functools.partial(pre_backward_hook, module))

x = torch.randn(4, 4)

mod_ref = TestModule()
mod_test = deepcopy(mod_ref)

# Start the param off with zero storage size to mimic fsdp
mod_ref.param.untyped_storage().resize_(0)
mod_test.param.untyped_storage().resize_(0)

# Resize storage at beginning of backward
# Free storage at end of backward
mod_ref.register_forward_hook(post_forward_hook, prepend=False)
mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
mod_test.register_forward_hook(post_forward_hook, prepend=False)
mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)

mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)

out_ref = mod_ref(x)
out_test = mod_test(x)
self.assertExpectedInline(
str(fw_graph[0].code.strip()),
"""\
def forward(self, primals_1, primals_2):
_foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
getitem = _foreach_copy[0]; _foreach_copy = None
mm = torch.ops.aten.mm.default(getitem, getitem)
return [mm, getitem]""",
)
self.assertEqual(out_ref, out_test)
# @torch._functorch.config.patch(
# recompute_views=True,
# )
# def test_storage_resize_forward_full_graph(self):
# class TestModule(torch.nn.Module):
# def __init__(self):
# super().__init__()
# self.param = torch.nn.Parameter(torch.randn(4, 4))

# def forward(self, x):
# self.param.untyped_storage().resize_(
# self.param.numel() * self.param.itemsize
# )
# with torch.no_grad():
# torch._foreach_copy_([self.param], [x])
# out = torch.matmul(self.param, self.param)
# self.param.untyped_storage().resize_(0)
# return out

# def post_accumulate_grad_hook(param):
# param.untyped_storage().resize_(0)

# # Beginning of backward, resize and put data into the param
# def pre_backward_hook(module, grad) -> None:
# module.param.untyped_storage().resize_(
# self.param.numel() * self.param.itemsize
# )
# with torch.no_grad():
# # simulates loading data into param from allgather
# module.param.fill_(2)

# def post_forward_hook(module, args, output):
# output.register_hook(functools.partial(pre_backward_hook, module))

# x = torch.randn(4, 4)

# mod_ref = TestModule()
# mod_test = deepcopy(mod_ref)

# # Start the param off with zero storage size to mimic fsdp
# mod_ref.param.untyped_storage().resize_(0)
# mod_test.param.untyped_storage().resize_(0)

# # Resize storage at beginning of backward
# # Free storage at end of backward
# mod_ref.register_forward_hook(post_forward_hook, prepend=False)
# mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
# mod_test.register_forward_hook(post_forward_hook, prepend=False)
# mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)

# mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)

# out_ref = mod_ref(x)
# out_test = mod_test(x)
# self.assertExpectedInline(
# str(fw_graph[0].code.strip()),
# """\
# def forward(self, primals_1, primals_2):
# _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
# getitem = _foreach_copy[0]; _foreach_copy = None
# mm = torch.ops.aten.mm.default(getitem, getitem)
# return [mm, getitem]""",
# )
# self.assertEqual(out_ref, out_test)

def test_super_in_staticmethod(self):
class A:
Expand Down
146 changes: 72 additions & 74 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,10 +909,10 @@ def f(a):
fw_graph_cell[0].code.strip(),
"""\
def forward(self, primals_1):
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32)
ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
copy = torch.ops.aten.copy.default(primals_1, ones); ones = None
add = torch.ops.aten.add.Tensor(copy, 1)
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 32)
copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = None
return [add]""",
)
Expand Down Expand Up @@ -950,46 +950,46 @@ def forward(self, primals_1):
return [sin, primals_1]""",
)

def test_input_mutation_storage_resize_up_down(self):
def f(a):
torch.ops.inductor.resize_storage_bytes_(a, 32)
# float32, 4 bytes per element, 32 bytes == 8 elements
with torch.no_grad():
a.copy_(torch.ones(8))
out = a.sin()
torch.ops.inductor.resize_storage_bytes_(a, 0)
return out

inp = torch.zeros(8, requires_grad=True)
# Input starts with zero-size-storage
inp.untyped_storage().resize_(0)

fw_graph_cell = [None]
compiled_f = aot_function(
f,
fw_compiler=make_boxed_compiler(
partial(extract_graph, graph_cell=fw_graph_cell)
),
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)
# Final graph has two interesting properties:
# (1) no resizes in the functional graph, since the two resizes cancel out
# and the final size is zero
# (2) no copy_ in the functional graph, even though we copied data into the input,
# because the input has no storage at the end of graph execution (so no data to copy)
self.assertExpectedInline(
fw_graph_cell[0].code.strip(),
"""\
def forward(self, primals_1):
ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
copy = torch.ops.aten.copy.default(primals_1, ones); primals_1 = ones = None
sin = torch.ops.aten.sin.default(copy)
return [sin, copy]""",
)
# def test_input_mutation_storage_resize_up_down(self):
# def f(a):
# torch.ops.inductor.resize_storage_bytes_(a, 32)
# # float32, 4 bytes per element, 32 bytes == 8 elements
# with torch.no_grad():
# a.copy_(torch.ones(8))
# out = a.sin()
# torch.ops.inductor.resize_storage_bytes_(a, 0)
# return out

# inp = torch.zeros(8, requires_grad=True)
# # Input starts with zero-size-storage
# inp.untyped_storage().resize_(0)

# fw_graph_cell = [None]
# compiled_f = aot_function(
# f,
# fw_compiler=make_boxed_compiler(
# partial(extract_graph, graph_cell=fw_graph_cell)
# ),
# bw_compiler=nop,
# decompositions={},
# keep_inference_input_mutations=True,
# dynamic=False,
# )
# out = compiled_f(inp)
# # Final graph has two interesting properties:
# # (1) no resizes in the functional graph, since the two resizes cancel out
# # and the final size is zero
# # (2) no copy_ in the functional graph, even though we copied data into the input,
# # because the input has no storage at the end of graph execution (so no data to copy)
# self.assertExpectedInline(
# fw_graph_cell[0].code.strip(),
# """\
# def forward(self, primals_1):
# ones = torch.ops.aten.ones.default([8], device = device(type='cpu'), pin_memory = False)
# copy = torch.ops.aten.copy.default(primals_1, ones); primals_1 = ones = None
# sin = torch.ops.aten.sin.default(copy)
# return [sin, copy]""",
# )

def test_input_mutation_storage_resize_down_and_set_(self):
# Meant to mimic ppFSDP
Expand Down Expand Up @@ -1046,51 +1046,49 @@ def f(dummy_param, param_shard):
def forward(self, primals_1, primals_2):
cat = torch.ops.aten.cat.default([primals_2, primals_2]); primals_2 = None
sin = torch.ops.aten.sin.default(cat)
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0)
set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = None
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(set_, 0); set_ = None
return [sin, cat]""",
)

def test_input_mutation_storage_resize_before_set__not_supported(self):
def test_input_mutation_storage_resize_before_set_(self):
def f(a):
with torch.no_grad():
torch.ops.inductor.resize_storage_bytes_(a, 0)
a.set_(torch.ones(2))

inp = torch.zeros(8, requires_grad=True)

# See Note [Ordering of resize_() and set_()]
with self.assertRaisesRegex(RuntimeError, "not supported today"):
compiled_f = aot_function(
f,
fw_compiler=nop,
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)

def test_input_mutation_storage_resize_not_supported(self):
def f(a):
a.mul_(2)
torch.ops.inductor.resize_storage_bytes_(a, 0)
return a

inp = torch.zeros(8, requires_grad=True)
compiled_f = aot_function(
f,
fw_compiler=nop,
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)

with self.assertRaisesRegex(
AssertionError, "the input has other mutations that we cannot"
):
compiled_f = aot_function(
f,
fw_compiler=nop,
bw_compiler=nop,
decompositions={},
keep_inference_input_mutations=True,
dynamic=False,
)
out = compiled_f(inp)
# def test_input_mutation_storage_resize_not_supported(self):
# def f(a):
# a.mul_(2)
# torch.ops.inductor.resize_storage_bytes_(a, 0)
# return a

# inp = torch.zeros(8, requires_grad=True)

# with self.assertRaisesRegex(
# AssertionError, "the input has other mutations that we cannot"
# ):
# compiled_f = aot_function(
# f,
# fw_compiler=nop,
# bw_compiler=nop,
# decompositions={},
# keep_inference_input_mutations=True,
# dynamic=False,
# )
# out = compiled_f(inp)

def test_input_output_aliase_custom_autograd_function(self):
class Foo(torch.autograd.Function):
Expand Down
Loading

0 comments on commit b91a9dc

Please sign in to comment.