Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mindir export for trained models and related docs, tests; Rename dbnet_r50 -> dbnet_resnet50, crnn_r34 -> crnn_resnet34 for consistency #184

Merged
merged 7 commits into from
Apr 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions configs/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@

## About configs

This folder contains the configurations including
- model definition
- training recipes
- pretrained weights
- reported performance
for all models trained with MindOCR.

## Model Export

To convert a pretrained model from mindspore checkpoint format to [MindIR](https://www.mindspore.cn/docs/zh-CN/r2.0.0-alpha/design/mindir.html) format for deployment, please use the `tools/export.py` script.

``` shell
# convert dbnet_resnet50 with pretrained weights to MindIR format
python tools/export.py --model_name dbnet_resnet50 --pretrained

# convert dbnet_resnet50 loaded with weights to MindIR format
python tools/export.py --model_name dbnet_resnet50 --ckpt_load_path /path/to/checkpoint
```

For more usage, run `python tools/export.py -h`.

I
18 changes: 5 additions & 13 deletions mindocr/models/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

3. Define your model in two ways

a. Write a model py file, which includes the model class and specification functions. Please follow the [model format guideline](#format-guideline-for-model-py-file). It is to allows users to invoke a pre-defined model easily, such as `model = build_model('dbnet_r50', pretrained=True)` .
a. Write a model py file, which includes the model class and specification functions. Please follow the [model format guideline](#format-guideline-for-model-py-file). It is to allows users to invoke a pre-defined model easily, such as `model = build_model('dbnet_resnet50', pretrained=True)` .

b. Config the architecture in a yaml file. Please follow the [yaml format guideline](#format-guideline-for-yaml-file) . It is to allows users to modify a base architecture quickly in yaml file.

Expand Down Expand Up @@ -60,24 +60,16 @@ python tests/ut/test_model.py --config /path/to/yaml_config_file
* File naming: `models/{task}_{model_class_name}.py`, e.g., `det_dbnet.py`
* Class naming: {ModelName}, e.g., `class DBNet`
* Class MUST inherent from `BaseModel`, e.g., `class DBNet(BaseModel)`
* Spec. function naming: `{model_class_name}_{specifiation}.py`, e.g. `def dbnet_r50()` (Note: no need to add task prefix assuming no one model can solve any two tasks)
* Spec. function args: (pretrained=False, **kwargs), e.g. `def dbnet_r50(pretrained=False, **kwargs)`.
* Spec. function naming: `{model_class_name}_{specifiation}.py`, e.g. `def dbnet_resnet50()` (Note: no need to add task prefix assuming no one model can solve any two tasks)
* Spec. function args: (pretrained=False, **kwargs), e.g. `def dbnet_resnet50(pretrained=False, **kwargs)`.
* Spec. function return: model (nn.Cell), which is the model instance
* Spec. function decorator: MUST add @register_model decorator, which is to register the model to the supported model list.

**Note:** Once you finish writing the model specification function, you should be able to use it in the yaml file for training or inference as follows,

``` python
# in a yaml file
model:
name: dbnet_r50 # model specificatio function name
pretrained: False
```

or, use it via the `build_model` func.
After writing and registration, model can be created via the `build_model` func.
``` python
# in a python script
model = build_model('dbnet_r50', pretrained=False)
model = build_model('dbnet_resnet50', pretrained=False)
```

## Format Guideline for Yaml File
Expand Down
42 changes: 19 additions & 23 deletions mindocr/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,54 +8,50 @@

__all__ = ['build_model']

#def build_model(config: Union[dict,str]):
def build_model(config: Union[dict, str], **kwargs): #config: Union[dict,str]):

def build_model(name_or_config: Union[str, dict], **kwargs):
'''
There are two ways to build a model.
1. load a predefined model according the given model name.
2. build the model according to the detailed configuration of the each module (transform, backbone, neck and head), for lower-level architecture customization.

Args:
config (Union[dict, str]): if it is a str, config is the model name. Predefined model with weights will be returned.
if dict, config is a dictionary and the available keys are:
model_name: string, model name in the registered models
pretrained: bool, if True, download the pretrained weight for the preset url and load to the network.
backbone: dict, a dictionary containing the backbone config, the available keys are defined in backbones/builder.py
neck: dict,
head: dict,
kwargs: if config is a str of model name, kwargs contains the args for the model.

name_or_config (Union[dict, str]): model name or config
if it's a string, it should be a model name (which can be found by mindocr.list_models())
if it's a dict, it should be an architecture configuration defining the backbone/neck/head components (e.g., parsed from yaml config).

kwargs (dict): options
if name_or_config is a model name, supported args in kwargs are:
- pretrained (bool): if True, pretrained checkpoint will be downloaded and loaded into the network.
- ckpt_load_path (str): path to checkpoint file. if a non-empty string given, the local checkpoint will loaded into the network.
if name_or_config is an architecture definition dict, supported args are:
- ckpt_load_path (str): path to checkpoint file.

Return:
nn.Cell

Example:
>>> from mindocr.models import build_model
>>> net = build_model(cfg['model'])
>>> net = build_model(cfg['model'], ckpt_load_path='./r50_fpn_dbhead.ckpt') # build network and load checkpoint
>>> net = build_model('dbnet_r50', pretrained=True)
>>> net = build_model('dbnet_resnet50', pretrained=True)

'''
if isinstance(config, str):
if isinstance(name_or_config, str):
# build model by specific model name
model_name = config #config['name']
model_name = name_or_config
if is_model(model_name):
create_fn = model_entrypoint(model_name)
'''
kwargs = {}
for k, v in config.items():
if k!=model_name and v is not None:
kwargs[k] = v
'''
network = create_fn(**kwargs)
else:
raise ValueError(f'Invalid model name: {model_name}. Supported models are {list_models()}')

elif isinstance(config, dict):
elif isinstance(name_or_config, dict):
# build model by given architecture config dict
network = BaseModel(config)
network = BaseModel(name_or_config)
else:
raise ValueError('Type error for config')

# load checkpoint
if 'ckpt_load_path' in kwargs:
ckpt_path = kwargs['ckpt_load_path']
Expand Down
8 changes: 4 additions & 4 deletions mindocr/models/det_dbnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from .base_model import BaseModel
from ._registry import register_model

__all__ = ['DBNet', 'dbnet_r50']
__all__ = ['DBNet', 'dbnet_resnet50']

def _cfg(url='', **kwargs):
return {
Expand All @@ -14,7 +14,7 @@ def _cfg(url='', **kwargs):


default_cfgs = {
'dbnet_r50': _cfg(
'dbnet_resnet50': _cfg(
url='https://download.mindspore.cn/toolkits/mindocr/dbnet/dbnet_resnet50-db1df47a.ckpt')
}

Expand All @@ -25,7 +25,7 @@ def __init__(self, config):


@register_model
def dbnet_r50(pretrained=False, **kwargs):
def dbnet_resnet50(pretrained=False, **kwargs):
model_config = {
"backbone": {
'name': 'det_resnet50',
Expand All @@ -48,7 +48,7 @@ def dbnet_r50(pretrained=False, **kwargs):

# load pretrained weights
if pretrained:
default_cfg = default_cfgs['dbnet_r50']
default_cfg = default_cfgs['dbnet_resnet50']
load_pretrained(model, default_cfg)

return model
Expand Down
8 changes: 4 additions & 4 deletions mindocr/models/rec_crnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from .backbones.mindcv_models.utils import load_pretrained


__all__ = ['CRNN', 'crnn_r34', 'crnn_vgg7']
__all__ = ['CRNN', 'crnn_resnet34', 'crnn_vgg7']

def _cfg(url='', **kwargs):
return {
Expand All @@ -16,7 +16,7 @@ def _cfg(url='', **kwargs):


default_cfgs = {
'crnn_r34': _cfg(
'crnn_resnet34': _cfg(
url='https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_resnet34-83f37f07.ckpt'),
'crnn_vgg7': _cfg(
url='https://download.mindspore.cn/toolkits/mindocr/crnn/crnn_vgg7-ea7e996c.ckpt'),
Expand All @@ -30,7 +30,7 @@ def __init__(self, config):


@register_model
def crnn_r34(pretrained=False, **kwargs):
def crnn_resnet34(pretrained=False, **kwargs):
model_config = {
"backbone": {
'name': 'rec_resnet34',
Expand All @@ -51,7 +51,7 @@ def crnn_r34(pretrained=False, **kwargs):

# load pretrained weights
if pretrained:
default_cfg = default_cfgs['crnn_r34']
default_cfg = default_cfgs['crnn_resnet34']
load_pretrained(model, default_cfg)

return model
Expand Down
48 changes: 48 additions & 0 deletions test_mindir_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import mindspore as ms
import numpy as np
from mindocr import list_models, build_model


def test_mindir_infer(name, task='rec'):
fn = f"{name}.mindir"

ms.set_context(mode=ms.GRAPH_MODE)
graph = ms.load(fn)
model = ms.nn.GraphCell(graph)

task = 'rec'
if 'db' in fn:
task = 'det'

if task=='rec':
c, h, w = 3, 32, 100
else:
c, h, w = 3, 640, 640

bs = 1
x = ms.Tensor(np.ones([bs, c, h, w]), dtype=ms.float32)

outputs_mindir = model(x)

# get original ckpt outputs
net = build_model(name, pretrained=True)
outputs_ckpt = net(x)

for i, o in enumerate(outputs_mindir):
print('mindir net out: ', outputs_mindir[i].sum(), outputs_mindir[i].shape)
print('ckpt net out: ', outputs_ckpt[i].sum(), outputs_mindir[i].shape)
assert outputs_mindir[i].sum()==outputs_ckpt[i].sum()


if __name__ == '__main__':
names = list_models()
for n in names:
task = 'rec'
if 'db' in n:
task = 'det'
print(n)
test_mindir_infer(n, task)




46 changes: 46 additions & 0 deletions tests/ut/test_mindir_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import sys
sys.path.append('.')
import mindspore as ms
import pytest
import numpy as np
from mindocr import list_models, build_model
from tools.export import export

@pytest.mark.parametrize('name', ['dbnet_resnet50', 'crnn_resnet34'])
def test_mindir_infer(name):
task = 'rec'
if 'db' in name:
task = 'det'

export(name, task, pretrained=True)

fn = f"{name}.mindir"

ms.set_context(mode=ms.GRAPH_MODE)
graph = ms.load(fn)
model = ms.nn.GraphCell(graph)

if task=='rec':
c, h, w = 3, 32, 100
else:
c, h, w = 3, 640, 640

bs = 1
x = ms.Tensor(np.ones([bs, c, h, w]), dtype=ms.float32)

outputs_mindir = model(x)

# get original ckpt outputs
net = build_model(name, pretrained=True)
outputs_ckpt = net(x)

for i, o in enumerate(outputs_mindir):
print('mindir net out: ', outputs_mindir[i].sum(), outputs_mindir[i].shape)
print('ckpt net out: ', outputs_ckpt[i].sum(), outputs_mindir[i].shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?
outputs_mindir -> outputs_ckpt

assert outputs_mindir[i].sum()==outputs_ckpt[i].sum()


if __name__ == '__main__':
names = list_models()
test_mindir_infer(names[0])

87 changes: 87 additions & 0 deletions tools/export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
'''
Usage:
To export all trained models from ckpt to mindir as listed in configs/, run
$ python tools/export.py

To export a sepecific model, taking dbnet for example, run
$ python tools/export.py --model_name dbnet_resnet50 --save_dir
'''
import sys
import os
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))

import argparse
import mindspore as ms
from mindocr import list_models, build_model
import numpy as np


def export(name, task='rec', pretrained=True, ckpt_load_path="", save_dir=""):
ms.set_context(mode=ms.GRAPH_MODE) #, device_target='Ascend')

net = build_model(name, pretrained=True)
net.set_train(False)

# TODO: extend input shapes for more models
if task=='rec':
c, h, w = 3, 32, 100
else:
c, h, w = 3, 640, 640

bs = 1
x = ms.Tensor(np.ones([bs, c, h, w]), dtype=ms.float32)

output_path = os.path.join(save_dir, name) + '.mindir'
ms.export(net, x, file_name=output_path, file_format='MINDIR')

print(f'=> Finish exporting {name} to {output_path}')


def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "1"):
return True
elif v.lower() in ("no", "false", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")


if __name__ == '__main__':
parser = argparse.ArgumentParser("Convert model checkpoint to mindir format.")
parser.add_argument(
'--model_name',
type=str,
default="",
help='Name of the model to convert, choices: [crnn_resnet34, crnn_vgg7, dbnet_resnet50, ""]. You can lookup the name by calling mindocr.list_models(). If "", all models in list_models() will be converted.')
parser.add_argument(
'--pretrained',
type=str2bool, nargs='?', const=True,
default=True,
help='Whether download and load the pretrained checkpoint into network.')
parser.add_argument(
'--ckpt_load_path',
type=str,
default="",
help='Path to a local checkpoint. No need to set it if pretrained is True. If set, network weights will be loaded using this checkpoint file')
parser.add_argument(
'--save_dir',
type=str,
default="",
help='Dir to save the exported model')

args = parser.parse_args()
mn = args.model_name

if mn =="":
names = list_models()
else:
names = [mn]

for n in names:
task = 'rec'
if 'db' in n:
task = 'det'
export(n, task, args.pretrained, args.ckpt_load_path, args.save_dir)