13
13
# limitations under the License.
14
14
15
15
import torch
16
- from pytorch_lightning .core import LightningModule
17
16
from pytorch_lightning .utilities import AMPType
18
17
from pytorch_lightning .accelerators .base_backend import Accelerator
19
-
20
- try :
21
- from apex import amp
22
- except ImportError :
23
- amp = None
18
+ from pytorch_lightning .plugins .apex import ApexPlugin
24
19
25
20
26
21
class GPUBackend (Accelerator ):
27
22
amp_backend : AMPType
28
23
29
24
def __init__ (self , trainer ):
30
25
super ().__init__ (trainer )
26
+ self .precision_backend = None
31
27
32
28
def setup (self , model ):
33
29
@@ -45,7 +41,8 @@ def setup(self, model):
45
41
self .trainer .optimizer_frequencies = optimizer_frequencies
46
42
47
43
if self .trainer .amp_backend == AMPType .APEX :
48
- model = self ._setup_nvidia_apex (model )
44
+ self .precision_backend = ApexPlugin (self .trainer )
45
+ model , optimizers = self .precision_backend ._init (model )
49
46
50
47
self .trainer .model = model
51
48
@@ -117,9 +114,3 @@ def to_device(self, batch):
117
114
# be referenced from and if there are multiple optimizers the batch will
118
115
# wind up copying it to the same device repeatedly.
119
116
return self .batch_to_device (batch , gpu_id )
120
-
121
- def _setup_nvidia_apex (self , model : LightningModule ):
122
- model , optimizers = model .configure_apex (amp , model , self .trainer .optimizers , self .trainer .amp_level )
123
- self .trainer .optimizers = optimizers
124
- self .trainer .reinit_scheduler_properties (self .trainer .optimizers , self .trainer .lr_schedulers )
125
- return model
0 commit comments