Skip to content

Commit

Permalink
use apex fp16 optimizer for bert_classification_model (facebookresear…
Browse files Browse the repository at this point in the history
…ch#390)

Summary:
Pull Request resolved: facebookresearch#390

use apex fp16 optimizer for bert_classification_model, it proves that fp16_optimizer is faster compared to amp

Reviewed By: m3rlin45

Differential Revision: D14434112

fbshipit-source-id: 71a3921adc8a7701fdfb59c7b4ba319701b9ec7e
  • Loading branch information
chenyangyu1988 authored and facebook-github-bot committed Mar 19, 2019
1 parent a22981a commit d72f977
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 21 deletions.
2 changes: 2 additions & 0 deletions pytext/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class BaseModel(nn.Module, Component):
__EXPANSIBLE__ = True
__COMPONENT_TYPE__ = ComponentType.MODEL

SUPPORT_FP16_OPTIMIZER = False

class Config(Component.Config):
class ModelInput(ModelInputBase):
pass
Expand Down
6 changes: 4 additions & 2 deletions pytext/task/new_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ def from_config(cls, config: Config, unused_metadata=None, model_state=None):
model = create_component(ComponentType.MODEL, config.model, tensorizers)
if model_state:
model.load_state_dict(model_state)

precision.activate(model)
if cuda.CUDA_ENABLED:
model = model.cuda()
# This is the only place right now that the task actually cares about which
Expand Down Expand Up @@ -173,8 +175,8 @@ def export(self, model, export_path, metric_channels=None, export_onnx_path=None
# Make sure to put the model on CPU and disable CUDA before exporting to
# ONNX to disable any data_parallel pieces
cuda.CUDA_ENABLED = False
precision.deactivate()
model = model.cpu()
precision.deactivate(model)

batch = next(iter(self.data.batches(Stage.TRAIN)))
print("Saving caffe2 model to: " + export_path)
Expand All @@ -186,8 +188,8 @@ def torchscript_export(self, model, export_path):
# Make sure to put the model on CPU and disable CUDA before exporting to
# ONNX to disable any data_parallel pieces
cuda.CUDA_ENABLED = False
precision.deactivate()
model.cpu()
precision.deactivate(model)
# Trace needs eval mode, to disable dropout etc
model.eval()

Expand Down
6 changes: 4 additions & 2 deletions pytext/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def from_config(cls, task_config, metadata=None, model_state=None):
model = create_model(task_config.model, task_config.features, metadata)
if model_state:
model.load_state_dict(model_state)

precision.activate(model)
if cuda.CUDA_ENABLED:
model = model.cuda()
metric_reporter = create_metric_reporter(task_config.metric_reporter, metadata)
Expand Down Expand Up @@ -177,9 +179,9 @@ def export(self, model, export_path, metric_channels=None, export_onnx_path=None
# Make sure to put the model on CPU and disable CUDA before exporting to
# ONNX to disable any data_parallel pieces
cuda.CUDA_ENABLED = False
precision.deactivate()

model = model.cpu()
precision.deactivate(model)

if self.exporter:
if metric_channels:
print("Exporting metrics")
Expand Down
4 changes: 2 additions & 2 deletions pytext/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,8 @@ def backprop(self, state, loss):
state.scheduler.step_batch()

if self.config.max_clip_norm is not None:
grad_norm = torch.nn.utils.clip_grad_norm_(
state.model.parameters(), self.config.max_clip_norm
grad_norm = precision.clip_grad_norm(
state.model, self.optimizer, self.config.max_clip_norm
)
else:
grad_norm = None
Expand Down
78 changes: 63 additions & 15 deletions pytext/utils/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from sys import stderr

import torch

from . import cuda


_APEX_DISABLED = False
try:
from apex import amp
from apex import amp, fp16_utils
except ImportError:
print("Install apex from https://github.com/NVIDIA/apex/.", file=stderr)
_APEX_DISABLED = True
Expand Down Expand Up @@ -52,50 +54,96 @@


_FP16_ENABLED = False
_USE_FP16_OPTIMIZER = False
_amp_handle = None


def set_fp16(fp16_enabled: bool):
global _FP16_ENABLED
global _amp_handle

if _APEX_DISABLED:
return

_FP16_ENABLED = fp16_enabled
if _FP16_ENABLED:
if fp16_enabled:
if not cuda.CUDA_ENABLED:
raise RuntimeError("Cuda is not available, should not running fp16...")

_amp_handle = amp.init(enabled=fp16_enabled)
_FP16_ENABLED = fp16_enabled


def activate(model):
# Warning: this function should be called before train.

global _amp_handle
global _USE_FP16_OPTIMIZER

if _FP16_ENABLED:
_USE_FP16_OPTIMIZER = model.SUPPORT_FP16_OPTIMIZER

if _USE_FP16_OPTIMIZER:
model.half()
else:
_amp_handle = amp.init(enabled=_FP16_ENABLED)


def wrap_optimizer(optimizer):
if _FP16_ENABLED:
return _amp_handle.wrap_optimizer(optimizer)
if _USE_FP16_OPTIMIZER:
return fp16_utils.FP16_Optimizer(optimizer, dynamic_loss_scale=True)
else:
return _amp_handle.wrap_optimizer(optimizer)
else:
return optimizer


def unwrap_optimizer(wrapped_optimizer):
if _FP16_ENABLED:
if _USE_FP16_OPTIMIZER:
return wrapped_optimizer.optimizer
else:
return wrapped_optimizer._optimizer
else:
return wrapped_optimizer


def backward(optimizer, loss):
if _FP16_ENABLED:
# 1. Use automatic loss scaling to best use fp16 range (skip step if overflow)
# 2. Clear handle's cache of casted parameters before the next optimizer step
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
if _USE_FP16_OPTIMIZER:
# 1. Manage master weights update
# 2. Manage dynamic loss scaling
optimizer.backward(loss)
else:
# 1. Use automatic loss scaling to best use fp16 range
# 2. Clear handle's cache of casted parameters
with optimizer.scale_loss(loss) as scaled_loss:
scaled_loss.backward()
else:
loss.backward()


def deactivate():
def clip_grad_norm(model, optimizer, max_clip_norm):
if _FP16_ENABLED and _USE_FP16_OPTIMIZER:
return optimizer.clip_master_grads(max_clip_norm)
else:
return torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm)


def deactivate(model):
# Warning: this function is expected to be called after train finished.
# In case need to deactivate before train, should invoke unwrap_optimizer first.

global _FP16_ENABLED
global _USE_FP16_OPTIMIZER

if _FP16_ENABLED:
# restoring uncasted versions of functions
_amp_handle._deactivate()
if _USE_FP16_OPTIMIZER:
# convert model parameters back to fp32
model.float()
_USE_FP16_OPTIMIZER = False
else:
# restoring uncasted versions of functions
_amp_handle._deactivate()
_FP16_ENABLED = False
else:
pass


def maybe_float(tensor):
Expand Down

0 comments on commit d72f977

Please sign in to comment.