Skip to content

Commit

Permalink
Add customized text logger that prints out both lr and layer_0_lr
Browse files Browse the repository at this point in the history
  • Loading branch information
HannaMao committed Jan 14, 2022
1 parent b2d5a2d commit e4e7eb2
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 2 deletions.
16 changes: 16 additions & 0 deletions object_detection/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
checkpoint_config = dict(interval=1)
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='CustomizedTextLoggerHook'),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
custom_hooks = [dict(type='NumClassCheckHook')]

dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
3 changes: 2 additions & 1 deletion object_detection/mmcv_custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@

from .checkpoint import load_checkpoint
from .layer_decay_optimizer_constructor import LearningRateDecayOptimizerConstructor
from .customized_text import CustomizedTextLoggerHook

__all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor']
__all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'CustomizedTextLoggerHook']
131 changes: 131 additions & 0 deletions object_detection/mmcv_custom/customized_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import datetime
from collections import OrderedDict

import torch

import mmcv
from mmcv.runner import HOOKS
from mmcv.runner import TextLoggerHook


@HOOKS.register_module()
class CustomizedTextLoggerHook(TextLoggerHook):
"""Customized Text Logger hook.
This logger prints out both lr and layer_0_lr.
"""

def _log_info(self, log_dict, runner):
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)

if log_dict['mode'] == 'train':
lr_str = {}
for lr_type in ['lr', 'layer_0_lr']:
if isinstance(log_dict[lr_type], dict):
lr_str[lr_type] = []
for k, val in log_dict[lr_type].items():
lr_str.append(f'{lr_type}_{k}: {val:.3e}')
lr_str[lr_type] = ' '.join(lr_str)
else:
lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'

# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{log_dict["epoch"]}]' \
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
else:
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '

if 'time' in log_dict.keys():
self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {log_dict["time"]:.3f}, ' \
f'data_time: {log_dict["data_time"]:.3f}, '
# statistic memory
if torch.cuda.is_available():
log_str += f'memory: {log_dict["memory"]}, '
else:
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
else:
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'

log_items = []
for name, val in log_dict.items():
# TODO: resolve this hack
# these items have been in log_str
if name in [
'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
'memory', 'epoch'
]:
continue
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)

runner.logger.info(log_str)


def log(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
else:
cur_iter = self.get_iter(runner, inner_iter=True)

log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=cur_iter)

# record lr and layer_0_lr
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
print(cur_lr)
log_dict['layer_0_lr'] = min(cur_lr)
log_dict['lr'] = max(cur_lr)
else:
assert isinstance(cur_lr, dict)
log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
for k, lr_ in cur_lr.items():
assert isinstance(lr_, list)
log_dict['layer_0_lr'].update({k: min(lr_)})
log_dict['lr'].update({k: max(lr_)})

if 'time' in runner.log_buffer.output:
# statistic memory
if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner)

log_dict = dict(log_dict, **runner.log_buffer.output)

self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
return log_dict
14 changes: 14 additions & 0 deletions semantic_segmentation/configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# yapf:disable
log_config = dict(
interval=50,
hooks=[
dict(type='CustomizedTextLoggerHook', by_epoch=False),
# dict(type='TensorboardLoggerHook')
])
# yapf:enable
dist_params = dict(backend='nccl')
log_level = 'INFO'
load_from = None
resume_from = None
workflow = [('train', 1)]
cudnn_benchmark = True
3 changes: 2 additions & 1 deletion semantic_segmentation/mmcv_custom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@
from .resize_transform import SETR_Resize
from .apex_runner.optimizer import DistOptimizerHook
from .train_api import train_segmentor
from .customized_text import CustomizedTextLoggerHook

__all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor']
__all__ = ['load_checkpoint', 'LearningRateDecayOptimizerConstructor', 'SETR_Resize', 'DistOptimizerHook', 'train_segmentor', 'CustomizedTextLoggerHook']
131 changes: 131 additions & 0 deletions semantic_segmentation/mmcv_custom/customized_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.

# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.


import datetime
from collections import OrderedDict

import torch

import mmcv
from mmcv.runner import HOOKS
from mmcv.runner import TextLoggerHook


@HOOKS.register_module()
class CustomizedTextLoggerHook(TextLoggerHook):
"""Customized Text Logger hook.
This logger prints out both lr and layer_0_lr.
"""

def _log_info(self, log_dict, runner):
# print exp name for users to distinguish experiments
# at every ``interval_exp_name`` iterations and the end of each epoch
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
runner.logger.info(exp_info)

if log_dict['mode'] == 'train':
lr_str = {}
for lr_type in ['lr', 'layer_0_lr']:
if isinstance(log_dict[lr_type], dict):
lr_str[lr_type] = []
for k, val in log_dict[lr_type].items():
lr_str.append(f'{lr_type}_{k}: {val:.3e}')
lr_str[lr_type] = ' '.join(lr_str)
else:
lr_str[lr_type] = f'{lr_type}: {log_dict[lr_type]:.3e}'

# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{log_dict["epoch"]}]' \
f'[{log_dict["iter"]}/{len(runner.data_loader)}]\t'
else:
log_str = f'Iter [{log_dict["iter"]}/{runner.max_iters}]\t'
log_str += f'{lr_str["lr"]}, {lr_str["layer_0_lr"]}, '

if 'time' in log_dict.keys():
self.time_sec_tot += (log_dict['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {log_dict["time"]:.3f}, ' \
f'data_time: {log_dict["data_time"]:.3f}, '
# statistic memory
if torch.cuda.is_available():
log_str += f'memory: {log_dict["memory"]}, '
else:
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
log_str = f'Epoch({log_dict["mode"]}) ' \
f'[{log_dict["epoch"]}][{log_dict["iter"]}]\t'
else:
log_str = f'Iter({log_dict["mode"]}) [{log_dict["iter"]}]\t'

log_items = []
for name, val in log_dict.items():
# TODO: resolve this hack
# these items have been in log_str
if name in [
'mode', 'Epoch', 'iter', 'lr', 'layer_0_lr', 'time', 'data_time',
'memory', 'epoch'
]:
continue
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)

runner.logger.info(log_str)


def log(self, runner):
if 'eval_iter_num' in runner.log_buffer.output:
# this doesn't modify runner.iter and is regardless of by_epoch
cur_iter = runner.log_buffer.output.pop('eval_iter_num')
else:
cur_iter = self.get_iter(runner, inner_iter=True)

log_dict = OrderedDict(
mode=self.get_mode(runner),
epoch=self.get_epoch(runner),
iter=cur_iter)

# record lr and layer_0_lr
cur_lr = runner.current_lr()
if isinstance(cur_lr, list):
print(cur_lr)
log_dict['layer_0_lr'] = min(cur_lr)
log_dict['lr'] = max(cur_lr)
else:
assert isinstance(cur_lr, dict)
log_dict['lr'], log_dict['layer_0_lr'] = {}, {}
for k, lr_ in cur_lr.items():
assert isinstance(lr_, list)
log_dict['layer_0_lr'].update({k: min(lr_)})
log_dict['lr'].update({k: max(lr_)})

if 'time' in runner.log_buffer.output:
# statistic memory
if torch.cuda.is_available():
log_dict['memory'] = self._get_max_memory(runner)

log_dict = dict(log_dict, **runner.log_buffer.output)

self._log_info(log_dict, runner)
self._dump_log(log_dict, runner)
return log_dict

0 comments on commit e4e7eb2

Please sign in to comment.