Description
Describe the feature
add self._iter and self._max_iter as arguments to the self.model.train_step() in the runner.
Maybe like this:
def run_iter(self, data_batch, train_mode, **kwargs):
kwargs['iter'] = self._iter
kwargs['max_iter'] = self._max_iters
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
Motivation
The global iteration and max iteration may be needed by the network forward during training and/or validation. Recently, I came across at least two studies that require this.
One study is Temporally Distributed Networks for Fast Video Semantic Segmentation (CVPR'20) and another is Regularizing Deep Networks with Semantic Data Augmentation (PAMI20).
Related resources
The following two links demonstrate how those two arguments were used during training.
https://github.com/feinanshan/TDNet/blob/3f8b5378fcc7f97c26b3760ddaf3d4402cf477d1/Training/train.py#L118
https://github.com/blackfeather-wang/ISDA-for-Deep-Networks/blob/318c30976d0c412a7dd10250b0164beac6d4fbeb/Semantic%20segmentation%20on%20Cityscapes/train_isda.py#L363
Additional context
I was able to implement the implicit semantic data augmentation in mmseg and designed a workaround to add those two arguments in the model training like follows, where every time the self.forward_train is called the self._iters will be updated by adding one.
class FCNHead(BaseDecodeHead):
#is_use_isda (boo): if use implicit semantic data augmentation
#isda_lambda (float) : 'The hyper-parameter \lambda_0 for ISDA, select from {1, 2.5, 5, 7.5, 10}. '
def __init__(self,
num_convs=2,
kernel_size=3,
concat_input=True,
is_use_isda = False,
isda_lambda = 2.5,
start_iters = 1,
max_iters = 4e5,
**kwargs):
assert isinstance(num_convs, int)
self.num_convs = num_convs
self.concat_input = concat_input
self.kernel_size = kernel_size
super(FCNHead3D, self).__init__(**kwargs)
convs = []
convs.append(
ConvModule(
self.in_channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
for _ in range(num_convs - 1):
convs.append(
ConvModule(
self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
if num_convs == 0:
self.convs = nn.Identity()
else:
self.convs = nn.Sequential(*convs)
if self.concat_input:
self.conv_cat = ConvModule(
self.in_channels + self.channels,
self.channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
self.is_use_isda = is_use_isda
self.isda_lambda = isda_lambda
self._iter = start_iters
self._max_iters = max_iters
if is_use_isda:
self.isda_augmentor = ISDALoss(self.final_channel, self.num_classes)
def forward(self, inputs):
# ratio = args.lambda_0 * global_iteration / args.num_steps # training progress as percentage
x = self._transform_inputs(inputs)
feat_map = self.convs(x) if self.num_convs > 0 else x
if self.concat_input:
feat_map = self.conv_cat(torch.cat([x, feat_map], dim=1))
output = self.cls_seg(feat_map)
if self.is_use_isda and self.training :
return output, feat_map.detach()
else:
return output
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
if self.is_use_isda:
ratio = self.isda_lambda * self._iter / self._max_iters # training progress as percentage
seg_logits_1, last_feat_map = self.forward(inputs)
# pdb.set_trace()
seg_logits = self.isda_augmentor(last_feat_map, self.conv_seg, seg_logits_1, gt_semantic_seg, ratio) #
self._iter += 1
else:
seg_logits = self.forward(inputs)
losses = self.losses(seg_logits, gt_semantic_seg)
return losses
But this is not the ultimate solution. I plan to implement the TDNet using mmseg in the future and may also encounter this issue.