Skip to content

Add checkpoint manager to save latest or top k checkpoints in history #279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions configs/rec/crnn/crnn_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ system:
distribute: False
amp_level: 'O3'
seed: 42
ckpt_save_policy: top_k # top_k or latest_k
ckpt_max_keep: 5
log_interval: 10
val_while_train: True
drop_overflow_update: False
Expand Down
4 changes: 2 additions & 2 deletions mindocr/data/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
"""

DEFAULT_CROP_PCT = 0.875
IMAGENET_DEFAULT_MEAN = [0.485 * 255, 0.456 * 255, 0.406 * 255]
IMAGENET_DEFAULT_STD = [0.229 * 255, 0.224 * 255, 0.225 * 255]
IMAGENET_DEFAULT_MEAN = [0.485 * 255, 0.456 * 255, 0.406 * 255] # RGB
IMAGENET_DEFAULT_STD = [0.229 * 255, 0.224 * 255, 0.225 * 255] # RGB
18 changes: 18 additions & 0 deletions mindocr/utils/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mindspore.common import dtype as mstype
from mindspore import save_checkpoint
from mindspore.train.callback._callback import Callback, _handle_loss
from .checkpoint import CheckpointManager
from .visualize import draw_bboxes, show_imgs, recover_image
from .recorder import PerfRecorder

Expand Down Expand Up @@ -178,6 +179,8 @@ def __init__(self,
val_interval=1,
val_start_epoch=1,
log_interval=1,
ckpt_save_policy='top_k',
ckpt_max_keep=10,
):
self.rank_id = rank_id
self.is_main_device = rank_id in [0, None]
Expand Down Expand Up @@ -214,6 +217,13 @@ def __init__(self,
# lamda expression is not supported in jit
self._loss_reduce = self._reduce if device_num is not None else lambda x: x

if self.is_main_device:
self.ckpt_save_policy = ckpt_save_policy
self.ckpt_manager = CheckpointManager(ckpt_save_dir,
ckpt_save_policy,
k=ckpt_max_keep,
prefer_low_perf=(self.main_indicator=='train_loss'))

@jit
def _reduce(self, x):
return self._reduce_sum(x) / self._device_num # average value across all devices
Expand Down Expand Up @@ -307,6 +317,14 @@ def on_train_epoch_end(self, run_context):

self.logger(f'=> Best {self.main_indicator}: {self.best_perf}, checkpoint saved.')

# save history checkpoints
self.ckpt_manager.save(self.network, perf, ckpt_name=f'e{cur_epoch}.ckpt')
if self.ckpt_save_policy=='top_k' and cur_epoch >= self.val_start_epoch:
log_str = f'Top K checkpoints:\n{self.main_indicator}\tcheckpoint\n'
for p, ckpt_name in self.ckpt_manager.get_ckpt_queue():
log_str += f'{p:.4f}\t{os.path.join(self.ckpt_save_dir, ckpt_name)}\n'
self.logger(log_str)

# record results
if cur_epoch == 1:
if self.loader_eval is not None:
Expand Down
98 changes: 98 additions & 0 deletions mindocr/utils/checkpoint.py
Copy link
Collaborator

@hadipash hadipash May 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using python's queue and heap to manage checkpoints will be easier, but this is not critical.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also thought about it, but heap doesn't sort.

Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
"""checkpoint manager """
import os
import stat

import numpy as np

import mindspore as ms
from mindspore import log as logger


class CheckpointManager:
"""
Manage checkpoint files according to ckpt_save_policy of checkpoint.
Args:
ckpt_save_dir (str): directory to save the checkpoints
ckpt_save_policy (str): Checkpoint saving strategy. Option: None, "top_k", or "latest_k".
None means to save each checkpoint, top_k means to save K checkpoints with the best performance,
and latest_k means saving the latest K checkpoint. Default: top_k.
k (int): top k value
prefer_low_perf (bool): standard for selecting the top k performance. If False, pick top k checkpoints with highest performance e.g. accuracy. If True, pick top k checkpoints with the lowest performance, e.g. loss.

"""

def __init__(self, ckpt_save_dir, ckpt_save_policy='top_k', k=10, prefer_low_perf=False, del_past=True):
self.ckpt_save_dir = ckpt_save_dir
self._ckpt_filelist = []
self.ckpt_save_policy = ckpt_save_policy
self.k = k

self.ckpt_queue = []
self.del_past = del_past
self.prefer_low_perf = prefer_low_perf

def get_ckpt_queue(self):
"""Get all the related checkpoint files managed here."""
return self.ckpt_queue

@property
def ckpt_num(self):
"""Get the number of the related checkpoint files managed here."""
return len(self.ckpt_queue)

def remove_ckpt_file(self, file_name):
"""Remove the specified checkpoint file from this checkpoint manager and also from the directory."""
try:
if os.path.exists(file_name):
os.chmod(file_name, stat.S_IWRITE)
os.remove(file_name)
except OSError:
logger.warning("OSError, failed to remove the older ckpt file %s.", file_name)
except ValueError:
logger.warning("ValueError, failed to remove the older ckpt file %s.", file_name)


def save_top_k(self, network, perf, ckpt_name, verbose=True):
"""Save and return Top K checkpoint address and accuracy."""
self.ckpt_queue.append((perf, ckpt_name))
self.ckpt_queue = sorted(self.ckpt_queue, key=lambda x: x[0], reverse=not self.prefer_low_perf) # by default, reverse is True for descending order
if len(self.ckpt_queue) > self.k:
to_del = self.ckpt_queue.pop(-1)
# save if the perf is better than the minimum in the heap
if to_del[1] != ckpt_name:
ms.save_checkpoint(network, os.path.join(self.ckpt_save_dir, ckpt_name))
# del minimum
self.remove_ckpt_file(os.path.join(self.ckpt_save_dir, to_del[1]))
else:
ms.save_checkpoint(network, os.path.join(self.ckpt_save_dir, ckpt_name))

def save_latest_k(self, network, ckpt_name):
"""Save latest K checkpoint."""
ms.save_checkpoint(network, os.path.join(self.ckpt_save_dir, ckpt_name))
self.ckpt_queue.append(ckpt_name)
if len(self.ckpt_queue) > self.k:
to_del = self.ckpt_queue.pop(0)
if self.del_past:
self.remove_ckpt_file(os.path.join(self.ckpt_save_dir, to_del))

def save_single(self, network, ckpt_path):
ms.save_checkpoint(network, ckpt_path)

def save(self, network, perf=None, ckpt_name=None):
"""Save checkpoint according to different save strategy.
"""
if self.ckpt_save_policy is None:
ms.save_checkpoint(network, os.path.join(self.ckpt_save_dir, ckpt_name))
elif self.ckpt_save_policy == "top_k":
if perf is None:
raise ValueError(f"The expected 'metric' is not None, but got: {metric}.")
self.save_top_k(network, perf, ckpt_name)
return self.ckpt_queue
elif self.ckpt_save_policy == "latest_k":
self.save_latest_k(network, ckpt_name)
return self.ckpt_queue
else:
raise ValueError(
f"The expected 'ckpt_save_policy' is None, top_k or latest_k, but got: {self.ckpt_save_policy}."
)

6 changes: 4 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

def main(cfg):
# init env
ms.set_context(mode=cfg.system.mode)
ms.set_context(mode=cfg.system.mode, device_id=6)
if cfg.system.distribute:
init()
device_num = get_group_size()
Expand Down Expand Up @@ -158,7 +158,9 @@ def main(cfg):
meta_data_indices=cfg.eval.dataset.pop('meta_data_column_index', None),
val_interval=cfg.system.get('val_interval', 1),
val_start_epoch=cfg.system.get('val_start_epoch', 1),
log_interval=cfg.system.get('log_interval', 100)
log_interval=cfg.system.get('log_interval', 100),
ckpt_save_policy=cfg.system.get('ckpt_save_policy', 'top_k'),
ckpt_max_keep=cfg.system.get('ckpt_max_keep', 10),
)

# log
Expand Down