Skip to content

Commit

Permalink
Lint
Browse files Browse the repository at this point in the history
  • Loading branch information
anijain2305 committed May 19, 2022
1 parent 0b4c279 commit 52503f2
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 9 deletions.
4 changes: 2 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2198,8 +2198,8 @@ def torchdynamo_smart_context_manager(self):
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
else:
ctx_manager = contextlib.nullcontext()
elif self.args.torchdynamo is not None:
raise ValueError("torchdynamo training arg can be eager/nvfuser")
return ctx_manager

def autocast_smart_context_manager(self):
Expand Down
7 changes: 1 addition & 6 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,12 +377,7 @@ def is_torch_tpu_available():


def is_torchdynamo_available():
try:
import torchdynamo

return True
except ImportError:
return False
return importlib.util.find_spec("torchdynamo") is not None


def is_datasets_available():
Expand Down
1 change: 0 additions & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1598,7 +1598,6 @@ def test_fp16_full_eval(self):
@require_torch_gpu
@require_torchdynamo
def test_torchdynamo_full_eval(self):
debug = 0
n_gpus = get_gpu_count()

bs = 8
Expand Down

0 comments on commit 52503f2

Please sign in to comment.