Skip to content

Commit

Permalink
Create torchtnt.utils.lr_scheduler.TLRScheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
dmtrs committed Dec 22, 2022
1 parent ecb5d5b commit 08b1db8
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 15 deletions.
5 changes: 5 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
4 changes: 2 additions & 2 deletions examples/framework/auto_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data.dataset import Dataset, TensorDataset
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import AutoUnit, fit, init_fit_state, State
from torchtnt.utils import get_timer_summary, init_from_env, seed
from torchtnt.utils import get_timer_summary, init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger
from typing_extensions import Literal

Expand Down Expand Up @@ -63,7 +63,7 @@ def __init__(
*,
module: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
lr_scheduler: TLRScheduler,
device: Optional[torch.device],
log_frequency_steps: int = 1000,
precision: Optional[Union[str, torch.dtype]] = None,
Expand Down
4 changes: 2 additions & 2 deletions examples/framework/train_unit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torch.utils.data.dataset import Dataset, TensorDataset
from torcheval.metrics import BinaryAccuracy
from torchtnt.framework import init_train_state, State, train, TrainUnit
from torchtnt.utils import get_timer_summary, init_from_env, seed
from torchtnt.utils import get_timer_summary, init_from_env, seed, TLRScheduler
from torchtnt.utils.loggers import TensorBoardLogger

_logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -60,7 +60,7 @@ def __init__(
self,
module: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
lr_scheduler: TLRScheduler,
train_accuracy: BinaryAccuracy,
tb_logger: TensorBoardLogger,
log_frequency_steps: int,
Expand Down
3 changes: 2 additions & 1 deletion tests/framework/test_app_state_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import nn
from torchtnt.framework.unit import AppStateMixin
from torchtnt.utils import TLRScheduler


class Dummy(AppStateMixin):
Expand Down Expand Up @@ -196,7 +197,7 @@ def tracked_optimizers(self) -> Dict[str, torch.optim.Optimizer]:

def tracked_lr_schedulers(
self,
) -> Dict[str, torch.optim.lr_scheduler._LRScheduler]:
) -> Dict[str, TLRScheduler]:
return {"lr_2": self.lr_2}

def tracked_misc_statefuls(self) -> Dict[str, Any]:
Expand Down
3 changes: 2 additions & 1 deletion torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
get_device_from_env,
is_torch_version_geq_1_12,
maybe_enable_tf32,
TLRScheduler,
transfer_batch_norm_stats,
transfer_weights,
)
Expand Down Expand Up @@ -111,7 +112,7 @@ def __init__(
*,
module: torch.nn.Module,
optimizer: torch.optim.Optimizer,
lr_scheduler: Optional[torch.optim.lr_scheduler.LRScheduler] = None,
lr_scheduler: Optional[TLRScheduler] = None,
step_lr_interval: Literal["step", "epoch"] = "epoch",
device: Optional[torch.device] = None,
log_frequency_steps: int = 1000,
Expand Down
10 changes: 1 addition & 9 deletions torchtnt/framework/unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,11 @@
from typing import Any, Dict, Generic, TypeVar

import torch
from packaging.version import Version

from torchtnt.framework.state import State
from torchtnt.utils.version import get_torch_version
from torchtnt.utils import TLRScheduler
from typing_extensions import Protocol, runtime_checkable

# This PR exposes LRScheduler as a public class
# https://github.com/pytorch/pytorch/pull/88503
if get_torch_version() > Version("1.13.0"):
TLRScheduler = torch.optim.lr_scheduler.LRScheduler
else:
TLRScheduler = torch.optim.lr_scheduler._LRScheduler

"""
This file defines mixins and interfaces for users to customize hooks in training, evaluation, and prediction loops.
"""
Expand Down
2 changes: 2 additions & 0 deletions torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .early_stop_checker import EarlyStopChecker
from .env import init_from_env
from .fsspec import get_filesystem
from .lr_scheduler import TLRScheduler
from .memory import get_tensor_size_bytes_map, measure_rss_deltas, RSSProfiler
from .misc import days_to_secs, transfer_batch_norm_stats, transfer_weights
from .oom import is_out_of_cpu_memory, is_out_of_cuda_memory, is_out_of_memory_error
Expand Down Expand Up @@ -85,6 +86,7 @@
"transfer_batch_norm_stats",
"transfer_weights",
"Timer",
"TLRScheduler",
"get_python_version",
"get_torch_version",
"is_torch_version_geq_1_10",
Expand Down
14 changes: 14 additions & 0 deletions torchtnt/utils/lr_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# This PR exposes LRScheduler as a public class
# https://github.com/pytorch/pytorch/pull/88503
try:
TLRScheduler = torch.optim.lr_scheduler.LRScheduler
except AttributeError:
TLRScheduler = torch.optim.lr_scheduler._LRScheduler

__all__ = ["TLRScheduler"]

0 comments on commit 08b1db8

Please sign in to comment.