Skip to content

Commit

Permalink
[AOT][CUDAGraphs] torchdynamo -> torch._dynamo (pytorch#87243)
Browse files Browse the repository at this point in the history
Fixes lingering issues from the torchdynamo -> torch._dynamo migration
Pull Request resolved: pytorch#87243
Approved by: https://github.com/suo, https://github.com/voznesenskym, https://github.com/jansel
  • Loading branch information
soumith authored and pytorchmergebot committed Oct 21, 2022
1 parent 13ab819 commit ff43288
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions torch/cuda/_dynamo_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
from torch.fx.passes.backends.cudagraphs import partition_cudagraphs
from torch.multiprocessing.reductions import StorageWeakRef
from torch.utils._pytree import tree_map
import torchdynamo # type: ignore[import]
from torchdynamo.optimizations.training import AOTAutogradStrategy # type: ignore[import]
import torch._dynamo # type: ignore[import]
from torch._dynamo.optimizations.training import AotAutogradStrategy # type: ignore[import]

import operator
from collections import defaultdict
from typing import Set

# TODO: maybe this should live in torchdynamo instead
# TODO: maybe this should live in torch._dynamo instead

__all__ = ['aot_autograd_cudagraphs']

Expand Down Expand Up @@ -140,8 +140,8 @@ def raw_aot_autograd_cudagraphs(model, inputs):
}

def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return torchdynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]
# stop dynamo from trying to compile our generated backwards pass
return torch._dynamo.disable(bw_compiler(*args, **kwargs)) # type: ignore[operator]

bw_compiler = kwargs.get("bw_compiler") or kwargs["fw_compiler"]
kwargs["bw_compiler"] = _wrapped_bw_compiler
Expand All @@ -151,7 +151,7 @@ def _wrapped_bw_compiler(*args, **kwargs):
return aot_module_simplified(model, **kwargs)


class AOTAutogradCudaGraphs(AOTAutogradStrategy):
class AOTAutogradCudaGraphs(AotAutogradStrategy):
def candidate(self):
return raw_aot_autograd_cudagraphs(self.gm, self.example_inputs)

Expand Down

0 comments on commit ff43288

Please sign in to comment.