From 08b1db8dc7a3a30e7704dcc94befcf835cf7bb15 Mon Sep 17 00:00:00 2001 From: Dmtrs Flaco Meng Date: Thu, 22 Dec 2022 13:29:50 +0200 Subject: [PATCH] Create torchtnt.utils.lr_scheduler.TLRScheduler --- conftest.py | 5 +++++ examples/framework/auto_unit_example.py | 4 ++-- examples/framework/train_unit_example.py | 4 ++-- tests/framework/test_app_state_mixin.py | 3 ++- torchtnt/framework/auto_unit.py | 3 ++- torchtnt/framework/unit.py | 10 +--------- torchtnt/utils/__init__.py | 2 ++ torchtnt/utils/lr_scheduler.py | 14 ++++++++++++++ 8 files changed, 30 insertions(+), 15 deletions(-) create mode 100644 conftest.py create mode 100644 torchtnt/utils/lr_scheduler.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/conftest.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/examples/framework/auto_unit_example.py b/examples/framework/auto_unit_example.py index d702c3da17..f72cd6bd36 100644 --- a/examples/framework/auto_unit_example.py +++ b/examples/framework/auto_unit_example.py @@ -17,7 +17,7 @@ from torch.utils.data.dataset import Dataset, TensorDataset from torcheval.metrics import BinaryAccuracy from torchtnt.framework import AutoUnit, fit, init_fit_state, State -from torchtnt.utils import get_timer_summary, init_from_env, seed +from torchtnt.utils import get_timer_summary, init_from_env, seed, TLRScheduler from torchtnt.utils.loggers import TensorBoardLogger from typing_extensions import Literal @@ -63,7 +63,7 @@ def __init__( *, module: torch.nn.Module, optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + lr_scheduler: TLRScheduler, device: Optional[torch.device], log_frequency_steps: int = 1000, precision: Optional[Union[str, torch.dtype]] = None, diff --git a/examples/framework/train_unit_example.py b/examples/framework/train_unit_example.py index 797162e04e..7036319cc5 100644 --- a/examples/framework/train_unit_example.py +++ b/examples/framework/train_unit_example.py @@ -17,7 +17,7 @@ from torch.utils.data.dataset import Dataset, TensorDataset from torcheval.metrics import BinaryAccuracy from torchtnt.framework import init_train_state, State, train, TrainUnit -from torchtnt.utils import get_timer_summary, init_from_env, seed +from torchtnt.utils import get_timer_summary, init_from_env, seed, TLRScheduler from torchtnt.utils.loggers import TensorBoardLogger _logger: logging.Logger = logging.getLogger(__name__) @@ -60,7 +60,7 @@ def __init__( self, module: torch.nn.Module, optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler.LRScheduler, + lr_scheduler: TLRScheduler, train_accuracy: BinaryAccuracy, tb_logger: TensorBoardLogger, log_frequency_steps: int, diff --git a/tests/framework/test_app_state_mixin.py b/tests/framework/test_app_state_mixin.py index 843ba06b6f..083cd49273 100644 --- a/tests/framework/test_app_state_mixin.py +++ b/tests/framework/test_app_state_mixin.py @@ -11,6 +11,7 @@ import torch from torch import nn from torchtnt.framework.unit import AppStateMixin +from torchtnt.utils import TLRScheduler class Dummy(AppStateMixin): @@ -196,7 +197,7 @@ def tracked_optimizers(self) -> Dict[str, torch.optim.Optimizer]: def tracked_lr_schedulers( self, - ) -> Dict[str, torch.optim.lr_scheduler._LRScheduler]: + ) -> Dict[str, TLRScheduler]: return {"lr_2": self.lr_2} def tracked_misc_statefuls(self) -> Dict[str, Any]: diff --git a/torchtnt/framework/auto_unit.py b/torchtnt/framework/auto_unit.py index 3ae7e8ed22..338782d9b9 100644 --- a/torchtnt/framework/auto_unit.py +++ b/torchtnt/framework/auto_unit.py @@ -26,6 +26,7 @@ get_device_from_env, is_torch_version_geq_1_12, maybe_enable_tf32, + TLRScheduler, transfer_batch_norm_stats, transfer_weights, ) @@ -111,7 +112,7 @@ def __init__( *, module: torch.nn.Module, optimizer: torch.optim.Optimizer, - lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None, + lr_scheduler: Optional[TLRScheduler] = None, step_lr_interval: Literal["step", "epoch"] = "epoch", device: Optional[torch.device] = None, log_frequency_steps: int = 1000, diff --git a/torchtnt/framework/unit.py b/torchtnt/framework/unit.py index 19f21bf24b..c6542ea113 100644 --- a/torchtnt/framework/unit.py +++ b/torchtnt/framework/unit.py @@ -12,19 +12,11 @@ from typing import Any, Dict, Generic, TypeVar import torch -from packaging.version import Version from torchtnt.framework.state import State -from torchtnt.utils.version import get_torch_version +from torchtnt.utils import TLRScheduler from typing_extensions import Protocol, runtime_checkable -# This PR exposes LRScheduler as a public class -# https://github.com/pytorch/pytorch/pull/88503 -if get_torch_version() > Version("1.13.0"): - TLRScheduler = torch.optim.lr_scheduler.LRScheduler -else: - TLRScheduler = torch.optim.lr_scheduler._LRScheduler - """ This file defines mixins and interfaces for users to customize hooks in training, evaluation, and prediction loops. """ diff --git a/torchtnt/utils/__init__.py b/torchtnt/utils/__init__.py index fd60f62e55..3f9c0e8782 100644 --- a/torchtnt/utils/__init__.py +++ b/torchtnt/utils/__init__.py @@ -23,6 +23,7 @@ from .early_stop_checker import EarlyStopChecker from .env import init_from_env from .fsspec import get_filesystem +from .lr_scheduler import TLRScheduler from .memory import get_tensor_size_bytes_map, measure_rss_deltas, RSSProfiler from .misc import days_to_secs, transfer_batch_norm_stats, transfer_weights from .oom import is_out_of_cpu_memory, is_out_of_cuda_memory, is_out_of_memory_error @@ -85,6 +86,7 @@ "transfer_batch_norm_stats", "transfer_weights", "Timer", + "TLRScheduler", "get_python_version", "get_torch_version", "is_torch_version_geq_1_10", diff --git a/torchtnt/utils/lr_scheduler.py b/torchtnt/utils/lr_scheduler.py new file mode 100644 index 0000000000..22c1ab5203 --- /dev/null +++ b/torchtnt/utils/lr_scheduler.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This PR exposes LRScheduler as a public class +# https://github.com/pytorch/pytorch/pull/88503 +try: + TLRScheduler = torch.optim.lr_scheduler.LRScheduler +except AttributeError: + TLRScheduler = torch.optim.lr_scheduler._LRScheduler + +__all__ = ["TLRScheduler"]