Skip to content

fix bugs in mixed precision training #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindocr/models/backbones/rec_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def __init__(self, in_channels=3, layers=34, **kwargs):
)
shortcut = True
self.block_list.append(basic_block)

self.block_list = nn.SequentialCell(self.block_list)
self.maxpool2d_2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')

Expand Down
2 changes: 1 addition & 1 deletion mindocr/models/necks/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def construct(self, features):
Tensor: Encoded features . Shape :math:`(W, N, 2*C)` where
"""
x = features[0]
assert x.shape[2]==1, 'Feature height must be 1'
assert x.shape[2]==1, f'Feature height must be 1, but got {x.shape[2]} from x.shape {x.shape}'
x = ops.squeeze(x, axis=2) # [N, C, W]
x = ops.transpose(x, (2, 0, 1)) # [W, N, C]

Expand Down
23 changes: 19 additions & 4 deletions mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

import numpy as np
import mindspore as ms
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore import save_checkpoint
from mindspore.train.callback._callback import Callback, _handle_loss
from .visualize import draw_bboxes, show_imgs, recover_image
Expand All @@ -22,14 +24,19 @@ class Evaluator:
loss_fn: loss function
postprocessor: post-processor
metrics: metrics to evaluate network performance
pred_cast_fp32: whehter to cast network prediction to float 32. Set True if AMP is used.
num_columns_to_net: number of inputs to the network in the dataset output columns. Default is 1 for the first column is image.
num_columns_of_labels: number of labels in the dataset output columns. Default is None assuming the columns after image (data[1:]) are labels.
If not None, the num_columns_of_labels columns after image (data[1:1+num_columns_of_labels]) are labels, and the remaining columns are additional info like image_path.
"""

def __init__(self, network, dataloader, loss_fn=None, postprocessor=None, metrics=None,
num_columns_to_net=1, num_columns_of_labels=None, num_epochs=-1,
visualize=False, verbose=False,
pred_cast_fp32=False,
num_columns_to_net=1,
num_columns_of_labels=None,
num_epochs=-1,
visualize=False,
verbose=False,
**kwargs):
self.net = network
self.postprocessor = postprocessor
Expand All @@ -39,7 +46,8 @@ def __init__(self, network, dataloader, loss_fn=None, postprocessor=None, metric
assert hasattr(m, 'metric_names') and isinstance(m.metric_names,
List), f'Metric object must contain `metric_names` attribute to indicate the metric names as a List type, but not found in {m.__class__.__name__}'
self.metric_names += m.metric_names


self.pred_cast_fp32 = pred_cast_fp32
self.visualize = visualize
self.verbose = verbose
eval_loss = False
Expand Down Expand Up @@ -80,6 +88,12 @@ def eval(self):

net_preds = self.net(*inputs)

if self.pred_cast_fp32:
if isinstance(net_preds, ms.Tensor):
net_preds = F.cast(net_preds, mstype.float32)
else:
net_preds = [F.cast(p, mstype.float32) for p in net_preds]

if self.postprocessor is not None:
# additional info such as image path, original image size, pad shape, extracted in data processing
meta_info = data[(self.num_inputs+self.num_labels):] if (self.num_labels is not None) else []
Expand Down Expand Up @@ -125,6 +139,7 @@ def __init__(self,
loss_fn=None,
postprocessor=None,
metrics=None,
pred_cast_fp32=False,
rank_id=0,
logger=None,
batch_size=20,
Expand All @@ -146,7 +161,7 @@ def __init__(self,
self.log_interval = log_interval
self.batch_size = batch_size
if self.loader_eval is not None:
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics, num_columns_to_net=num_columns_to_net, num_columns_of_labels=num_columns_of_labels)
self.net_evaluator = Evaluator(network, loader, loss_fn, postprocessor, metrics, pred_cast_fp32=pred_cast_fp32, num_columns_to_net=num_columns_to_net, num_columns_of_labels=num_columns_of_labels)
self.main_indicator = main_indicator
self.best_perf = -1e8
else:
Expand Down
45 changes: 11 additions & 34 deletions mindocr/utils/model_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from mindspore import nn
import mindspore as ms
import mindspore.ops as ops
from mindspore.communication import get_group_size
from mindspore.common import dtype as mstype
from mindspore.ops import functional as F


class NetWithLossWrapper(nn.Cell):
Expand All @@ -15,13 +18,14 @@ class NetWithLossWrapper(nn.Cell):
num_net_inputs: number of network input, e.g. 1
num_labels: number of labels used for loss fn computation. If None, all the remaining args will be fed into loss func.
'''
def __init__(self, net, loss_fn, num_net_inputs=1, num_labels=None):
def __init__(self, net, loss_fn, pred_cast_fp32=False, num_net_inputs=1, num_labels=None):
super().__init__(auto_prefix=False)
self._net = net
self._loss_fn = loss_fn
# TODO: get this automatically from net and loss func
self.num_net_inputs = num_net_inputs
self.num_labels = num_labels
self.pred_cast_fp32 = pred_cast_fp32
#self.net_forward_input = ['img']
#self.loss_forward_input = ['gt', 'gt_mask', 'thresh_map', 'thresh_mask']

Expand All @@ -33,13 +37,19 @@ def construct(self, *args):
loss_val (Tensor): loss value
'''
pred = self._net(*args[:self.num_net_inputs])
if self.pred_cast_fp32:
if isinstance(pred, ms.Tensor):
pred = F.cast(pred, mstype.float32)
else:
pred = [F.cast(p, mstype.float32) for p in pred]
if self.num_labels is None:
loss_val = self._loss_fn(pred, *args[self.num_net_inputs:])
else:
loss_val = self._loss_fn(pred, *args[self.num_net_inputs:self.num_net_inputs+self.num_labels])

return loss_val


class NetWithEvalWrapper(nn.Cell):
'''
A universal wrapper for any network with any loss for evaluation pipeline.
Expand Down Expand Up @@ -83,36 +93,3 @@ def construct(self, *args):
loss_val = None

return loss_val, pred, labels


class DBNetWithLossCell(nn.Cell):
"""
Wrap the network with loss function to compute loss.

Args:
net (Cell): The target network to wrap.
loss_fn (Cell): The loss function used to compute loss.
"""

def __init__(self, net, loss_fn):
super().__init__(auto_prefix=False)

self._net = net
self._loss_fn = loss_fn

# Note: this order should be consistent with the dataloader output
def construct(self, img, gt, gt_mask, thresh_map, thresh_mask):
pred = self._net(img)
loss = self._loss_fn(pred, gt, gt_mask, thresh_map, thresh_mask)

return loss

@property
def backbone_network(self):
"""
Get the backbone network.

Returns:
Cell, return backbone network.
"""
return self._network
8 changes: 6 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,15 @@ def main(cfg):

# create model
network = build_model(cfg.model)
ms.amp.auto_mixed_precision(network, amp_level=cfg.system.amp_level)
amp_level = cfg.system.get('amp_level', 'O0')
ms.amp.auto_mixed_precision(network, amp_level=amp_level)

# create loss
loss_fn = build_loss(cfg.loss.pop('name'), **cfg['loss'])

net_with_loss = NetWithLossWrapper(network, loss_fn) # wrap train-one-step cell
net_with_loss = NetWithLossWrapper(network, loss_fn,
pred_cast_fp32=(amp_level!='O0'),
) # wrap train-one-step cell

# get loss scale setting for mixed precision training
loss_scale_manager, optimizer_loss_scale = get_loss_scales(cfg)
Expand Down Expand Up @@ -130,6 +133,7 @@ def main(cfg):
loader_eval,
postprocessor=postprocessor,
metrics=[metric],
pred_cast_fp32=(amp_level!='O0'),
rank_id=rank_id,
logger=logger,
batch_size=cfg.train.loader.batch_size,
Expand Down