From 9d3913a2e6b23c8b75607f7aa043de0c4f237c5e Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Wed, 21 Dec 2022 14:04:36 -0800 Subject: [PATCH] Offer torchdynamo integration (#287) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/287 Add torchdynamo support to auto unit Reviewed By: ananthsub Differential Revision: D41315198 fbshipit-source-id: cf5e2f4cd76e3318d37eaae5d8c6334d154aa25a --- tests/framework/test_auto_unit.py | 158 +++++++++++++++++++++++++++++- tests/utils/test_version.py | 4 + torchtnt/framework/auto_unit.py | 80 ++++++++++++--- torchtnt/utils/__init__.py | 2 + torchtnt/utils/version.py | 4 + 5 files changed, 234 insertions(+), 14 deletions(-) diff --git a/tests/framework/test_auto_unit.py b/tests/framework/test_auto_unit.py index 02195e7f47..9ddd690292 100644 --- a/tests/framework/test_auto_unit.py +++ b/tests/framework/test_auto_unit.py @@ -11,6 +11,7 @@ from unittest.mock import MagicMock, patch import torch +import torch._dynamo from parameterized import parameterized from torch.distributed import launcher from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -20,7 +21,7 @@ generate_random_iterable_dataloader, ) -from torchtnt.framework.auto_unit import AutoUnit, SWAParams +from torchtnt.framework.auto_unit import AutoUnit, SWAParams, TorchDynamoParams from torchtnt.framework.evaluate import evaluate, init_eval_state from torchtnt.framework.predict import init_predict_state, predict from torchtnt.framework.state import State @@ -307,6 +308,156 @@ def forward(self, x): torch.allclose(orig_module.l2.weight, swa_module.module.l2.weight) ) + def test_dynamo_eager(self) -> None: + """ + e2e torchdynamo test + """ + + device = init_from_env() + my_module = torch.nn.Linear(2, 2, device=device) + my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01) + + input_dim = 2 + dataset_len = 16 + batch_size = 2 + max_epochs = 1 + + auto_unit = DummyAutoUnit( + module=my_module, + optimizer=my_optimizer, + torchdynamo_params=TorchDynamoParams("eager"), + ) + + train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + state = init_train_state(dataloader=train_dl, max_epochs=max_epochs) + self.assertFalse(auto_unit._dynamo_used) + train(state, auto_unit) + self.assertTrue(auto_unit._dynamo_used) + + @unittest.skipUnless( + condition=cuda_available, reason="This test needs a GPU host to run." + ) + def test_dynamo_train(self) -> None: + """ + e2e torchdynamo on train + """ + + device = init_from_env() + my_module = torch.nn.Linear(2, 2, device=device) + my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01) + + input_dim = 2 + dataset_len = 16 + batch_size = 2 + max_epochs = 1 + + auto_unit = DummyAutoUnit( + module=my_module, + optimizer=my_optimizer, + torchdynamo_params=TorchDynamoParams("inductor"), + device=device, + ) + + train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + state = init_train_state(dataloader=train_dl, max_epochs=max_epochs) + + self.assertFalse(auto_unit._dynamo_used) + train(state, auto_unit) + self.assertTrue(auto_unit._dynamo_used) + + @unittest.skipUnless( + condition=cuda_available, reason="This test needs a GPU host to run." + ) + def test_dynamo_eval(self) -> None: + """ + e2e torchdynamo on eval + """ + + device = init_from_env() + my_module = torch.nn.Linear(2, 2, device=device) + my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01) + + input_dim = 2 + dataset_len = 16 + batch_size = 2 + + auto_unit = DummyAutoUnit( + module=my_module, + optimizer=my_optimizer, + torchdynamo_params=TorchDynamoParams("inductor"), + device=device, + ) + + input_dim = 2 + dataset_len = 8 + batch_size = 2 + + eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size) + state = init_eval_state(dataloader=eval_dl) + self.assertFalse(auto_unit._dynamo_used) + evaluate(state, auto_unit) + self.assertTrue(auto_unit._dynamo_used) + + @unittest.skipUnless( + condition=cuda_available, reason="This test needs a GPU host to run." + ) + def test_dynamo_predict(self) -> None: + """ + e2e torchdynamo on predict + """ + device = init_from_env() + my_module = torch.nn.Linear(2, 2, device=device) + my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01) + auto_unit = DummyAutoUnit( + module=my_module, + optimizer=my_optimizer, + torchdynamo_params=TorchDynamoParams("inductor"), + device=device, + ) + + input_dim = 2 + dataset_len = 8 + batch_size = 2 + + predict_dl = generate_random_iterable_dataloader( + dataset_len, input_dim, batch_size + ) + state = init_predict_state(dataloader=predict_dl) + self.assertFalse(auto_unit._dynamo_used) + predict(state, auto_unit) + + def test_dynamo_invalid_backend(self) -> None: + """ + verify error is thrown on invalid backend + """ + + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.l1 = torch.nn.Linear(2, 2) + self.b1 = torch.nn.BatchNorm1d(2) + self.l2 = torch.nn.Linear(2, 2) + + def forward(self, x): + x = self.l1(x) + x = self.b1(x) + x = self.l2(x) + return x + + my_module = Net() + my_dynamo_params = TorchDynamoParams(backend="foo") + my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01) + + self.failUnlessRaises( + RuntimeError, + DummyAutoUnit, + **{ + "module": my_module, + "optimizer": my_optimizer, + "torchdynamo_params": my_dynamo_params, + }, + ) + def test_log_frequency_steps_exception(self) -> None: """ Test that an exception is raised when log_frequency_steps is < 1 @@ -539,7 +690,12 @@ def test_move_data_to_device(self) -> None: class DummyAutoUnit(AutoUnit[Batch]): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._dynamo_used = False + def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]: + self._dynamo_used = torch._dynamo.is_compiling() inputs, targets = data outputs = self.module(inputs) loss = torch.nn.functional.cross_entropy(outputs, targets) diff --git a/tests/utils/test_version.py b/tests/utils/test_version.py index cce1e8f9bb..46fc042b09 100644 --- a/tests/utils/test_version.py +++ b/tests/utils/test_version.py @@ -87,3 +87,7 @@ def test_torch_version_comparators(self) -> None: self.assertTrue(version.is_torch_version_geq_1_10()) self.assertTrue(version.is_torch_version_geq_1_11()) self.assertTrue(version.is_torch_version_geq_1_12()) + + with patch.object(torch, "__version__", "2.0.0a0"): + self.assertTrue(version.is_torch_version_ge_1_13_1()) + self.assertFalse(version.is_torch_version_geq_2_0()) diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 3ae7e8ed22..4e1dbe4c09 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -29,6 +29,7 @@ transfer_batch_norm_stats, transfer_weights, ) +from torchtnt.utils.version import is_torch_version_ge_1_13_1 from typing_extensions import Literal TSWA_avg_fn = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor] @@ -54,6 +55,20 @@ class SWAParams: avg_fn: Optional[TSWA_avg_fn] = None +@dataclass +class TorchDynamoParams: + """ + Dataclass to store parameters for torchdynamo. + + Args: + backend: a string backend name in `torch._dynamo.list_backends()` + """ + + backend: str + + +# pyre-ignore: Invalid type parameters [24] +TSelf = TypeVar("TSelf", bound="AutoUnit") TData = TypeVar("TData") @@ -88,6 +103,7 @@ class AutoUnit(TrainUnit[TData], EvalUnit[TData], PredictUnit[Any], ABC): clip_grad_norm: max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html clip_grad_value: max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging + torchdynamo_params: params for TorchDynamo Attributes: module: module to be used during training. @@ -104,6 +120,10 @@ class AutoUnit(TrainUnit[TData], EvalUnit[TData], PredictUnit[Any], ABC): clip_grad_norm: max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html clip_grad_value: max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging + torchdynamo_params: params for TorchDynamo + + Note: + TorchDynamo support is only available in PyTorch 2.0 or higher. """ def __init__( @@ -121,6 +141,7 @@ def __init__( clip_grad_norm: Optional[float] = None, clip_grad_value: Optional[float] = None, swa_params: Optional[SWAParams] = None, + torchdynamo_params: Optional[TorchDynamoParams] = None, ) -> None: super().__init__() self.module = module @@ -190,6 +211,20 @@ def __init__( anneal_strategy=swa_params.anneal_strategy, ) + if torchdynamo_params: + if not is_torch_version_ge_1_13_1(): + raise RuntimeError( + "TorchDynamo support is available only in PyTorch 2.0 or higher. " + "Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/" + ) + # pyre-ignore + self.compute_loss = _dynamo_wrapper(self.compute_loss, torchdynamo_params) + # pyre-ignore + self._forward_and_backward = _dynamo_wrapper( + self._forward_and_backward, torchdynamo_params + ) + self.module = _dynamo_wrapper(self.module, torchdynamo_params) + # TODO: Make AutoTrainUnit work when data type is Iterator @abstractmethod @@ -252,11 +287,27 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: data = self.move_data_to_device(state, data) train_state = none_throws(state.train_state) - should_update_weights = ( train_state.progress.num_steps_completed_in_epoch + 1 ) % self.gradient_accumulation_steps == 0 or train_state.is_last_batch + loss, outputs = self._forward_and_backward(state, data, should_update_weights) + + # users can override this, by default this is a no-op + self.update_metrics(state, data, loss, outputs) + + if should_update_weights: + # TODO try to use dynamo here + self._run_optimizer_lr_scheduler_step(state) + + # log metrics only after an optimizer step + if self.num_optimizer_steps_completed % self.log_frequency_steps == 0: + self.log_metrics(state, self.num_optimizer_steps_completed - 1, "step") + return loss, outputs + + def _forward_and_backward( + self, state: State, data: TData, should_update_weights: bool + ): # if using gradient accumulation and DDP or FSDP, when in a step where we will not update the weights, # run forward and backward in no_sync context # https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync @@ -280,17 +331,6 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]: if grad_scaler: loss = grad_scaler.scale(loss) loss.backward() - - # users can override this, by default this is a no-op - self.update_metrics(state, data, loss, outputs) - - if should_update_weights: - self._run_optimizer_lr_scheduler_step(state) - - # log metrics only after an optimizer step - if self.num_optimizer_steps_completed % self.log_frequency_steps == 0: - self.log_metrics(state, self.num_optimizer_steps_completed - 1, "step") - return loss, outputs def _run_optimizer_lr_scheduler_step(self, state: State) -> None: @@ -399,7 +439,6 @@ def predict_step(self, state: State, data: Any) -> Any: with self.maybe_autocast_precision: outputs = self.module(data) - return outputs @property @@ -443,3 +482,18 @@ def _get_grad_scaler_from_precision( else: return GradScaler() return None + + +# pyre-ignore +def _dynamo_wrapper(fn: Callable, torchdynamo_params: TorchDynamoParams): + backend = torchdynamo_params.backend + try: + return torch.compile(fn, backend=backend) + except KeyError as e: + raise RuntimeError( + f"Torchdynamo backend {torchdynamo_params.backend} is not supported." + ) from e + except Exception as e: + raise RuntimeError( + f"The following error encountered when calling torch.compile for dynamo: {e}" + ) from e diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index fd60f62e55..88ec5fb0f3 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -39,6 +39,7 @@ from .version import ( get_python_version, get_torch_version, + is_torch_version_ge_1_13_1, is_torch_version_geq_1_10, is_torch_version_geq_1_11, is_torch_version_geq_1_12, @@ -87,6 +88,7 @@ "Timer", "get_python_version", "get_torch_version", + "is_torch_version_ge_1_13_1", "is_torch_version_geq_1_10", "is_torch_version_geq_1_11", "is_torch_version_geq_1_12", diff --git a/torchtnt/utils/version.py b/torchtnt/utils/version.py index be0ddae4f7..26b79da832 100644 --- a/torchtnt/utils/version.py +++ b/torchtnt/utils/version.py @@ -78,6 +78,10 @@ def is_torch_version_geq_1_13() -> bool: return get_torch_version() >= Version("1.13.0") +def is_torch_version_ge_1_13_1() -> bool: + return get_torch_version() > Version("1.13.1") + + def is_torch_version_geq_1_14() -> bool: return get_torch_version() >= Version("1.14.0")