Skip to content

Commit 897a8dd

Browse files
anijain2305stas00
andauthored
Support compilation via Torchdynamo, AOT Autograd, NVFuser (#17308)
* Support compilation via Torchdynamo, AOT Autograd, NVFuser * Address comments * Lint * Stas comments - missing quality test * Lintere * Quality test * Doc lint * Reset CUDA peak mem * Add CustomTrainer * require a single gpu Co-authored-by: Stas Bekman <stas@stason.org>
1 parent 31484af commit 897a8dd

File tree

7 files changed

+155
-4
lines changed

7 files changed

+155
-4
lines changed

src/transformers/testing_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
is_torch_tf32_available,
7171
is_torch_tpu_available,
7272
is_torchaudio_available,
73+
is_torchdynamo_available,
7374
is_vision_available,
7475
)
7576

@@ -464,6 +465,11 @@ def require_torch_tpu(test_case):
464465
jax_device = None
465466

466467

468+
def require_torchdynamo(test_case):
469+
"""Decorator marking a test that requires TorchDynamo"""
470+
return unittest.skipUnless(is_torchdynamo_available(), "test requires TorchDynamo")(test_case)
471+
472+
467473
def require_torch_gpu(test_case):
468474
"""Decorator marking a test that requires CUDA and PyTorch."""
469475
return unittest.skipUnless(torch_device == "cuda", "test requires CUDA")(test_case)

src/transformers/trainer.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,10 @@
139139
is_sagemaker_dp_enabled,
140140
is_sagemaker_mp_enabled,
141141
is_torch_tpu_available,
142+
is_torchdynamo_available,
142143
logging,
143144
)
145+
from .utils.generic import ContextManagers
144146

145147

146148
_is_torch_generator_available = False
@@ -2172,6 +2174,32 @@ def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[s
21722174

21732175
return inputs
21742176

2177+
def compute_loss_context_manager(self):
2178+
"""
2179+
A helper wrapper to group together context managers.
2180+
"""
2181+
return ContextManagers(
2182+
[
2183+
self.torchdynamo_smart_context_manager(),
2184+
self.autocast_smart_context_manager(),
2185+
]
2186+
)
2187+
2188+
def torchdynamo_smart_context_manager(self):
2189+
"""
2190+
A helper wrapper that creates an appropriate context manager for `torchdynamo`.
2191+
"""
2192+
ctx_manager = contextlib.nullcontext()
2193+
if is_torchdynamo_available():
2194+
import torchdynamo
2195+
from torchdynamo.optimizations.training import aot_autograd_speedup_strategy
2196+
2197+
if self.args.torchdynamo == "eager":
2198+
ctx_manager = torchdynamo.optimize("eager")
2199+
elif self.args.torchdynamo == "nvfuser":
2200+
ctx_manager = torchdynamo.optimize(aot_autograd_speedup_strategy)
2201+
return ctx_manager
2202+
21752203
def autocast_smart_context_manager(self):
21762204
"""
21772205
A helper wrapper that creates an appropriate context manager for `autocast` while feeding it the desired
@@ -2213,7 +2241,7 @@ def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor,
22132241
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
22142242
return loss_mb.reduce_mean().detach().to(self.args.device)
22152243

2216-
with self.autocast_smart_context_manager():
2244+
with self.compute_loss_context_manager():
22172245
loss = self.compute_loss(model, inputs)
22182246

22192247
if self.args.n_gpu > 1:
@@ -2907,7 +2935,7 @@ def prediction_step(
29072935
logits = smp_nested_concat(logits_mb)
29082936
else:
29092937
if has_labels:
2910-
with self.autocast_smart_context_manager():
2938+
with self.compute_loss_context_manager():
29112939
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
29122940
loss = loss.mean().detach()
29132941

@@ -2917,7 +2945,7 @@ def prediction_step(
29172945
logits = outputs[1:]
29182946
else:
29192947
loss = None
2920-
with self.autocast_smart_context_manager():
2948+
with self.compute_loss_context_manager():
29212949
outputs = model(**inputs)
29222950
if isinstance(outputs, dict):
29232951
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)

src/transformers/trainer_seq2seq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def prediction_step(
183183
generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_kwargs["max_length"])
184184

185185
with torch.no_grad():
186-
with self.autocast_smart_context_manager():
186+
with self.compute_loss_context_manager():
187187
outputs = model(**inputs)
188188
if has_labels:
189189
if self.label_smoother is not None:

src/transformers/training_args.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,9 @@ class TrainingArguments:
450450
full_determinism (`bool`, *optional*, defaults to `False`)
451451
If `True`, [`enable_full_determinism`] is called instead of [`set_seed`] to ensure reproducible results in
452452
distributed training
453+
torchdynamo (`str`, *optional*):
454+
The token that is used to set the backend compiler for TorchDynamo. Possible choices are ["eager",
455+
"nvfuser]. This is an experimental API and subject to change.
453456
"""
454457

455458
output_dir: str = field(
@@ -881,6 +884,20 @@ class TrainingArguments:
881884
)
882885
},
883886
)
887+
torchdynamo: Optional[str] = field(
888+
default=None,
889+
metadata={
890+
"help": (
891+
"Sets up the backend compiler for TorchDynamo. TorchDynamo is a Python level JIT compiler designed to"
892+
" make unmodified PyTorch programs faster. TorchDynamo dynamically modifies the Python bytecode right"
893+
" before its executed. It rewrites Python bytecode to extract sequences of PyTorch operations"
894+
" and lifts them up into Fx graph. We can then pass these Fx graphs to other backend compilers. There"
895+
" are two options - eager and nvfuser. Eager defaults to pytorch eager and is useful for debugging."
896+
" nvfuser path uses AOT Autograd and nvfuser compiler to optimize the models."
897+
),
898+
"choices": ["eager", "nvfuser"],
899+
},
900+
)
884901

885902
def __post_init__(self):
886903
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).

src/transformers/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@
130130
is_torch_tf32_available,
131131
is_torch_tpu_available,
132132
is_torchaudio_available,
133+
is_torchdynamo_available,
133134
is_training_run_on_sagemaker,
134135
is_vision_available,
135136
requires_backends,

src/transformers/utils/import_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,10 @@ def is_torch_tpu_available():
376376
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
377377

378378

379+
def is_torchdynamo_available():
380+
return importlib.util.find_spec("torchdynamo") is not None
381+
382+
379383
def is_datasets_available():
380384
return _datasets_available
381385

tests/trainer/test_trainer.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
require_torch_non_multi_gpu,
6363
require_torch_tf32,
6464
require_torch_up_to_2_gpus,
65+
require_torchdynamo,
6566
require_wandb,
6667
slow,
6768
)
@@ -1594,6 +1595,100 @@ def test_fp16_full_eval(self):
15941595
# perfect world: fp32_init/2 == fp16_eval
15951596
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
15961597

1598+
@require_torch_non_multi_gpu
1599+
@require_torchdynamo
1600+
def test_torchdynamo_full_eval(self):
1601+
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
1602+
n_gpus = get_gpu_count()
1603+
1604+
bs = 8
1605+
eval_len = 16 * n_gpus
1606+
# make the params are somewhat big so that there will be enough RAM consumed to be able to
1607+
# measure things. We should get about 64KB for a+b in fp32
1608+
a = torch.ones(1000, bs) + 0.001
1609+
b = torch.ones(1000, bs) - 0.001
1610+
1611+
# 1. Default - without TorchDynamo
1612+
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len)
1613+
metrics = trainer.evaluate()
1614+
original_eval_loss = metrics["eval_loss"]
1615+
del trainer
1616+
1617+
# 2. TorchDynamo eager
1618+
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="eager")
1619+
metrics = trainer.evaluate()
1620+
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
1621+
del trainer
1622+
1623+
# 3. TorchDynamo nvfuser
1624+
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, torchdynamo="nvfuser")
1625+
metrics = trainer.evaluate()
1626+
self.assertAlmostEqual(metrics["eval_loss"], original_eval_loss)
1627+
1628+
@require_torch_non_multi_gpu
1629+
@require_torchdynamo
1630+
def test_torchdynamo_memory(self):
1631+
# torchdynamo at the moment doesn't support DP/DDP, therefore require a single gpu
1632+
class CustomTrainer(Trainer):
1633+
def compute_loss(self, model, inputs, return_outputs=False):
1634+
x = inputs["x"]
1635+
output = model(x)
1636+
if self.args.n_gpu == 1:
1637+
return output.mean()
1638+
return output
1639+
1640+
class MyModule(torch.nn.Module):
1641+
"""Simple module that does aggressive fusion"""
1642+
1643+
def __init__(self):
1644+
super().__init__()
1645+
1646+
def forward(self, x):
1647+
for _ in range(20):
1648+
x = torch.nn.functional.relu(x)
1649+
return x
1650+
1651+
mod = MyModule()
1652+
1653+
# 1. Default - without TorchDynamo
1654+
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
1655+
a.grad = None
1656+
trainer = CustomTrainer(model=mod)
1657+
# warmup
1658+
for _ in range(10):
1659+
orig_loss = trainer.training_step(mod, {"x": a})
1660+
1661+
torch.cuda.reset_peak_memory_stats()
1662+
orig_loss = trainer.training_step(mod, {"x": a})
1663+
orig_peak_mem = torch.cuda.max_memory_allocated()
1664+
del trainer
1665+
1666+
# Reset the peak for another measurement
1667+
gc.collect()
1668+
torch.cuda.empty_cache()
1669+
torch.cuda.reset_peak_memory_stats()
1670+
1671+
# 2. TorchDynamo nvfuser
1672+
a = torch.ones(1024, 1024, device="cuda", requires_grad=True)
1673+
a.grad = None
1674+
args = TrainingArguments(output_dir="None", torchdynamo="nvfuser")
1675+
trainer = CustomTrainer(model=mod, args=args)
1676+
# warmup
1677+
for _ in range(10):
1678+
loss = trainer.training_step(mod, {"x": a})
1679+
1680+
torch.cuda.reset_peak_memory_stats()
1681+
loss = trainer.training_step(mod, {"x": a})
1682+
peak_mem = torch.cuda.max_memory_allocated()
1683+
del trainer
1684+
1685+
# Functional check
1686+
self.assertAlmostEqual(loss, orig_loss)
1687+
1688+
# AOT Autograd recomputaion and nvfuser recomputation optimization
1689+
# aggressively fuses the operations and reduce the memory footprint.
1690+
self.assertGreater(orig_peak_mem, peak_mem * 2)
1691+
15971692
@require_torch_gpu
15981693
@require_torch_bf16
15991694
def test_bf16_full_eval(self):

0 commit comments

Comments
 (0)