Skip to content

Commit

Permalink
Remove deprecations for 0.18 (mosaicml#2935)
Browse files Browse the repository at this point in the history
* remove

* fix

* fix import

* fix default

* purge tests

* mass deprecation

* merge

* lint

* lint

* fix docs

* fix

* fix

* fix

* change dataset

* change to resnet for precision

* fix

* fix

* fix

* fix

* fix

* fix

* fix tests

* fix fixtures

* fix

* fix readme

* fix

---------

Co-authored-by: Charles Tang <j316chuck@users.noreply.github.com>
  • Loading branch information
mvpatel2000 and j316chuck authored Feb 1, 2024
1 parent 3c14906 commit efb03ff
Show file tree
Hide file tree
Showing 104 changed files with 274 additions and 8,334 deletions.
16 changes: 8 additions & 8 deletions .github/workflows/daily.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,42 +20,42 @@ jobs:
include:
- name: cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not vision and not doctest
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not vision and not doctest
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.10-2.1-composer
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not vision and not doctest
markers: not daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: cpu-doctest
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and (remote or not remote) and not gpu and not vision and doctest
markers: not daily and (remote or not remote) and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not vision and not doctest
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not vision and not doctest
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: daily-cpu-3.10-2.1-composer
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not vision and not doctest
markers: daily and (remote or not remote) and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: composer
- name: daily-cpu-doctest
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: daily and (remote or not remote) and not gpu and not vision and doctest
markers: daily and (remote or not remote) and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
name: ${{ matrix.name }}
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/pr-cpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,17 @@ jobs:
include:
- name: cpu-3.10-2.0
container: mosaicml/pytorch:2.0.1_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not vision and not doctest
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-3.10-2.1
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not vision and not doctest
markers: not daily and not remote and not gpu and not doctest
pytest_command: coverage run -m pytest
composer_package_name: mosaicml
- name: cpu-doctest
container: mosaicml/pytorch:2.1.2_cpu-python3.10-ubuntu20.04
markers: not daily and not remote and not gpu and not vision and doctest
markers: not daily and not remote and not gpu and doctest
pytest_command: coverage run -m pytest tests/test_docs.py
composer_package_name: mosaicml
name: ${{ matrix.name }}
Expand Down
35 changes: 32 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -135,26 +135,55 @@ Here is a code snippet demonstrating our Trainer on the MNIST dataset.
<!--pytest.mark.filterwarnings(r'ignore:Some targets have less than 1 total probability:UserWarning')-->
<!--pytest.mark.filterwarnings('ignore:Cannot split tensor of length .* into batches of size 128.*:UserWarning')-->
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

from composer import Trainer
from composer.models import mnist_model
from composer.models import ComposerClassifier
from composer.algorithms import LabelSmoothing, CutMix, ChannelsLast

class Model(nn.Module):
"""Toy convolutional neural network architecture in pytorch for MNIST."""

def __init__(self, num_classes: int = 10):
super().__init__()

self.num_classes = num_classes

self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)
self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)
self.bn = nn.BatchNorm2d(32)
self.fc1 = nn.Linear(32 * 16, 32)
self.fc2 = nn.Linear(32, num_classes)

def forward(self, x):
out = self.conv1(x)
out = F.relu(out)
out = self.conv2(out)
out = self.bn(out)
out = F.relu(out)
out = F.adaptive_avg_pool2d(out, (4, 4))
out = torch.flatten(out, 1, -1)
out = self.fc1(out)
out = F.relu(out)
return self.fc2(out)

transform = transforms.Compose([transforms.ToTensor()])
dataset = datasets.MNIST("data", train=True, download=True, transform=transform)
train_dataloader = DataLoader(dataset, batch_size=128)

trainer = Trainer(
model=mnist_model(num_classes=10),
model=ComposerClassifier(module=Model(), num_classes=10),
train_dataloader=train_dataloader,
max_duration="2ep",
algorithms=[
LabelSmoothing(smoothing=0.1),
CutMix(alpha=1.0),
ChannelsLast(),
]
],
)
trainer.fit()
```
Expand Down
13 changes: 7 additions & 6 deletions STYLE_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,22 +227,23 @@ All imports in composer should be absolute -- that is, they do not begin with a
1. If a dependency is not core to Composer (e.g. it is for a model, dataset, algorithm, or some callbacks):
1. It must be specified in a entry of the `extra_deps` dictionary of [setup.py](setup.py).
This dictionary groups dependencies that can be conditionally installed. An entry named `foo`
can be installed with `pip install 'mosaicml[foo]'`. For example, running `pip install 'mosaicml[unet]'`
will install everything in `install_requires`, along with `monai` and `scikit-learn`.
can be installed with `pip install 'mosaicml[foo]'`. For example, running `pip install 'mosaicml[system_metrics_monitor]'`
will install everything in `install_requires`, along with `pynvml`.
1. It must also be specified in the `run_constrained` and the `test.requires` section.
1. The import must be conditionally imported in the code. For example:
<!--pytest-codeblocks:importorskip(monai)-->
<!--pytest-codeblocks:importorskip(scikit-learn)-->
```python
from composer import Callback
from composer.utils import MissingConditionalImportError
def unet():
class SystemMetricsMonitor(Callback)
try:
import monai
import pynvml
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group="unet",
conda_package="monai",
raise MissingConditionalImportError(extra_deps_group="system_metrics_monitor",
conda_package="pynvml",
conda_channel="conda-forge",) from e
```
Expand Down
4 changes: 1 addition & 3 deletions composer/algorithms/blurpool/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,7 @@ def training_loop(model, train_loader):
<!--
```python
from torch.utils.data import DataLoader
from tests.common import RandomImageDataset
from composer.models import composer_resnet
from tests.common import RandomImageDataset, composer_resnet
model = composer_resnet('resnet50')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,11 @@
from __future__ import annotations

import logging
import textwrap
import warnings
from typing import Dict, Optional, Sequence, Type, Union

import torch
import torch.nn.functional as F
from packaging import version
from torch.optim import Optimizer

from composer.algorithms.warnings import NoEffectWarning
Expand All @@ -22,12 +20,6 @@

log = logging.getLogger(__name__)

try:
from apex.normalization.fused_layer_norm import FusedLayerNorm as APEXFusedLayerNorm
APEX_INSTALLED = True
except ImportError as e:
APEX_INSTALLED = False


def apply_low_precision_layernorm(model,
precision: Optional[Precision] = None,
Expand All @@ -38,22 +30,6 @@ def apply_low_precision_layernorm(model,

policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {torch.nn.LayerNorm: _to_LPLayerNorm}

# Prior to v1.13, torch.nn.LayerNorm is slow in bf16 precision.
# We use FusedLayerNorm as a fallback.
if version.parse(torch.__version__) < version.parse('1.13') and precision == Precision.AMP_BF16:
warnings.warn(
DeprecationWarning(
textwrap.dedent(
'You are using Low Precision LayerNorm on PyTorch < v.1.13 with bfloat16 precision. '
'In this scenario, we fall back to Fused LayerNorm. '
'Fused LayerNorm has been deprecated and will be removed in Composer 0.18. '
'Please upgrade your PyTorch version to >=v.1.13 to use Low Precision LayerNorm without the Fused LayerNorm fallback.'
)))
check_if_apex_installed()
policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {
torch.nn.LayerNorm: _to_FusedLayerNorm
}

replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy)
if len(replaced_instances) == 0:
warnings.warn(NoEffectWarning('No instances of torch.nn.LayerNorm found.'))
Expand Down Expand Up @@ -129,13 +105,6 @@ def _cast_if_autocast_enabled(tensor):
return tensor


def check_if_apex_installed():
if not APEX_INSTALLED:
raise ImportError(
'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied on PyTorch <1.13 without Apex. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
)


def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm:
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `LPLayerNorm`"""
if not isinstance(layer, torch.nn.LayerNorm):
Expand All @@ -153,22 +122,3 @@ def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm:
lp_layernorm.bias.copy_(layer.bias) # type: ignore

return lp_layernorm


def _to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm:
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`"""
if not isinstance(layer, torch.nn.LayerNorm):
raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}')
fused_layernorm = APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps)

with torch.no_grad():
if layer.weight is None: # pyright: ignore[reportUnnecessaryComparison]
fused_layernorm.weight = None # pyright: ignore[reportGeneralTypeIssues]
else:
fused_layernorm.weight.copy_(layer.weight)
if layer.bias is None: # pyright: ignore[reportUnnecessaryComparison]
fused_layernorm.bias = None # pyright: ignore[reportGeneralTypeIssues]
else:
fused_layernorm.bias.copy_(layer.bias)

return fused_layernorm
6 changes: 2 additions & 4 deletions composer/algorithms/stochastic_depth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
<!--
```python
from torch.utils.data import DataLoader
from tests.common import RandomImageDataset
from tests.common import RandomImageDataset, composer_resnet
train_dataloader = DataLoader(RandomImageDataset(size=2), batch_size=2)
```
Expand All @@ -27,7 +27,6 @@ import torch
import torch.nn.functional as F

import composer.functional as cf
from composer.models import composer_resnet

# Training

Expand Down Expand Up @@ -63,7 +62,7 @@ for epoch in range(1):
<!--
```python
from torch.utils.data import DataLoader
from tests.common import RandomImageDataset
from tests.common import RandomImageDataset, composer_resnet
train_dataloader = DataLoader(RandomImageDataset(size=2), batch_size=2)
eval_dataloader = DataLoader(RandomImageDataset(size=2), batch_size=2)
Expand All @@ -75,7 +74,6 @@ eval_dataloader = DataLoader(RandomImageDataset(size=2), batch_size=2)
# The trainer will automatically run it at the appropriate point in the training loop

from composer.algorithms import StochasticDepth
from composer.models import composer_resnet
from composer.trainer import Trainer

# Train model
Expand Down
6 changes: 2 additions & 4 deletions composer/algorithms/weight_standardization/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Weight Standardization is a reparametrization of convolutional weights such that
```python
from torchvision import models
from torch.utils.data import DataLoader
from tests.common import RandomImageDataset
from tests.common import RandomImageDataset, composer_resnet
my_train_dataloader = DataLoader(RandomImageDataset(size=2), batch_size=2)
my_cnn_model = models.resnet18()
Expand All @@ -31,7 +31,6 @@ my_cnn_model = models.resnet18()
import composer.functional as cf
import torch
import torch.nn.functional as F
from composer.models import composer_resnet

def training_loop(model, train_dataloader):
opt = torch.optim.Adam(model.parameters())
Expand All @@ -58,9 +57,8 @@ training_loop(my_cnn_model, my_train_dataloader)
<!--pytest.mark.gpu-->
<!--
```python
from composer.models import composer_resnet
from torch.utils.data import DataLoader
from tests.common import RandomImageDataset
from tests.common import RandomImageDataset, composer_resnet
cnn_composer_model = composer_resnet('resnet50')
my_train_dataloader = DataLoader(RandomImageDataset(size=2), batch_size=2)
Expand Down
2 changes: 0 additions & 2 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.free_outputs import FreeOutputs
from composer.callbacks.generate import Generate
from composer.callbacks.health_checker import HealthChecker
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
Expand All @@ -38,7 +37,6 @@
'ExportForInferenceCallback',
'ThresholdStopper',
'ImageVisualizer',
'HealthChecker',
'RuntimeEstimator',
'SystemMetricsMonitor',
'Generate',
Expand Down
Loading

0 comments on commit efb03ff

Please sign in to comment.