Skip to content

Commit

Permalink
rm load_dyg_pretrain
Browse files Browse the repository at this point in the history
  • Loading branch information
littletomatodonkey committed Jun 5, 2021
1 parent bd1820b commit 48d8537
Show file tree
Hide file tree
Showing 10 changed files with 27 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ Global:
save_epoch_step: 3
eval_batch_step: [0, 2000]
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: null
pretrained_model:
checkpoints:
save_inference_dir:
use_visualdl: false
infer_img: doc/imgs_words/ch/word_1.jpg
character_dict_path: ppocr/utils/ppocr_keys_v1.txt
Expand Down Expand Up @@ -38,7 +38,7 @@ Architecture:
algorithm: Distillation
Models:
Student:
pretrained: null
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
Expand All @@ -57,7 +57,7 @@ Architecture:
name: CTCHead
fc_decay: 0.00001
Teacher:
pretrained: null
pretrained:
freeze_params: false
return_all_feats: true
model_type: rec
Expand Down Expand Up @@ -118,8 +118,8 @@ Train:
- DecodeImage:
img_mode: BGR
channel_first: false
- RecAug: null
- CTCLabelEncode: null
- RecAug:
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
Expand All @@ -143,7 +143,7 @@ Eval:
- DecodeImage:
img_mode: BGR
channel_first: false
- CTCLabelEncode: null
- CTCLabelEncode:
- RecResizeImg:
image_shape: [3, 32, 320]
- KeepKeys:
Expand Down
4 changes: 2 additions & 2 deletions ppocr/modeling/architectures/distillation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from ppocr.modeling.necks import build_neck
from ppocr.modeling.heads import build_head
from .base_model import BaseModel
from ppocr.utils.save_load import load_dygraph_pretrain
from ppocr.utils.save_load import init_model

__all__ = ['DistillationModel']

Expand All @@ -46,7 +46,7 @@ def __init__(self, config):
pretrained = model_config.pop("pretrained")
model = BaseModel(model_config)
if pretrained is not None:
load_dygraph_pretrain(model, path=pretrained)
init_model(model, path=pretrained)
if freeze_params:
for param in model.parameters():
param.trainable = False
Expand Down
22 changes: 10 additions & 12 deletions ppocr/utils/save_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import paddle

from ppocr.utils.logging import get_logger

__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']


Expand All @@ -42,19 +44,11 @@ def _mkdir_if_not_exist(path, logger):
raise OSError('Failed to mkdir {}'.format(path))


def load_dygraph_pretrain(model, logger=None, path=None):
if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(path))
param_state_dict = paddle.load(path + '.pdparams')
model.set_state_dict(param_state_dict)
return


def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
def init_model(config, model, optimizer=None, lr_scheduler=None):
"""
load model from checkpoint or pretrained_model
"""
logger = get_logger()
global_config = config['Global']
checkpoints = global_config.get('checkpoints')
pretrained_model = global_config.get('pretrained_model')
Expand All @@ -77,13 +71,17 @@ def init_model(config, model, logger, optimizer=None, lr_scheduler=None):
best_model_dict = states_dict.get('best_model_dict', {})
if 'epoch' in states_dict:
best_model_dict['start_epoch'] = states_dict['epoch'] + 1

logger.info("resume from {}".format(checkpoints))
elif pretrained_model:
if not isinstance(pretrained_model, list):
pretrained_model = [pretrained_model]
for pretrained in pretrained_model:
load_dygraph_pretrain(model, logger, path=pretrained)
if not (os.path.isdir(pretrained) or
os.path.exists(pretrained + '.pdparams')):
raise ValueError("Model pretrain path {} does not "
"exists.".format(pretrained))
param_state_dict = paddle.load(pretrained + '.pdparams')
model.set_state_dict(param_state_dict)
logger.info("load pretrained model from {}".format(
pretrained_model))
else:
Expand Down
2 changes: 1 addition & 1 deletion tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
model = build_model(config['Architecture'])
use_srn = config['Architecture']['algorithm'] == "SRN"

best_model_dict = init_model(config, model, logger)
best_model_dict = init_model(config, model)
if len(best_model_dict):
logger.info('metric in ckpt ***************')
for k, v in best_model_dict.items():
Expand Down
2 changes: 1 addition & 1 deletion tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def main():
else: # base rec model
config["Architecture"]["Head"]["out_channels"] = char_num
model = build_model(config["Architecture"])
init_model(config, model, logger)
init_model(config, model)
model.eval()

save_path = config["Global"]["save_inference_dir"]
Expand Down
2 changes: 1 addition & 1 deletion tools/infer_cls.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def main():
# build model
model = build_model(config['Architecture'])

init_model(config, model, logger)
init_model(config, model)

# create data ops
transforms = []
Expand Down
2 changes: 1 addition & 1 deletion tools/infer_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def main():
# build model
model = build_model(config['Architecture'])

init_model(config, model, logger)
init_model(config, model)

# build post process
post_process_class = build_post_process(config['PostProcess'])
Expand Down
2 changes: 1 addition & 1 deletion tools/infer_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def main():
# build model
model = build_model(config['Architecture'])

init_model(config, model, logger)
init_model(config, model)

# build post process
post_process_class = build_post_process(config['PostProcess'],
Expand Down
2 changes: 1 addition & 1 deletion tools/infer_rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def main():

model = build_model(config['Architecture'])

init_model(config, model, logger)
init_model(config, model)

# create data ops
transforms = []
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def main(config, device, logger, vdl_writer):
# build metric
eval_class = build_metric(config['Metric'])
# load pretrain model
pre_best_model_dict = init_model(config, model, logger, optimizer)
pre_best_model_dict = init_model(config, model, optimizer)

logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
if valid_dataloader is not None:
Expand Down

0 comments on commit 48d8537

Please sign in to comment.