Skip to content

Commit

Permalink
add apex test (Lightning-AI#2921)
Browse files Browse the repository at this point in the history
* add apex test

* rename

* level

* events

* wrap

* evt

* miss

* apex

* apex

* apex

* apex

* apex

* apex

* Update tests/models/test_amp.py

Co-authored-by: William Falcon <waf2107@columbia.edu>

* notes

* notes

Co-authored-by: William Falcon <waf2107@columbia.edu>
  • Loading branch information
Borda and williamFalcon authored Aug 13, 2020
1 parent 6c5a0a1 commit 4354690
Show file tree
Hide file tree
Showing 22 changed files with 191 additions and 101 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tpu-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ env:
GKE_CLUSTER: lightning-cluster
GKE_ZONE: us-central1-a
IMAGE: gcr.io/${{ secrets.GKE_PROJECT }}/tpu-testing-image
MAX_CHECKS: 240
MAX_CHECKS: 360
CHECK_SPEEP: 5

jobs:
Expand Down
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Tracks all outputs including TBPTT and multiple optimizers ([#2890](https://github.com/PyTorchLightning/pytorch-lightning/pull/2890))

- Added GPU Usage Logger ([#2932](https://github.com/PyTorchLightning/pytorch-lightning/pull/2932))

### Changed

- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))
Expand Down Expand Up @@ -351,7 +353,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Deprecated

- Deprecated `tags_csv` in favor of `hparams_file` ([#1271](https://github.com/PyTorchLightning/pytorch-lightning/pull/1271))
- Deprecated `amp_level` in favor of native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561))

### Fixed

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies:
- future>=0.17.1
- PyYAML>=5.1
- tqdm>=4.41.0
- nvidia-apex

# For dev and testing
- black==19.10b0
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, trainer):

def setup(self, model):
# run through amp wrapper
if self.trainer.amp_type:
if self.trainer.amp_backend:
raise MisconfigurationException('amp + cpu is not supported. Please use a GPU option')

# call setup after the ddp process has connected
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp2_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.copy_trainer_model_properties(model)

# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def ddp_train(self, process_idx, mp_queue, model, is_master=False, proc_offset=0
self.trainer.copy_trainer_model_properties(model)

# AMP - run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/ddp_spawn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def ddp_train(self, process_idx, mp_queue, model):

# AMP -
# run through amp wrapper before going to distributed DP
if self.trainer.amp_type == AMPType.APEX:
if self.trainer.amp_backend == AMPType.APEX:
model, optimizers = model.configure_apex(amp, model, self.trainer.optimizers, self.trainer.amp_level)
self.trainer.optimizers = optimizers
self.trainer.reinit_scheduler_properties(self.trainer.optimizers, self.trainer.lr_schedulers)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/dp_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def setup(self, model):
self.model_autocast_original_forward = model.forward

# init half precision
if self.trainer.amp_type:
if self.trainer.amp_backend:
model = self.__init_half_precision(model)

# init torch data parallel
Expand All @@ -69,7 +69,7 @@ def __init_torch_data_parallel(self, model):
return model

def __init_half_precision(self, model):
if self.trainer.amp_type == AMPType.NATIVE:
if self.trainer.amp_backend == AMPType.NATIVE:
self.__init_native_amp(model)
else:
model = self.__init_nvidia_apex(model)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/accelerators/gpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class GPUBackend(object):
amp_type: AMPType
amp_backend: AMPType

def __init__(self, trainer):
self.trainer = trainer
Expand All @@ -41,7 +41,7 @@ def setup(self, model):
self.trainer.lr_schedulers = lr_schedulers
self.trainer.optimizer_frequencies = optimizer_frequencies

if self.trainer.amp_type == AMPType.APEX:
if self.trainer.amp_backend == AMPType.APEX:
model = self._setup_nvidia_apex(model)
return model

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ def backward(self, trainer, loss, optimizer, optimizer_idx):
"""
loss.backward()

def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_type: AMPType):
if amp_type == AMPType.NATIVE:
def amp_scale_loss(self, unscaled_loss, optimizer, optimizer_idx, amp_backend: AMPType):
if amp_backend == AMPType.NATIVE:
scaled_loss = self.trainer.scaler.scale(unscaled_loss)
else:
scaled_loss = amp.scale_loss(unscaled_loss, optimizer)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def _forward_example_input(self) -> None:
input_ = model.example_input_array
input_ = model.transfer_batch_to_device(input_, model.device)

if trainer is not None and trainer.amp_type == AMPType.NATIVE and not trainer.use_tpu:
if trainer is not None and trainer.amp_backend == AMPType.NATIVE and not trainer.use_tpu:
model.forward = torch.cuda.amp.autocast()(model.forward)

mode = model.training
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,18 +900,18 @@ def on_train_end(self, trainer, pl_module):
trainer = Trainer(sync_batchnorm=True)
amp_type
^^^^^^^^
amp_backend
^^^^^^^^^^^
Define a preferable mixed precision, either NVIDIA Apex ("apex") or PyTorch built-in ("native") AMP which is supported from v1.6.
.. testcode::
# using NVIDIA Apex
trainer = Trainer(amp_type='apex')
trainer = Trainer(amp_backend='apex')
# using PyTorch built-in AMP
trainer = Trainer(amp_type='native')
trainer = Trainer(amp_backend='native')
val_percent_check
^^^^^^^^^^^^^^^^^
Expand Down
14 changes: 4 additions & 10 deletions pytorch_lightning/trainer/auto_mix_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ class TrainerAMPMixin(ABC):
# the proper values/initialisation should be done in child class
precision: int

def _setup_amp_type(self, amp_type: str):
self.amp_type = None
def _setup_amp_backend(self, amp_type: str):
if self.precision != 16:
# no AMP requested, so we can leave now
return
Expand All @@ -25,25 +24,20 @@ def _setup_amp_type(self, amp_type: str):
amp_type = 'apex'
else:
log.info('Using native 16bit precision.')
self.amp_type = AMPType.NATIVE
self.amp_backend = AMPType.NATIVE
if amp_type == 'apex':
if not APEX_AVAILABLE:
rank_zero_warn('You have asked for Apex AMP but you have not installed it yet.'
' Install apex first using this guide: https://github.com/NVIDIA/apex#linux')
else:
log.info('Using APEX 16bit precision.')
self.amp_type = AMPType.APEX
if not self.amp_type:
self.amp_backend = AMPType.APEX
if not self.amp_backend:
raise ModuleNotFoundError(
f'You have asked for AMP support {amp_type}, but there is no support on your side yet.'
f' Consider installing torch >= 1.6 or NVIDIA Apex.'
)

def init_amp(self, amp_type: str):
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'

self._setup_amp_type(amp_type)

@property
def use_amp(self) -> bool:
return self.precision == 16
6 changes: 3 additions & 3 deletions pytorch_lightning/trainer/distrib_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class TrainerDPMixin(ABC):
on_colab_kaggle: str
save_spawn_weights: Callable
logger: ...
amp_type: AMPType
amp_backend: AMPType

@abstractmethod
def call_setup_hook(self, *args):
Expand Down Expand Up @@ -124,7 +124,7 @@ def copy_trainer_model_properties(self, model):
m.use_dp = self.use_dp
m.use_ddp2 = self.use_ddp2
m.use_ddp = self.use_ddp
m.use_amp = self.amp_type is not None
m.use_amp = self.amp_backend is not None
m.testing = self.testing
m.use_single_gpu = self.use_single_gpu
m.use_tpu = self.use_tpu
Expand Down Expand Up @@ -209,7 +209,7 @@ def horovod_train(self, model):
if isinstance(scheduler, _LRScheduler):
scheduler.base_lrs = [lr * hvd.size() for lr in scheduler.base_lrs]

if self.amp_type:
if self.amp_backend:
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
self.optimizers = optimizers
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class TrainerEvaluationLoopMixin(ABC):
tpu_id: int
verbose_test: bool
running_sanity_check: bool
amp_type: AMPType
amp_backend: AMPType

# Callback system
on_validation_batch_start: Callable
Expand Down Expand Up @@ -325,7 +325,7 @@ def _evaluate(
# -----------------
# RUN EVALUATION STEP
# -----------------
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast():
output = self.evaluation_forward(model, batch, batch_idx, dataloader_idx, test_mode)
else:
Expand Down
12 changes: 8 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def __init__(
terminate_on_nan: bool = False,
auto_scale_batch_size: Union[str, bool] = False,
prepare_data_per_node: bool = True,
amp_type: str = 'native',
amp_backend: str = 'native',
amp_level: str = 'O2', # backward compatible, todo: remove in v1.0.0
val_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
test_percent_check: float = None, # backward compatible, todo: remove in v0.10.0
Expand Down Expand Up @@ -312,7 +312,6 @@ def __init__(
amp_level: The optimization level to use (O1, O2, etc...).
.. warning:: .. deprecated:: v0.7.4
num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine.
Set it to `-1` to run all batches in all validation dataloaders. Default: 2
Expand Down Expand Up @@ -593,7 +592,7 @@ def __init__(
self.scaler = None

self.amp_level = amp_level
self.init_amp(amp_type)
self.init_amp(amp_backend)

self.on_colab_kaggle = os.getenv('COLAB_GPU') or os.getenv('KAGGLE_URL_BASE')

Expand Down Expand Up @@ -1141,7 +1140,7 @@ def run_pretrain_routine(self, model: LightningModule):
self.copy_trainer_model_properties(ref_model)

# init amp. Must be done here instead of __init__ to allow ddp to work
if self.amp_type == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and self.precision == 16 and not self.use_tpu:
self.scaler = torch.cuda.amp.GradScaler()

# log hyper-parameters
Expand Down Expand Up @@ -1427,6 +1426,11 @@ def call_setup_hook(self, model):
self.setup(stage_name)
model.setup(stage_name)

def init_amp(self, amp_type: str):
assert self.precision in (16, 32), 'only 32 or 16 bit precision supported'
self.amp_backend = None
self._setup_amp_backend(amp_type)


class _PatchDataLoader(object):
r"""
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/trainer/training_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class TrainerIOMixin(ABC):
accumulate_grad_batches: int
scaler: ...
use_tpu: bool
amp_type: AMPType
amp_backend: AMPType

def get_model(self):
is_dp_module = isinstance(self.model, (LightningDistributedDataParallel, LightningDataParallel))
Expand Down Expand Up @@ -301,9 +301,9 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
model.cuda(self.root_gpu)

# restore amp scaling
if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
if self.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint:
elif self.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

# load training state (affects trainer only)
Expand Down Expand Up @@ -354,9 +354,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint['lr_schedulers'] = lr_schedulers

# save native amp scaling
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
checkpoint['native_amp_scaling_state'] = self.scaler.state_dict()
elif self.amp_type == AMPType.APEX:
elif self.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the module_arguments and state_dict from the model
Expand Down Expand Up @@ -513,9 +513,9 @@ def hpc_load(self, folderpath, on_gpu):
model.load_state_dict(checkpoint['state_dict'])

# restore amp scaling
if self.amp_type == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
if self.amp_backend == AMPType.NATIVE and 'native_amp_scaling_state' in checkpoint:
self.scaler.load_state_dict(checkpoint['native_amp_scaling_state'])
elif self.amp_type == AMPType.APEX and 'amp_scaling_state' in checkpoint:
elif self.amp_backend == AMPType.APEX and 'amp_scaling_state' in checkpoint:
amp.load_state_dict(checkpoint['amp_scaling_state'])

if self.root_gpu is not None:
Expand Down
19 changes: 10 additions & 9 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ class TrainerTrainLoopMixin(ABC):
tpu_id: int
interactive_ddp_procs: ...
state: TrainerState
amp_type: AMPType
amp_backend: AMPType
on_tpu: bool

# Callback system
Expand Down Expand Up @@ -935,7 +935,7 @@ def run_batch_backward_pass(self, split_batch, batch_idx, opt_idx, optimizer):
# ------------------
# CLIP GRADS
# ------------------
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
self.scaler.unscale_(optimizer)
self.clip_gradients(optimizer)

Expand Down Expand Up @@ -969,7 +969,7 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):
elif isinstance(optimizer, torch.optim.LBFGS):

# native amp + lbfgs is a no go right now
if self.amp_type == AMPType.NATIVE:
if self.amp_backend == AMPType.NATIVE:
raise MisconfigurationException(
'native PyTorch amp and lbfgs are not compatible.'
' To request, please file a Github issue in PyTorch and tag @mcarilli')
Expand All @@ -978,12 +978,12 @@ def call_optimizer_step(self, optimizer, opt_idx, batch_idx, split_batch):

# when using 16-bit
else:
native_amp = self.amp_type == AMPType.NATIVE
native_amp = self.amp_backend == AMPType.NATIVE
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx, lambda_closure,
using_native_amp=native_amp)

# in native 16-bit we need to update scaler after optimizer step
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
self.scaler.update()

# model hook
Expand All @@ -1000,7 +1000,7 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
# FORWARD (TRAINING STEP + TRAIN STEP END)
# ---------------------------
with self.profiler.profile('model_forward'):
if self.amp_type == AMPType.NATIVE and not self.use_tpu:
if self.amp_backend == AMPType.NATIVE and not self.use_tpu:
with torch.cuda.amp.autocast():
training_step_output = self.training_forward(split_batch, batch_idx,
opt_idx, hiddens)
Expand Down Expand Up @@ -1058,18 +1058,19 @@ def optimizer_closure(self, split_batch, batch_idx, opt_idx, optimizer, hiddens)
with self.profiler.profile('model_backward'):
# scale loss for 16 bit
if self.precision == 16 and not self.on_tpu:
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_type=self.amp_type)
closure_loss = model_ref.amp_scale_loss(closure_loss, optimizer, opt_idx, amp_backend=self.amp_backend)

# enter amp context
if self.amp_type == AMPType.APEX:
if self.amp_backend == AMPType.APEX:
self.dev_debugger.track_event('AMP', str(AMPType.APEX))
context = closure_loss
closure_loss = closure_loss.__enter__()

# do backward pass
model_ref.backward(self, closure_loss, optimizer, opt_idx)

# exit amp context
if self.precision == 16 and self.amp_type == AMPType.APEX and not self.on_tpu:
if self.precision == 16 and self.amp_backend == AMPType.APEX and not self.on_tpu:
a, b, c = None, None, None
error = context.__exit__(a, b, c)
if error:
Expand Down
Loading

0 comments on commit 4354690

Please sign in to comment.