Skip to content

Commit

Permalink
AOTDispatcher: properly bump version counter on input mutations in in…
Browse files Browse the repository at this point in the history
…ference graphs (pytorch#131665)

This ensures that in an inference setting, we properly bump the VC of mutated graph inputs. Previously, we would only properly bump the VC for training graphs.

Pull Request resolved: pytorch#131665
Approved by: https://github.com/ezyang, https://github.com/zou3519
ghstack dependencies: pytorch#131403, pytorch#131482
  • Loading branch information
bdhirsh authored and pytorchmergebot committed Jul 26, 2024
1 parent 5570a0d commit e4ace1a
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 22 deletions.
11 changes: 11 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4920,6 +4920,17 @@ def fn(obj):
compiled_str = str(e)
self.assertEqual(orig_str, compiled_str)

def test_vc_bumped_in_inference_graph(self):
@torch.compile
def f(x):
return x.mul_(2)

x = torch.randn(4)
vc_before = x._version
f(x)
vc_after = x._version
self.assertTrue(vc_after > vc_before)

def test_nn_module_callable(self):
class M(nn.Module):
def forward(self, x):
Expand Down
7 changes: 6 additions & 1 deletion test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4402,8 +4402,13 @@ def forward(self, x):
out_test = compiled_f(inp_test.clone())

eager_version_counters_after = [
buffer._version for _, buffer in model_for_eager.named_buffers()
# TODO: remove the + 1 after https://github.com/pytorch/pytorch/issues/120622 is fixed
buffer._version + 1
if k in ["m.running_mean", "m.running_var"]
else buffer._version
for k, buffer in model_for_eager.named_buffers()
]

compile_version_counters_after = [
buffer._version for _, buffer in model_for_compile.named_buffers()
]
Expand Down
28 changes: 7 additions & 21 deletions torch/_functorch/_aot_autograd/runtime_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,12 @@ def runtime_wrapper(args: List[Any]):
# stash a ref to each input tensor we plan to use after the compiled function
orig_inputs = {i: args[i] for i in epilogue_args_idx}

if keep_input_mutations:
for i in runtime_metadata.mutated_graph_handled_indices_seen_by_autograd:
arg = args[i]
if not arg.is_inference(): # inference tensors have no VC
torch.autograd.graph.increment_version(arg)

if trace_joint:
args_ = list(args)
# See Note [Detaching inputs that never need gradients]
Expand Down Expand Up @@ -329,14 +335,6 @@ def runtime_wrapper(args: List[Any]):
num_mutated_runtime_inps = runtime_metadata.num_mutated_inp_runtime_indices
num_intermediate_bases = runtime_metadata.num_intermediate_bases

if keep_input_mutations and trace_joint:
num_input_mutations_handled_by_autograd = (
runtime_metadata.num_mutated_graph_handled_indices_seen_by_autograd
)
# autograd.Function requires us to return the mutated inputs as extra outputs to the autograd.Function.forward
if num_input_mutations_handled_by_autograd > 0:
all_outs = all_outs[:-num_input_mutations_handled_by_autograd]

assert (
len(all_outs)
== num_mutated_runtime_inps
Expand Down Expand Up @@ -1493,13 +1491,6 @@ def forward(ctx, *deduped_flat_tensor_args):
assert isinstance(bw_state, BackwardState)
ctx._compiled_autograd_backward_state = bw_state

marked_dirty_inps = []
for i in fw_metadata.mutated_graph_handled_indices_seen_by_autograd:
arg = deduped_flat_tensor_args[i]
if not (arg.requires_grad and arg.is_leaf): # would error
ctx.mark_dirty(arg)
marked_dirty_inps.append(arg)

# There is a pretty complicated calling convention around what the compiled fw returns.
# The full list of outputs and their relative order is:
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
Expand Down Expand Up @@ -1614,7 +1605,7 @@ def forward(ctx, *deduped_flat_tensor_args):
]
ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
ctx._materialize_non_diff_grads = False
return tuple(raw_returns) + tuple(marked_dirty_inps)
return tuple(raw_returns)

@staticmethod
def backward(ctx, *flat_args):
Expand All @@ -1630,9 +1621,6 @@ def backward(ctx, *flat_args):
num_intermediate_bases = (
CompiledFunction.metadata.num_intermediate_bases
)
num_graph_handled_inputs = (
CompiledFunction.metadata.num_mutated_graph_handled_indices_seen_by_autograd
)
num_mutated_runtime_inps = (
CompiledFunction.metadata.num_mutated_inp_runtime_indices
)
Expand All @@ -1656,8 +1644,6 @@ def backward(ctx, *flat_args):
),
)

if num_graph_handled_inputs > 0:
flat_args = flat_args[:-num_graph_handled_inputs]
assert len(flat_args) == expected_grad_outs
out_info = CompiledFunction.metadata.output_info

Expand Down

0 comments on commit e4ace1a

Please sign in to comment.