Skip to content

Commit

Permalink
Offer torchdynamo integration (pytorch#287)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#287

Add torchdynamo support to auto unit

Reviewed By: ananthsub

Differential Revision: D41315198

fbshipit-source-id: cf5e2f4cd76e3318d37eaae5d8c6334d154aa25a
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 21, 2022
1 parent ad1593d commit 9d3913a
Show file tree
Hide file tree
Showing 5 changed files with 234 additions and 14 deletions.
158 changes: 157 additions & 1 deletion tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from unittest.mock import MagicMock, patch

import torch
import torch._dynamo
from parameterized import parameterized
from torch.distributed import launcher
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
Expand All @@ -20,7 +21,7 @@
generate_random_iterable_dataloader,
)

from torchtnt.framework.auto_unit import AutoUnit, SWAParams
from torchtnt.framework.auto_unit import AutoUnit, SWAParams, TorchDynamoParams
from torchtnt.framework.evaluate import evaluate, init_eval_state
from torchtnt.framework.predict import init_predict_state, predict
from torchtnt.framework.state import State
Expand Down Expand Up @@ -307,6 +308,156 @@ def forward(self, x):
torch.allclose(orig_module.l2.weight, swa_module.module.l2.weight)
)

def test_dynamo_eager(self) -> None:
"""
e2e torchdynamo test
"""

device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)

input_dim = 2
dataset_len = 16
batch_size = 2
max_epochs = 1

auto_unit = DummyAutoUnit(
module=my_module,
optimizer=my_optimizer,
torchdynamo_params=TorchDynamoParams("eager"),
)

train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_train_state(dataloader=train_dl, max_epochs=max_epochs)
self.assertFalse(auto_unit._dynamo_used)
train(state, auto_unit)
self.assertTrue(auto_unit._dynamo_used)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_dynamo_train(self) -> None:
"""
e2e torchdynamo on train
"""

device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)

input_dim = 2
dataset_len = 16
batch_size = 2
max_epochs = 1

auto_unit = DummyAutoUnit(
module=my_module,
optimizer=my_optimizer,
torchdynamo_params=TorchDynamoParams("inductor"),
device=device,
)

train_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_train_state(dataloader=train_dl, max_epochs=max_epochs)

self.assertFalse(auto_unit._dynamo_used)
train(state, auto_unit)
self.assertTrue(auto_unit._dynamo_used)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_dynamo_eval(self) -> None:
"""
e2e torchdynamo on eval
"""

device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)

input_dim = 2
dataset_len = 16
batch_size = 2

auto_unit = DummyAutoUnit(
module=my_module,
optimizer=my_optimizer,
torchdynamo_params=TorchDynamoParams("inductor"),
device=device,
)

input_dim = 2
dataset_len = 8
batch_size = 2

eval_dl = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_eval_state(dataloader=eval_dl)
self.assertFalse(auto_unit._dynamo_used)
evaluate(state, auto_unit)
self.assertTrue(auto_unit._dynamo_used)

@unittest.skipUnless(
condition=cuda_available, reason="This test needs a GPU host to run."
)
def test_dynamo_predict(self) -> None:
"""
e2e torchdynamo on predict
"""
device = init_from_env()
my_module = torch.nn.Linear(2, 2, device=device)
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)
auto_unit = DummyAutoUnit(
module=my_module,
optimizer=my_optimizer,
torchdynamo_params=TorchDynamoParams("inductor"),
device=device,
)

input_dim = 2
dataset_len = 8
batch_size = 2

predict_dl = generate_random_iterable_dataloader(
dataset_len, input_dim, batch_size
)
state = init_predict_state(dataloader=predict_dl)
self.assertFalse(auto_unit._dynamo_used)
predict(state, auto_unit)

def test_dynamo_invalid_backend(self) -> None:
"""
verify error is thrown on invalid backend
"""

class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(2, 2)
self.b1 = torch.nn.BatchNorm1d(2)
self.l2 = torch.nn.Linear(2, 2)

def forward(self, x):
x = self.l1(x)
x = self.b1(x)
x = self.l2(x)
return x

my_module = Net()
my_dynamo_params = TorchDynamoParams(backend="foo")
my_optimizer = torch.optim.SGD(my_module.parameters(), lr=0.01)

self.failUnlessRaises(
RuntimeError,
DummyAutoUnit,
**{
"module": my_module,
"optimizer": my_optimizer,
"torchdynamo_params": my_dynamo_params,
},
)

def test_log_frequency_steps_exception(self) -> None:
"""
Test that an exception is raised when log_frequency_steps is < 1
Expand Down Expand Up @@ -539,7 +690,12 @@ def test_move_data_to_device(self) -> None:


class DummyAutoUnit(AutoUnit[Batch]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._dynamo_used = False

def compute_loss(self, state: State, data: Batch) -> Tuple[torch.Tensor, Any]:
self._dynamo_used = torch._dynamo.is_compiling()
inputs, targets = data
outputs = self.module(inputs)
loss = torch.nn.functional.cross_entropy(outputs, targets)
Expand Down
4 changes: 4 additions & 0 deletions tests/utils/test_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ def test_torch_version_comparators(self) -> None:
self.assertTrue(version.is_torch_version_geq_1_10())
self.assertTrue(version.is_torch_version_geq_1_11())
self.assertTrue(version.is_torch_version_geq_1_12())

with patch.object(torch, "__version__", "2.0.0a0"):
self.assertTrue(version.is_torch_version_ge_1_13_1())
self.assertFalse(version.is_torch_version_geq_2_0())
80 changes: 67 additions & 13 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
transfer_batch_norm_stats,
transfer_weights,
)
from torchtnt.utils.version import is_torch_version_ge_1_13_1
from typing_extensions import Literal

TSWA_avg_fn = Callable[[torch.Tensor, torch.Tensor, int], torch.Tensor]
Expand All @@ -54,6 +55,20 @@ class SWAParams:
avg_fn: Optional[TSWA_avg_fn] = None


@dataclass
class TorchDynamoParams:
"""
Dataclass to store parameters for torchdynamo.
Args:
backend: a string backend name in `torch._dynamo.list_backends()`
"""

backend: str


# pyre-ignore: Invalid type parameters [24]
TSelf = TypeVar("TSelf", bound="AutoUnit")
TData = TypeVar("TData")


Expand Down Expand Up @@ -88,6 +103,7 @@ class AutoUnit(TrainUnit[TData], EvalUnit[TData], PredictUnit[Any], ABC):
clip_grad_norm: max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
clip_grad_value: max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging
torchdynamo_params: params for TorchDynamo
Attributes:
module: module to be used during training.
Expand All @@ -104,6 +120,10 @@ class AutoUnit(TrainUnit[TData], EvalUnit[TData], PredictUnit[Any], ABC):
clip_grad_norm: max norm of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_norm_.html
clip_grad_value: max value of the gradients for clipping https://pytorch.org/docs/stable/generated/torch.nn.utils.clip_grad_value_.html
swa_params: params for stochastic weight averaging https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging
torchdynamo_params: params for TorchDynamo
Note:
TorchDynamo support is only available in PyTorch 2.0 or higher.
"""

def __init__(
Expand All @@ -121,6 +141,7 @@ def __init__(
clip_grad_norm: Optional[float] = None,
clip_grad_value: Optional[float] = None,
swa_params: Optional[SWAParams] = None,
torchdynamo_params: Optional[TorchDynamoParams] = None,
) -> None:
super().__init__()
self.module = module
Expand Down Expand Up @@ -190,6 +211,20 @@ def __init__(
anneal_strategy=swa_params.anneal_strategy,
)

if torchdynamo_params:
if not is_torch_version_ge_1_13_1():
raise RuntimeError(
"TorchDynamo support is available only in PyTorch 2.0 or higher. "
"Please install PyTorch 2.0 or higher to continue: https://pytorch.org/get-started/locally/"
)
# pyre-ignore
self.compute_loss = _dynamo_wrapper(self.compute_loss, torchdynamo_params)
# pyre-ignore
self._forward_and_backward = _dynamo_wrapper(
self._forward_and_backward, torchdynamo_params
)
self.module = _dynamo_wrapper(self.module, torchdynamo_params)

# TODO: Make AutoTrainUnit work when data type is Iterator

@abstractmethod
Expand Down Expand Up @@ -252,11 +287,27 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
data = self.move_data_to_device(state, data)

train_state = none_throws(state.train_state)

should_update_weights = (
train_state.progress.num_steps_completed_in_epoch + 1
) % self.gradient_accumulation_steps == 0 or train_state.is_last_batch

loss, outputs = self._forward_and_backward(state, data, should_update_weights)

# users can override this, by default this is a no-op
self.update_metrics(state, data, loss, outputs)

if should_update_weights:
# TODO try to use dynamo here
self._run_optimizer_lr_scheduler_step(state)

# log metrics only after an optimizer step
if self.num_optimizer_steps_completed % self.log_frequency_steps == 0:
self.log_metrics(state, self.num_optimizer_steps_completed - 1, "step")
return loss, outputs

def _forward_and_backward(
self, state: State, data: TData, should_update_weights: bool
):
# if using gradient accumulation and DDP or FSDP, when in a step where we will not update the weights,
# run forward and backward in no_sync context
# https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel.no_sync
Expand All @@ -280,17 +331,6 @@ def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
if grad_scaler:
loss = grad_scaler.scale(loss)
loss.backward()

# users can override this, by default this is a no-op
self.update_metrics(state, data, loss, outputs)

if should_update_weights:
self._run_optimizer_lr_scheduler_step(state)

# log metrics only after an optimizer step
if self.num_optimizer_steps_completed % self.log_frequency_steps == 0:
self.log_metrics(state, self.num_optimizer_steps_completed - 1, "step")

return loss, outputs

def _run_optimizer_lr_scheduler_step(self, state: State) -> None:
Expand Down Expand Up @@ -399,7 +439,6 @@ def predict_step(self, state: State, data: Any) -> Any:

with self.maybe_autocast_precision:
outputs = self.module(data)

return outputs

@property
Expand Down Expand Up @@ -443,3 +482,18 @@ def _get_grad_scaler_from_precision(
else:
return GradScaler()
return None


# pyre-ignore
def _dynamo_wrapper(fn: Callable, torchdynamo_params: TorchDynamoParams):
backend = torchdynamo_params.backend
try:
return torch.compile(fn, backend=backend)
except KeyError as e:
raise RuntimeError(
f"Torchdynamo backend {torchdynamo_params.backend} is not supported."
) from e
except Exception as e:
raise RuntimeError(
f"The following error encountered when calling torch.compile for dynamo: {e}"
) from e
2 changes: 2 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .version import (
get_python_version,
get_torch_version,
is_torch_version_ge_1_13_1,
is_torch_version_geq_1_10,
is_torch_version_geq_1_11,
is_torch_version_geq_1_12,
Expand Down Expand Up @@ -87,6 +88,7 @@
"Timer",
"get_python_version",
"get_torch_version",
"is_torch_version_ge_1_13_1",
"is_torch_version_geq_1_10",
"is_torch_version_geq_1_11",
"is_torch_version_geq_1_12",
Expand Down
4 changes: 4 additions & 0 deletions torchtnt/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def is_torch_version_geq_1_13() -> bool:
return get_torch_version() >= Version("1.13.0")


def is_torch_version_ge_1_13_1() -> bool:
return get_torch_version() > Version("1.13.1")


def is_torch_version_geq_1_14() -> bool:
return get_torch_version() >= Version("1.14.0")

Expand Down

0 comments on commit 9d3913a

Please sign in to comment.