Skip to content

Improve log, save origin yaml, and fix adan #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def build_dataset(
NUM_WORKERS_MAP = int(cores / num_devices - NUM_WORKERS_BATCH) # optimal num workers assuming all cpu cores are used in this job
num_workers = loader_config.get("num_workers", NUM_WORKERS_MAP)
if num_workers > int(cores / num_devices):
print(f'WARNING: num_workers is adjusted to {int(cores / num_devices)} since {num_workers}x{num_devices} exceeds the number of CPU cores {cores}')
print(f'WARNING: `num_workers` is adjusted to {int(cores / num_devices)} since {num_workers}x{num_devices} exceeds the number of CPU cores {cores}')
num_workers = int(cores / num_devices)
## prefetch_size: the length of the cache queue in the data pipeline for each worker, used to reduce waiting time. Larger value leads to more memory consumption. Default: 16
prefetch_size = loader_config.get("prefetch_size", 16) #
Expand All @@ -113,7 +113,7 @@ def build_dataset(
dataset = dataset_class(**dataset_args)

dataset_column_names = dataset.get_output_columns()
print('==> Dataset output columns: \n\t', dataset_column_names)
#print('=> Dataset output columns: \n\t', dataset_column_names)

## Generate source dataset (source w.r.t. the dataset.map pipeline) based on python callable numpy dataset in parallel
ds = ms.dataset.GeneratorDataset(
Expand All @@ -134,17 +134,21 @@ def build_dataset(
# get batch of dataset by collecting batch_size consecutive data rows and apply batch operations
num_samples = ds.get_dataset_size()
batch_size = loader_config['batch_size']
print(f'INFO: num_samples: {num_samples}, batch_size: {batch_size}')

device_id = 0 if shard_id is None else shard_id
is_main_device = device_id == 0
print(f'INFO: Creating dataloader (training={is_train}) for device {device_id}. Number of data samples: {num_samples}')

if 'refine_batch_size' in kwargs:
batch_size = _check_batch_size(num_samples, batch_size, refine=kwargs['refine_batch_size'])

drop_remainder = loader_config.get('drop_remainder', is_train)
if is_train and drop_remainder == False:
print('WARNING: drop_remainder should be True for training, otherwise the last batch may lead to training fail in Graph mode')
if is_train and drop_remainder == False and is_main_device:
print('WARNING: `drop_remainder` should be True for training, otherwise the last batch may lead to training fail in Graph mode')

if not is_train:
if drop_remainder:
print("WARNING: drop_remainder is forced to be False for evaluation to include the last batch for accurate evaluation." )
if drop_remainder and is_main_device:
print("WARNING: `drop_remainder` is forced to be False for evaluation to include the last batch for accurate evaluation." )
drop_remainder = False

dataloader = ds.batch(
Expand All @@ -158,6 +162,7 @@ def build_dataset(

return dataloader


def _check_dataset_paths(dataset_config):
if 'dataset_root' in dataset_config:
if isinstance(dataset_config['data_dir'], str):
Expand All @@ -181,7 +186,6 @@ def _check_batch_size(num_samples, ori_batch_size=32, refine=True):
for bs in range(ori_batch_size - 1, 0, -1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if the number of samples in the evaluation set is a prime number, then the batch size will be set to 1. This can significantly increase evaluation time. Can we just set drop_remainder to False for the evaluation set and leave the batch size as it is?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agree. The time consumption of run batch size = 1 is usually longer than the time running with two different batch sizes, the compiling time of model due to the different batch size is negligible

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even when we set drop_remainder to False, the last batch will be padded to batch_size, leading to an inaccurate evaluation result.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember it is not padded, the remainder will be output with different batch size.

from mindspore.dataset import GeneratorDataset

dataset = GeneratorDataset(range(10), 'data').batch(4)
for x in dataset.create_tuple_iterator(num_epochs=1):
    print(x[0].shape)

>>>(4,)
>>>(4,)
>>>(2,)

if num_samples % bs == 0:
print(
f"WARNING: num eval samples {num_samples} can not be divided by "
f"the input batch size {ori_batch_size}. The batch size is refined to {bs}"
f"INFO: Batch size for evaluation is refined to {bs} to ensure the last batch will not be dropped/padded in graph mode."
)
return bs
2 changes: 1 addition & 1 deletion mindocr/data/transforms/rec_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self,
char_list = [c for c in "0123456789abcdefghijklmnopqrstuvwxyz"]

self.lower = True
print("INFO: The character_dict_path is None, model can only recognize number and lower letters")
#print("INFO: The character_dict_path is None, model can only recognize number and lower letters")
else:
# TODO: this is commonly used in other modules, wrap into a func or class.
# parse char dictionary
Expand Down
3 changes: 1 addition & 2 deletions mindocr/losses/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ def build_loss(name, **kwargs):

loss_fn = eval(name)(**kwargs)

# print('loss func inputs: ', loss_fn.construct.__code__.co_varnames)
print('==> Loss func input args: \n\t', inspect.signature(loss_fn.construct))
# print('=> Loss func input args: \n\t', inspect.signature(loss_fn.construct))

return loss_fn
2 changes: 1 addition & 1 deletion mindocr/losses/rec_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, pred_seq_len=26, max_label_len=25, batch_size=32, reduction='
self.ctc_loss = ops.CTCLoss(ctc_merge_repeated=True)

self.reduction = reduction
print('D: ', self.label_indices.shape)
#print('D: ', self.label_indices.shape)

# TODO: diff from paddle, paddle takes `label_length` as input too.
def construct(self, pred: Tensor, label: Tensor):
Expand Down
3 changes: 1 addition & 2 deletions mindocr/models/backbones/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ def build_backbone(name, **kwargs):
pretrained = kwargs['pretrained']
if not isinstance(pretrained, bool):
load_model(backbone, pretrained)
else:
print(f'Backbone weights are already loaded from default url defined in {name} python file')
# No need to load again if pretrained is bool and True, because pretrained backbone is already loaded in the backbone definition function.')

return backbone
2 changes: 1 addition & 1 deletion mindocr/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ def load_model(network, load_from: str):
params = load_checkpoint(load_from)
load_param_into_net(network, params)

print(f'==> Finish loading checkoint from {load_from}.')
print(f'INFO: Finish loading model checkoint from {load_from}. If no parameter fail-load warning displayed, all checkpoint params have been successfully loaded.')
111 changes: 36 additions & 75 deletions mindocr/optim/adan.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,12 @@
# Copyright 2020-2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""adan"""
from __future__ import absolute_import

import mindspore as ms
from mindspore import ops
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
from mindspore.common.tensor import Tensor
from mindspore.nn.optim.optimizer import Optimizer, opt_init_args_register
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P

_adan_opt = C.MultitypeFuncGraph("adan_opt")
_scaler_one = Tensor(1, mstype.int32)
_scaler_ten = Tensor(10, mstype.float32)
_adan_opt = ops.MultitypeFuncGraph("adan_opt")


@_adan_opt.register(
Expand Down Expand Up @@ -73,67 +54,55 @@ def _update_run_op(
Returns:
Tensor, the new value of v after updating.
"""
op_cast = P.Cast()
op_mul = P.Mul()
op_square = P.Square()
op_sqrt = P.Sqrt()
op_cast = P.Cast()
op_reshape = P.Reshape()
op_shape = P.Shape()

success = True

# if global_step == 0.0: # init
# TODO: use global_step==0 as the condition to init prev_gradient as gradient
# if (F.reduce_min(prev_gradient) == 0.0) and (F.reduce_max(prev_gradient) == 0.0):
if F.reduce_sum(prev_gradient) == 0.0:
success = F.depend(success, F.assign(prev_gradient, gradient))

# TODO: is casting needed?
op_mul = ops.Mul()
op_square = ops.Square()
op_sqrt = ops.Sqrt()
op_cast = ops.Cast()
op_reshape = ops.Reshape()
op_shape = ops.Shape()

success = ms.Tensor(True, dtype=ms.bool_)

if ops.reduce_sum(prev_gradient) == 0.0:
success = ops.depend(success, ops.assign(prev_gradient, gradient))

param_fp32 = op_cast(param, mstype.float32)
m_fp32 = op_cast(m, mstype.float32)
v_fp32 = op_cast(v, mstype.float32)
n_fp32 = op_cast(n, mstype.float32)
gradient_fp32 = op_cast(gradient, mstype.float32)
prev_gradient_fp32 = op_cast(prev_gradient, mstype.float32)

next_m = op_mul(F.tuple_to_array((1.0,)) - beta1, m_fp32) + op_mul(beta1, gradient_fp32)
next_m = op_mul(ops.tuple_to_array((1.0,)) - beta1, m_fp32) + op_mul(beta1, gradient_fp32)

next_v = op_mul(F.tuple_to_array((1.0,)) - beta2, v_fp32) + op_mul(beta2, gradient_fp32 - prev_gradient_fp32)
next_v = op_mul(ops.tuple_to_array((1.0,)) - beta2, v_fp32) + op_mul(beta2, gradient_fp32 - prev_gradient_fp32)

next_n = op_mul(F.tuple_to_array((1.0,)) - beta3, n_fp32) + op_mul(
beta3, op_square(gradient + op_mul(F.tuple_to_array((1.0,)) - beta2, gradient_fp32 - prev_gradient_fp32))
next_n = op_mul(ops.tuple_to_array((1.0,)) - beta3, n_fp32) + op_mul(
beta3, op_square(gradient + op_mul(ops.tuple_to_array((1.0,)) - beta2, gradient_fp32 - prev_gradient_fp32))
)

lr_t = lr / (eps + op_sqrt(next_n))

update = next_m + op_mul(F.tuple_to_array((1.0,)) - beta2, next_v)

# if decay_flag:
# update = op_mul(weight_decay, param_fp32) + update
update = next_m + op_mul(ops.tuple_to_array((1.0,)) - beta2, next_v)

next_param = param_fp32 - op_reshape(op_mul(lr_t, update), op_shape(param_fp32))

next_param = next_param / (Tensor(1.0, mstype.float32) + op_mul(weight_decay, lr_t))

success = F.depend(success, F.assign(param, op_cast(next_param, F.dtype(param))))
success = F.depend(success, F.assign(m, op_cast(next_m, F.dtype(m))))
success = F.depend(success, F.assign(v, op_cast(next_v, F.dtype(v))))
success = F.depend(success, F.assign(n, op_cast(next_n, F.dtype(n))))
success = F.depend(success, F.assign(prev_gradient, gradient))
success = ops.depend(success, ops.assign(param, op_cast(next_param, ops.dtype(param))))
success = ops.depend(success, ops.assign(m, op_cast(next_m, ops.dtype(m))))
success = ops.depend(success, ops.assign(v, op_cast(next_v, ops.dtype(v))))
success = ops.depend(success, ops.assign(n, op_cast(next_n, ops.dtype(n))))
success = ops.depend(success, ops.assign(prev_gradient, gradient))

return op_cast(next_param, F.dtype(param))
return op_cast(next_param, ops.dtype(param))


def _check_param_value(beta1, beta2, eps, use_locking, prim_name):
def _check_param_value(beta1, beta2, eps, prim_name):
"""Check the type of inputs."""
assert isinstance(beta1, float), f"For '{prim_name}', the type of 'beta1' must be 'float', but got type '{type(beta1).__name__}'."
assert isinstance(beta2, float), f"For '{prim_name}', the type of 'beta2' must be 'float', but got type '{type(beta2).__name__}'."
assert isinstance(eps, float), f"For '{prim_name}', the type of 'eps' must be 'float', but got type '{type(eps).__name__}'."
assert 0.0 < beta1 < 1.0, f"For '{prim_name}', the range of 'beta1' must be (0.0, 1.0), but got {beta1}."
assert 0.0 < beta2 < 1.0, f"For '{prim_name}', the range of 'beta2' must be (0.0, 1.0), but got {beta2}."
assert eps > 0, f"For '{prim_name}', the 'eps' must be positive, but got {eps}."
assert isinstance(use_locking, bool), f"For '{prim_name}', the type of 'use_locking' must be 'bool', but got type '{type(use_locking).__name__}'."
assert isinstance(beta1, float) and 0 <= beta1 <= 1.0, f"For {prim_name}, beta1 should between 0 and 1"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the docstring of function _update_run_op() in line 44 and 45. (0.0, 1.0) -> [0.0, 1.0]

assert isinstance(beta2, float) and 0 <= beta2 <= 1.0, f"For {prim_name}, beta2 should between 0 and 1"
assert isinstance(eps, float) and eps > 0, f"For {prim_name}, eps should be bigger than 0"


class Adan(Optimizer):
Expand All @@ -153,28 +122,26 @@ def __init__(
beta3=0.99,
eps=1e-8,
use_locking=False,
weight_decay=1e-6,
weight_decay=0.0,
loss_scale=1.0,
):
super().__init__(
learning_rate, params, weight_decay=weight_decay, loss_scale=loss_scale
) # Optimized inherit weight decay is bloaked. weight decay is computed in this py.

_check_param_value(beta1, beta2, eps, use_locking, self.cls_name)
_check_param_value(beta1, beta2, eps, self.cls_name)
assert isinstance(use_locking, bool), f"For {self.cls_name}, use_looking should be bool"

self.beta1 = Tensor(beta1, mstype.float32)
self.beta2 = Tensor(beta2, mstype.float32)
self.beta3 = Tensor(beta3, mstype.float32)
# self.beta1_power = Parameter(initializer(1, [1], mstype.float32), name="beta1_power")
# self.beta2_power = Parameter(initializer(1, [1], mstype.float32), name="beta2_power")
# self.beta3_power = Parameter(initializer(1, [1], mstype.float32), name="beta3_power")

self.eps = Tensor(eps, mstype.float32)
self.use_locking = use_locking
self.moment1 = self._parameters.clone(prefix="moment1", init="zeros") # m
self.moment2 = self._parameters.clone(prefix="moment2", init="zeros") # v
self.moment3 = self._parameters.clone(prefix="moment3", init="zeros") # n
self.prev_gradient = self._parameters.clone(prefix="prev_gradient", init="zeros")
# print('prev g: ', type(self.prev_gradient))

self.weight_decay = Tensor(weight_decay, mstype.float32)

Expand All @@ -184,29 +151,23 @@ def construct(self, gradients):
moment1 = self.moment1
moment2 = self.moment2
moment3 = self.moment3
# vhat = self.vhat

gradients = self.flatten_gradients(gradients)
# gradients = self.decay_weight(gradients) # we decay weight in adan_opt func
gradients = self.gradients_centralization(gradients)
gradients = self.scale_grad(gradients)
gradients = self._grad_sparse_indices_deduplicate(gradients)
lr = self.get_lr()
# weight_decay = self.get_weight_decay()

# if self.global_step == 0:
# success = F.depend(True, F.assign(self.prev_gradient, gradients))

# TODO: currently not support dist
success = self.map_(
F.partial(_adan_opt, self.beta1, self.beta2, self.beta3, self.eps, lr, self.weight_decay),
ops.partial(_adan_opt, self.beta1, self.beta2, self.beta3, self.eps, lr, self.weight_decay),
params,
moment1,
moment2,
moment3,
gradients,
self.prev_gradient,
)
# params, moment1, moment2, moment3, gradients, self.prev_gradient, self.global_step)

return success

Expand Down
2 changes: 1 addition & 1 deletion mindocr/optim/param_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def create_group_params(params, weight_decay=0, grouping_strategy=None, no_weigh
# TODO: assert valid arg names
gp = grouping_strategy

print(f'INFO: param grouping startegy: {grouping_strategy}, no_weight_decay_params: ', no_weight_decay_params)
#print(f'INFO: param grouping startegy: {grouping_strategy}, no_weight_decay_params: ', no_weight_decay_params)
if gp is not None:
if weight_decay == 0:
print("WARNING: weight decay is 0 in param grouping, which is meaningless. Please check config setting.")
Expand Down
2 changes: 1 addition & 1 deletion mindocr/postprocess/rec_postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self,
if character_dict_path is None:
char_list = [c for c in "0123456789abcdefghijklmnopqrstuvwxyz"]
self.lower = True
print("INFO: The character_dict_path is None, model can only recognize number and lower letters")
print("INFO: `character_dict_path` for RecCTCLabelDecode is not given. Default dict \"0123456789abcdefghijklmnopqrstuvwxyz\" is applied. Only number and English letters (regardless of lower/upper case) will be recognized and evaluated.")
else:
# parse char dictionary
char_list = []
Expand Down
2 changes: 1 addition & 1 deletion mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def eval(self):

# visualize
if self.verbose:
print('Eval data info: ', data_info)
print('Data meta info: ', data_info)

if self.visualize:
img = img[0].asnumpy()
Expand Down
11 changes: 7 additions & 4 deletions mindocr/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ class Logger(logging.Logger):
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0, is_main_device=False, log_fn=None):
def __init__(self, logger_name, rank=0, log_fn=None):
super(Logger, self).__init__(logger_name)
self.rank = rank
self.rank = rank or 0
self.log_fn = log_fn
is_main_device = not rank

if is_main_device:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
Expand Down Expand Up @@ -58,8 +60,9 @@ def important_info(self, msg, *args, **kwargs):
self.info(important_msg, *args, **kwargs)


def get_logger(log_dir, rank, is_main_device):
def get_logger(log_dir, rank):
"""Get Logger."""
logger = Logger('mindocr', rank, is_main_device)
logger = Logger('mindocr', rank)
logger.setup_logging_file(log_dir)

return logger
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ addict
matplotlib
addict
numpy
shutils
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not shutil. shutil is a built-in python library, so no need to install with pip.

imgaug>=0.4.0
tqdm>=4.64.1
opencv-python-headless>=3.4.18.65
Expand Down
2 changes: 1 addition & 1 deletion tests/ut/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ def test_logger(task):
cfg = yaml.safe_load(fp)
cfg = Dict(cfg)

logger = get_logger(cfg.train.ckpt_save_dir, rank=0, is_main_device=True)
logger = get_logger(cfg.train.ckpt_save_dir, rank=0)
Loading