diff --git a/pytext/models/model.py b/pytext/models/model.py index f0268b0fe..eed66a5ba 100644 --- a/pytext/models/model.py +++ b/pytext/models/model.py @@ -47,6 +47,8 @@ class BaseModel(nn.Module, Component): __EXPANSIBLE__ = True __COMPONENT_TYPE__ = ComponentType.MODEL + SUPPORT_FP16_OPTIMIZER = False + class Config(Component.Config): class ModelInput(ModelInputBase): pass diff --git a/pytext/task/new_task.py b/pytext/task/new_task.py index 08fa8d3cf..0702c2399 100644 --- a/pytext/task/new_task.py +++ b/pytext/task/new_task.py @@ -123,6 +123,8 @@ def from_config(cls, config: Config, unused_metadata=None, model_state=None): model = create_component(ComponentType.MODEL, config.model, tensorizers) if model_state: model.load_state_dict(model_state) + + precision.activate(model) if cuda.CUDA_ENABLED: model = model.cuda() # This is the only place right now that the task actually cares about which @@ -173,8 +175,8 @@ def export(self, model, export_path, metric_channels=None, export_onnx_path=None # Make sure to put the model on CPU and disable CUDA before exporting to # ONNX to disable any data_parallel pieces cuda.CUDA_ENABLED = False - precision.deactivate() model = model.cpu() + precision.deactivate(model) batch = next(iter(self.data.batches(Stage.TRAIN))) print("Saving caffe2 model to: " + export_path) @@ -186,8 +188,8 @@ def torchscript_export(self, model, export_path): # Make sure to put the model on CPU and disable CUDA before exporting to # ONNX to disable any data_parallel pieces cuda.CUDA_ENABLED = False - precision.deactivate() model.cpu() + precision.deactivate(model) # Trace needs eval mode, to disable dropout etc model.eval() diff --git a/pytext/task/task.py b/pytext/task/task.py index 5558f5b61..9f69dcdea 100644 --- a/pytext/task/task.py +++ b/pytext/task/task.py @@ -93,6 +93,8 @@ def from_config(cls, task_config, metadata=None, model_state=None): model = create_model(task_config.model, task_config.features, metadata) if model_state: model.load_state_dict(model_state) + + precision.activate(model) if cuda.CUDA_ENABLED: model = model.cuda() metric_reporter = create_metric_reporter(task_config.metric_reporter, metadata) @@ -177,9 +179,9 @@ def export(self, model, export_path, metric_channels=None, export_onnx_path=None # Make sure to put the model on CPU and disable CUDA before exporting to # ONNX to disable any data_parallel pieces cuda.CUDA_ENABLED = False - precision.deactivate() - model = model.cpu() + precision.deactivate(model) + if self.exporter: if metric_channels: print("Exporting metrics") diff --git a/pytext/trainers/trainer.py b/pytext/trainers/trainer.py index 88199145a..a82a7531a 100644 --- a/pytext/trainers/trainer.py +++ b/pytext/trainers/trainer.py @@ -160,8 +160,8 @@ def backprop(self, state, loss): state.scheduler.step_batch() if self.config.max_clip_norm is not None: - grad_norm = torch.nn.utils.clip_grad_norm_( - state.model.parameters(), self.config.max_clip_norm + grad_norm = precision.clip_grad_norm( + state.model, self.optimizer, self.config.max_clip_norm ) else: grad_norm = None diff --git a/pytext/utils/precision.py b/pytext/utils/precision.py index cff74604f..369500d5f 100644 --- a/pytext/utils/precision.py +++ b/pytext/utils/precision.py @@ -2,12 +2,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved from sys import stderr +import torch + from . import cuda _APEX_DISABLED = False try: - from apex import amp + from apex import amp, fp16_utils except ImportError: print("Install apex from https://github.com/NVIDIA/apex/.", file=stderr) _APEX_DISABLED = True @@ -52,50 +54,96 @@ _FP16_ENABLED = False +_USE_FP16_OPTIMIZER = False _amp_handle = None def set_fp16(fp16_enabled: bool): global _FP16_ENABLED - global _amp_handle if _APEX_DISABLED: return - _FP16_ENABLED = fp16_enabled - if _FP16_ENABLED: + if fp16_enabled: if not cuda.CUDA_ENABLED: raise RuntimeError("Cuda is not available, should not running fp16...") - _amp_handle = amp.init(enabled=fp16_enabled) + _FP16_ENABLED = fp16_enabled + + +def activate(model): + # Warning: this function should be called before train. + + global _amp_handle + global _USE_FP16_OPTIMIZER + + if _FP16_ENABLED: + _USE_FP16_OPTIMIZER = model.SUPPORT_FP16_OPTIMIZER + + if _USE_FP16_OPTIMIZER: + model.half() + else: + _amp_handle = amp.init(enabled=_FP16_ENABLED) def wrap_optimizer(optimizer): if _FP16_ENABLED: - return _amp_handle.wrap_optimizer(optimizer) + if _USE_FP16_OPTIMIZER: + return fp16_utils.FP16_Optimizer(optimizer, dynamic_loss_scale=True) + else: + return _amp_handle.wrap_optimizer(optimizer) else: return optimizer +def unwrap_optimizer(wrapped_optimizer): + if _FP16_ENABLED: + if _USE_FP16_OPTIMIZER: + return wrapped_optimizer.optimizer + else: + return wrapped_optimizer._optimizer + else: + return wrapped_optimizer + + def backward(optimizer, loss): if _FP16_ENABLED: - # 1. Use automatic loss scaling to best use fp16 range (skip step if overflow) - # 2. Clear handle's cache of casted parameters before the next optimizer step - with optimizer.scale_loss(loss) as scaled_loss: - scaled_loss.backward() + if _USE_FP16_OPTIMIZER: + # 1. Manage master weights update + # 2. Manage dynamic loss scaling + optimizer.backward(loss) + else: + # 1. Use automatic loss scaling to best use fp16 range + # 2. Clear handle's cache of casted parameters + with optimizer.scale_loss(loss) as scaled_loss: + scaled_loss.backward() else: loss.backward() -def deactivate(): +def clip_grad_norm(model, optimizer, max_clip_norm): + if _FP16_ENABLED and _USE_FP16_OPTIMIZER: + return optimizer.clip_master_grads(max_clip_norm) + else: + return torch.nn.utils.clip_grad_norm_(model.parameters(), max_clip_norm) + + +def deactivate(model): + # Warning: this function is expected to be called after train finished. + # In case need to deactivate before train, should invoke unwrap_optimizer first. + global _FP16_ENABLED + global _USE_FP16_OPTIMIZER if _FP16_ENABLED: - # restoring uncasted versions of functions - _amp_handle._deactivate() + if _USE_FP16_OPTIMIZER: + # convert model parameters back to fp32 + model.float() + _USE_FP16_OPTIMIZER = False + else: + # restoring uncasted versions of functions + _amp_handle._deactivate() _FP16_ENABLED = False - else: - pass def maybe_float(tensor):