Skip to content

Commit 810b445

Browse files
ref: apex plugin (Lightning-AI#3502)
* ref: apex plugin * ref: apex plugin * ref: apex plugin
1 parent 61b31d9 commit 810b445

File tree

8 files changed

+57
-37
lines changed

8 files changed

+57
-37
lines changed

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,3 +141,4 @@ Indices and tables
141141
api/pytorch_lightning.trainer
142142
api/pytorch_lightning.utilities
143143
api/pytorch_lightning.tuner
144+
api/pytorch_lightning.plugins

pytorch_lightning/accelerators/ddp2_backend.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323
from pytorch_lightning.core.step_result import Result
2424
from pytorch_lightning.accelerators.ddp_base_backend import DDPBase
25+
from pytorch_lightning.plugins.apex import ApexPlugin
2526

2627
try:
2728
from hydra.utils import to_absolute_path, get_original_cwd
@@ -31,17 +32,13 @@
3132
else:
3233
HYDRA_AVAILABLE = True
3334

34-
try:
35-
from apex import amp
36-
except ImportError:
37-
amp = None
38-
3935

4036
class DDP2Backend(DDPBase):
4137

4238
def __init__(self, trainer):
4339
super().__init__(trainer)
4440
self.task_idx = None
41+
self.precision_backend = None
4542

4643
def setup(self, model):
4744
self._resolve_task_idx()

pytorch_lightning/accelerators/ddp_base_backend.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytorch_lightning.utilities.cloud_io import atomic_save
2323
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only
2424
from pytorch_lightning import _logger as log
25+
from pytorch_lightning.plugins.apex import ApexPlugin
2526

2627
try:
2728
from hydra.utils import to_absolute_path, get_original_cwd
@@ -31,16 +32,12 @@
3132
else:
3233
HYDRA_AVAILABLE = True
3334

34-
try:
35-
from apex import amp
36-
except ImportError:
37-
amp = None
38-
3935

4036
class DDPBase(Accelerator):
4137

4238
def __init__(self, trainer):
4339
super().__init__(trainer)
40+
self.precision_backend = None
4441

4542
def training_step(self, args):
4643
if self.trainer.amp_backend == AMPType.NATIVE:
@@ -155,9 +152,8 @@ def ddp_train_tmp(self, process_idx, mp_queue, model, is_master=False, proc_offs
155152
# AMP -
156153
# run through amp wrapper before going to distributed DP
157154
if self.trainer.amp_backend == AMPType.APEX:
158-
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
159-
self.trainer.optimizers = optimizers
160-
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
155+
self.precision_backend = ApexPlugin(self.trainer)
156+
model, optimizers = self.precision_backend._init(model)
161157

162158
# device ids change depending on the DDP setup
163159
device_ids = self.get_device_ids()

pytorch_lightning/accelerators/dp_backend.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2121
from pytorch_lightning.core.step_result import Result
2222
from pytorch_lightning.accelerators.base_backend import Accelerator
23+
from pytorch_lightning.plugins.apex import ApexPlugin
2324

2425
try:
2526
from apex import amp
@@ -32,6 +33,7 @@ class DataParallelBackend(Accelerator):
3233
def __init__(self, trainer):
3334
super().__init__(trainer)
3435
self.model_autocast_original_forward = None
36+
self.precision_backend = None
3537

3638
def setup(self, model):
3739
# call setup after the ddp process has connected
@@ -89,8 +91,8 @@ def __init_nvidia_apex(self, model):
8991
f' See this note from NVIDIA for more info: https://github.com/NVIDIA/apex/issues/227.'
9092
f' We recommend you switch to ddp if you want to use amp')
9193
else:
92-
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
93-
self.reinit_scheduler_properties(optimizers, self.trainer.lr_schedulers)
94+
self.precision_backend = ApexPlugin(self.trainer)
95+
model, optimizers = self.precision_backend._init(model)
9496

9597
return model
9698

pytorch_lightning/accelerators/gpu_backend.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,17 @@
1313
# limitations under the License.
1414

1515
import torch
16-
from pytorch_lightning.core import LightningModule
1716
from pytorch_lightning.utilities import AMPType
1817
from pytorch_lightning.accelerators.base_backend import Accelerator
19-
20-
try:
21-
from apex import amp
22-
except ImportError:
23-
amp = None
18+
from pytorch_lightning.plugins.apex import ApexPlugin
2419

2520

2621
class GPUBackend(Accelerator):
2722
amp_backend: AMPType
2823

2924
def __init__(self, trainer):
3025
super().__init__(trainer)
26+
self.precision_backend = None
3127

3228
def setup(self, model):
3329

@@ -45,7 +41,8 @@ def setup(self, model):
4541
self.trainer.optimizer_frequencies = optimizer_frequencies
4642

4743
if self.trainer.amp_backend == AMPType.APEX:
48-
model = self._setup_nvidia_apex(model)
44+
self.precision_backend = ApexPlugin(self.trainer)
45+
model, optimizers = self.precision_backend._init(model)
4946

5047
self.trainer.model = model
5148

@@ -117,9 +114,3 @@ def to_device(self, batch):
117114
# be referenced from and if there are multiple optimizers the batch will
118115
# wind up copying it to the same device repeatedly.
119116
return self.batch_to_device(batch, gpu_id)
120-
121-
def _setup_nvidia_apex(self, model: LightningModule):
122-
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
123-
self.trainer.optimizers = optimizers
124-
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
125-
return model

pytorch_lightning/accelerators/horovod_backend.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
from pytorch_lightning.accelerators.base_backend import Accelerator
1919
from pytorch_lightning.utilities.distributed import rank_zero_only
2020
from torch.optim.lr_scheduler import _LRScheduler
21-
22-
try:
23-
from apex import amp
24-
except ImportError:
25-
amp = None
26-
21+
from pytorch_lightning.plugins.apex import ApexPlugin
2722

2823
try:
2924
import horovod.torch as hvd
@@ -38,6 +33,7 @@ class HorovodBackend(Accelerator):
3833

3934
def __init__(self, trainer):
4035
super().__init__(trainer)
36+
self.precision_backend = None
4137

4238
def setup(self, model):
4339
# call setup after the ddp process has connected
@@ -88,9 +84,8 @@ def filter_named_parameters(model, optimizer):
8884
]
8985

9086
if self.trainer.amp_backend == AMPType.APEX:
91-
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
92-
self.trainer.optimizers = optimizers
93-
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
87+
self.precision_backend = ApexPlugin(self.trainer)
88+
model, optimizers = self.precision_backend._init(model)
9489

9590
# Update logger rank info from Horovod to avoid race conditions from different ranks
9691
# creating directories / writing files in the same locations.

pytorch_lightning/plugins/__init__.py

Whitespace-only changes.

pytorch_lightning/plugins/apex.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
try:
16+
from apex import amp
17+
except ImportError:
18+
amp = None
19+
20+
21+
class ApexPlugin:
22+
23+
def __init__(self, trainer):
24+
self.trainer = trainer
25+
26+
def _init(self, model):
27+
model, optimizers = self.configure_apex(model, self.trainer.optimizers, self.trainer.amp_level)
28+
self.trainer.optimizers = optimizers
29+
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
30+
return model, optimizers
31+
32+
def configure_apex(self, model, optimizers, amp_level):
33+
model, optimizers = amp.initialize(model, optimizers, opt_level=amp_level)
34+
return model, optimizers
35+
36+
def training_step(self, fx, args):
37+
output = fx(args)
38+
return output

0 commit comments

Comments
 (0)