Skip to content

Commit

Permalink
add parameters of apex_amp.initialize
Browse files Browse the repository at this point in the history
  • Loading branch information
xcnick committed Jan 28, 2023
1 parent 9dc4210 commit 41d4370
Showing 1 changed file with 62 additions and 3 deletions.
65 changes: 62 additions & 3 deletions mmengine/optim/optimizer/apex_optimizer_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from contextlib import contextmanager
from typing import Union
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -35,6 +35,35 @@ class ApexOptimWrapper(OptimWrapper):
loss_scale (float or str, default=None): If passed as
a string, must be a string representing a number,
e.g., "128.0", or the string "dynamic".
enabled (bool, default=True): If False, renders all Amp calls no-ops,
so your script should run as if Amp were not present.
cast_model_type (torch.dtype, default=None): Model's parameters and
buffers to the desired type.
patch_torch_functions (bool, default=None): Patch all Torch functions
and Tensor methods to perform Tensor Core-friendly ops like GEMMs
and convolutions in FP16,
and any ops that benefit from FP32 precision in FP32.
keep_batchnorm_fp32 (bool or str, default=None): To enhance precision
and enable cudnn batchnorm (which improves performance),
it's often beneficial to keep batchnorm weights in FP32
even if the rest of the model is FP16.
If passed as a string, must be the string "True" or "False".
master_weights (bool, default=None): Maintain FP32 master weights to
accompany any FP16 model weights. FP32 master weights are stepped
by the optimizer to enhance precision and capture small gradients.
cast_model_outputs (torch.dtype, default=None): Option to ensure that
the outputs of your model(s) are always cast to a particular type
regardless of ``opt_level``.
num_losses (int, default=1): Option to tell Amp in advance how many
losses/backward passes you plan to use.
verbosity (int, default=1): Set to 0 to suppress Amp-related output.
min_loss_scale (float, default=None): Sets a floor for the loss scale
values that can be chosen by dynamic loss scaling.
The default value of None means that no floor is imposed.
If dynamic loss scaling is not used, `min_loss_scale` is ignored.
max_loss_scale (float, default=2.**24): Sets a ceiling for the
loss scale values that can be chosen by dynamic loss scaling.
If dynamic loss scaling is not used, `max_loss_scale` is ignored.
**kwargs: Keyword arguments passed to OptimWrapper.
Note:
Expand All @@ -46,13 +75,33 @@ class ApexOptimWrapper(OptimWrapper):
def __init__(self,
opt_level: str = 'O1',
loss_scale: Union[float, str] = 'dynamic',
enabled: Optional[bool] = True,
cast_model_type: Optional[torch.dtype] = None,
patch_torch_functions: Optional[bool] = None,
keep_batchnorm_fp32: Optional[Union[bool, str]] = None,
master_weights: Optional[bool] = None,
cast_model_outputs: Optional[torch.dtype] = None,
num_losses: Optional[int] = 1,
verbosity: Optional[int] = 1,
min_loss_scale: Optional[float] = None,
max_loss_scale: Optional[float] = 2.**24,
**kwargs):
assert apex_amp is not None, \
'Apex is not installed. Please check ' \
'https://github.com/NVIDIA/apex#linux.'
super().__init__(**kwargs)
self.opt_level = opt_level
self.loss_scale = loss_scale
self.enabled = enabled
self.cast_model_type = cast_model_type
self.patch_torch_functions = patch_torch_functions
self.keep_batchnorm_fp32 = keep_batchnorm_fp32
self.master_weights = master_weights
self.cast_model_outputs = cast_model_outputs
self.num_losses = num_losses
self.verbosity = verbosity
self.min_loss_scale = min_loss_scale
self.max_loss_scale = max_loss_scale

def backward(self, loss: torch.Tensor, **kwargs) -> None:
"""Perform gradient back propagation with :attr:`loss_scaler`.
Expand All @@ -62,7 +111,7 @@ def backward(self, loss: torch.Tensor, **kwargs) -> None:
kwargs: Keyword arguments passed to :meth:`torch.Tensor.backward`
"""
with apex_amp.scale_loss(loss, self.optimizer) as scaled_loss:
scaled_loss.backward()
scaled_loss.backward(**kwargs)
self._inner_count += 1

def state_dict(self) -> dict:
Expand Down Expand Up @@ -116,5 +165,15 @@ def optim_context(self, model: nn.Module):
model,
self.optimizer,
opt_level=self.opt_level,
loss_scale=self.loss_scale)
loss_scale=self.loss_scale,
enabled=self.enabled,
cast_model_type=self.cast_model_type,
patch_torch_functions=self.patch_torch_functions,
keep_batchnorm_fp32=self.keep_batchnorm_fp32,
master_weights=self.master_weights,
cast_model_outputs=self.cast_model_outputs,
num_losses=self.num_losses,
verbosity=self.verbosity,
min_loss_scale=self.min_loss_scale,
max_loss_scale=self.max_loss_scale)
yield

0 comments on commit 41d4370

Please sign in to comment.