Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support compilation via Torchdynamo, AOT Autograd, NVFuser #17308

Merged
merged 10 commits into from
May 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Address comments
  • Loading branch information
anijain2305 committed May 19, 2022
commit 0b4c279a72a9c33dab00410f8194b107271f67a2
6 changes: 6 additions & 0 deletions src/transformers/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_vision_available,
)

Expand Down Expand Up @@ -464,6 +465,11 @@ def require_torch_tpu(test_case):
jax_device = None


def require_torchdynamo(test_case):
"""Decorator marking a test that requires TorchDynamo"""
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)


def require_torch_gpu(test_case):
"""Decorator marking a test that requires CUDA and PyTorch."""
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)
Expand Down
45 changes: 30 additions & 15 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

from tqdm.auto import tqdm


# Integrations must be imported before ML frameworks:
from .integrations import ( # isort: split
default_hp_search_backend,
Expand Down Expand Up @@ -138,8 +139,10 @@
is_sagemaker_dp_enabled,
is_sagemaker_mp_enabled,
is_torch_tpu_available,
is_torchdynamo_available,
logging,
)
from .utils.generic import ContextManagers


_is_torch_generator_available = False
Expand Down Expand Up @@ -2171,17 +2174,32 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s

return inputs

def compute_loss_context_manager(self):
"""
A helper wrapper to group together context managers.
"""
return ContextManagers(
[
self.torchdynamo_smart_context_manager(),
self.autocast_smart_context_manager(),
]
)

def torchdynamo_smart_context_manager(self):
"""
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
"""
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
ctx_manager = contextlib.nullcontext()
if is_torchdynamo_available():
import torchdynamo
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy

if self.args.use_torchdynamo:
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
else:
ctx_manager = contextlib.nullcontext()
if self.args.torchdynamo == "eager":
ctx_manager = torchdynamo.optimize("eager")
elif self.args.torchdynamo == "nvfuser":
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
else:
ctx_manager = contextlib.nullcontext()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure that else is needed since we set it to this by default.

return ctx_manager

def autocast_smart_context_manager(self):
Expand Down Expand Up @@ -2225,9 +2243,8 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.torchdynamo_smart_context_manager():
with self.autocast_smart_context_manager():
loss = self.compute_loss(model, inputs)
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs)

if self.args.n_gpu > 1:
loss = loss.mean() # mean() to average on multi-gpu parallel training
Expand Down Expand Up @@ -2920,9 +2937,8 @@ def prediction_step(
logits = smp_nested_concat(logits_mb)
else:
if has_labels:
with self.torchdynamo_smart_context_manager():
with self.autocast_smart_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
with self.compute_loss_context_manager():
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
loss = loss.mean().detach()

if isinstance(outputs, dict):
Expand All @@ -2931,9 +2947,8 @@ def prediction_step(
logits = outputs[1:]
else:
loss = None
with self.torchdynamo_smart_context_manager():
with self.autocast_smart_context_manager():
outputs = model(**inputs)
with self.compute_loss_context_manager():
outputs = model(**inputs)
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/trainer_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,8 @@ def prediction_step(
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])

with torch.no_grad():
with self.torchdynamo_smart_context_manager():
with self.autocast_smart_context_manager():
outputs = model(**inputs)
with self.compute_loss_context_manager():
outputs = model(**inputs)
if has_labels:
if self.label_smoother is not None:
loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ class TrainingArguments:
full_determinism (`bool`, *optional*, defaults to `False`)
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
distributed training
use_torchdynamo ('bool`, `str`, defaults to `False`):
torchdynamo (`str`, *optional*):
If `True`, TorchDynamo is called with AOT Autograd and nvfuser compiler to compile the appropriate portions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't match the actual usage. Definitely not True/False but the actual choices : eager, nvfuser

of the model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
of the model.
of the model. This is an experimental API and it may change.

"""
Expand Down Expand Up @@ -884,15 +884,16 @@ class TrainingArguments:
)
},
)
use_torchdynamo: bool = field(
default=False,
torchdynamo: Optional[str] = field(
default=None,
metadata={
"help": (
"Whether or not to use TorchDynamo. TorchDynamo is a Python level JIT compilers designed to make"
"Whether or not to use TorchDynamo. TorchDynamo is a Python level JIT compiler designed to make"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as the comment above - not whether or not, but how - via choices

" unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
" before its executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations"
" and lift them up into Fx fraph. We can then pass these Fx graphs to other backend compilers. Here"
" we use AOT Autograd and nvfuser compiler."
" and lift them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
),
},
)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
is_torch_tf32_available,
is_torch_tpu_available,
is_torchaudio_available,
is_torchdynamo_available,
is_training_run_on_sagemaker,
is_vision_available,
requires_backends,
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,15 @@ def is_torch_tpu_available():
return importlib.util.find_spec("torch_xla.core.xla_model") is not None


def is_torchdynamo_available():
try:
import torchdynamo

return True
except ImportError:
return False


def is_datasets_available():
return _datasets_available

Expand Down
31 changes: 31 additions & 0 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require_torch_non_multi_gpu,
require_torch_tf32,
require_torch_up_to_2_gpus,
require_torchdynamo,
require_wandb,
slow,
)
Expand Down Expand Up @@ -1594,6 +1595,36 @@ def test_fp16_full_eval(self):
# perfect world: fp32_init/2 == fp16_eval
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)

@require_torch_gpu
@require_torchdynamo
def test_torchdynamo_full_eval(self):
debug = 0
n_gpus = get_gpu_count()

bs = 8
eval_len = 16 * n_gpus
# make the params somewhat big so that there will be enough RAM consumed to be able to
# measure things. We should get about 64KB for a+b in fp32
a = torch.ones(1000, bs) + 0.001
b = torch.ones(1000, bs) - 0.001

# 1. Default - without TorchDynamo
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
metrics = trainer.evaluate()
original_eval_loss = metrics["eval_loss"]
del trainer

# 2. TorchDynamo eager
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
del trainer

# 3. TorchDynamo nvfuser
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
metrics = trainer.evaluate()
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)

anijain2305 marked this conversation as resolved.
Show resolved Hide resolved
@require_torch_gpu
@require_torch_bf16
def test_bf16_full_eval(self):
Expand Down