Skip to content

Commit 481f454

Browse files
authored
Merge c0a95aa into 47f5430
2 parents 47f5430 + c0a95aa commit 481f454

File tree

4 files changed

+441
-138
lines changed

4 files changed

+441
-138
lines changed

.dev_scripts/benchmark_valid_flops.py

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
import logging
2+
import re
3+
import tempfile
4+
from argparse import ArgumentParser
5+
from collections import OrderedDict
6+
from functools import partial
7+
from pathlib import Path
8+
9+
import numpy as np
10+
import pandas as pd
11+
import torch
12+
from mmengine import Config, DictAction
13+
from mmengine.analysis import get_model_complexity_info
14+
from mmengine.analysis.print_helper import _format_size
15+
from mmengine.fileio import FileClient
16+
from mmengine.logging import MMLogger
17+
from mmengine.model import revert_sync_batchnorm
18+
from mmengine.registry import init_default_scope
19+
from mmengine.runner import Runner
20+
from modelindex.load_model_index import load
21+
from rich.console import Console
22+
from rich.table import Table
23+
from rich.text import Text
24+
from tqdm import tqdm
25+
26+
from mmocr.registry import MODELS
27+
28+
console = Console()
29+
MMOCR_ROOT = Path(__file__).absolute().parents[1]
30+
31+
32+
def parse_args():
33+
parser = ArgumentParser(description='Valid all models in model-index.yml')
34+
parser.add_argument(
35+
'--shape',
36+
type=int,
37+
nargs='+',
38+
default=[1280, 800],
39+
help='input image size')
40+
parser.add_argument(
41+
'--checkpoint_root',
42+
help='Checkpoint file root path. If set, load checkpoint before test.')
43+
parser.add_argument('--img', default='demo/demo.jpg', help='Image file')
44+
parser.add_argument('--models', nargs='+', help='models name to inference')
45+
parser.add_argument(
46+
'--batch-size',
47+
type=int,
48+
default=1,
49+
help='The batch size during the inference.')
50+
parser.add_argument(
51+
'--flops', action='store_true', help='Get Flops and Params of models')
52+
parser.add_argument(
53+
'--flops-str',
54+
action='store_true',
55+
help='Output FLOPs and params counts in a string form.')
56+
parser.add_argument(
57+
'--cfg-options',
58+
nargs='+',
59+
action=DictAction,
60+
help='override some settings in the used config, the key-value pair '
61+
'in xxx=yyy format will be merged into config file. If the value to '
62+
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
63+
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
64+
'Note that the quotation marks are necessary and that no white space '
65+
'is allowed.')
66+
parser.add_argument(
67+
'--size_divisor',
68+
type=int,
69+
default=32,
70+
help='Pad the input image, the minimum size that is divisible '
71+
'by size_divisor, -1 means do not pad the image.')
72+
args = parser.parse_args()
73+
return args
74+
75+
76+
def inference(config_file, checkpoint, work_dir, args, exp_name):
77+
logger = MMLogger.get_instance(name='MMLogger')
78+
logger.warning('if you want test flops, please make sure torch>=1.12')
79+
cfg = Config.fromfile(config_file)
80+
cfg.work_dir = work_dir
81+
cfg.load_from = checkpoint
82+
cfg.log_level = 'WARN'
83+
cfg.experiment_name = exp_name
84+
if args.cfg_options is not None:
85+
cfg.merge_from_dict(args.cfg_options)
86+
init_default_scope(cfg.get('default_scope', 'mmocr'))
87+
88+
# forward the model
89+
result = {'model': config_file.stem}
90+
91+
if args.flops:
92+
93+
if len(args.shape) == 1:
94+
h = w = args.shape[0]
95+
elif len(args.shape) == 2:
96+
h, w = args.shape
97+
else:
98+
raise ValueError('invalid input shape')
99+
divisor = args.size_divisor
100+
if divisor > 0:
101+
h = int(np.ceil(h / divisor)) * divisor
102+
w = int(np.ceil(w / divisor)) * divisor
103+
104+
input_shape = (3, h, w)
105+
result['resolution'] = input_shape
106+
107+
try:
108+
cfg = Config.fromfile(config_file)
109+
if args.cfg_options is not None:
110+
cfg.merge_from_dict(args.cfg_options)
111+
112+
model = MODELS.build(cfg.model)
113+
input = torch.rand(1, *input_shape)
114+
if torch.cuda.is_available():
115+
model.cuda()
116+
input = input.cuda()
117+
model = revert_sync_batchnorm(model)
118+
inputs = (input, )
119+
model.eval()
120+
outputs = get_model_complexity_info(
121+
model, input_shape, inputs, show_table=False, show_arch=False)
122+
flops = outputs['flops']
123+
params = outputs['params']
124+
activations = outputs['activations']
125+
result['Get Types'] = 'Random input'
126+
except: # noqa 772
127+
logger = MMLogger.get_instance(name='MMLogger')
128+
logger.warning(
129+
'Direct get flops failed, try to get flops with data')
130+
cfg = Config.fromfile(config_file)
131+
data_loader = Runner.build_dataloader(cfg.val_dataloader)
132+
data_batch = next(iter(data_loader))
133+
model = MODELS.build(cfg.model)
134+
if torch.cuda.is_available():
135+
model = model.cuda()
136+
model = revert_sync_batchnorm(model)
137+
model.eval()
138+
_forward = model.forward
139+
data = model.data_preprocessor(data_batch)
140+
del data_loader
141+
model.forward = partial(
142+
_forward, data_samples=data['data_samples'])
143+
outputs = get_model_complexity_info(
144+
model,
145+
input_shape,
146+
data['inputs'],
147+
show_table=False,
148+
show_arch=False)
149+
flops = outputs['flops']
150+
params = outputs['params']
151+
activations = outputs['activations']
152+
result['Get Types'] = 'Dataloader'
153+
154+
if args.flops_str:
155+
flops = _format_size(flops)
156+
params = _format_size(params)
157+
activations = _format_size(activations)
158+
159+
result['flops'] = flops
160+
result['params'] = params
161+
162+
return result
163+
164+
165+
def show_summary(summary_data, args):
166+
table = Table(title='Validation Benchmark Regression Summary')
167+
table.add_column('Model')
168+
table.add_column('Validation')
169+
table.add_column('Resolution (c, h, w)')
170+
if args.flops:
171+
table.add_column('Flops', justify='right', width=11)
172+
table.add_column('Params', justify='right')
173+
174+
for model_name, summary in summary_data.items():
175+
row = [model_name]
176+
valid = summary['valid']
177+
color = 'green' if valid == 'PASS' else 'red'
178+
row.append(f'[{color}]{valid}[/{color}]')
179+
if valid == 'PASS':
180+
row.append(str(summary['resolution']))
181+
if args.flops:
182+
row.append(str(summary['flops']))
183+
row.append(str(summary['params']))
184+
table.add_row(*row)
185+
186+
console.print(table)
187+
table_data = {
188+
x.header: [Text.from_markup(y).plain for y in x.cells]
189+
for x in table.columns
190+
}
191+
table_pd = pd.DataFrame(table_data)
192+
table_pd.to_csv('./mmocr_flops.csv')
193+
194+
195+
# Sample test whether the inference code is correct
196+
def main(args):
197+
model_index_file = MMOCR_ROOT / 'model-index.yml'
198+
model_index = load(str(model_index_file))
199+
model_index.build_models_with_collections()
200+
models = OrderedDict({model.name: model for model in model_index.models})
201+
202+
logger = MMLogger(
203+
'validation',
204+
logger_name='validation',
205+
log_file='benchmark_test_image.log',
206+
log_level=logging.INFO)
207+
208+
if args.models:
209+
patterns = [
210+
re.compile(pattern.replace('+', '_')) for pattern in args.models
211+
]
212+
filter_models = {}
213+
for k, v in models.items():
214+
k = k.replace('+', '_')
215+
if any([re.match(pattern, k) for pattern in patterns]):
216+
filter_models[k] = v
217+
if len(filter_models) == 0:
218+
print('No model found, please specify models in:')
219+
print('\n'.join(models.keys()))
220+
return
221+
models = filter_models
222+
223+
summary_data = {}
224+
tmpdir = tempfile.TemporaryDirectory()
225+
for model_name, model_info in tqdm(models.items()):
226+
227+
if model_info.config is None:
228+
continue
229+
230+
model_info.config = model_info.config.replace('%2B', '+')
231+
config = Path(model_info.config)
232+
233+
try:
234+
config.exists()
235+
except: # noqa 722
236+
logger.error(f'{model_name}: {config} not found.')
237+
continue
238+
239+
logger.info(f'Processing: {model_name}')
240+
241+
http_prefix = 'https://download.openmmlab.com/mmocr/'
242+
if args.checkpoint_root is not None:
243+
root = args.checkpoint_root
244+
if 's3://' in args.checkpoint_root:
245+
from petrel_client.common.exception import AccessDeniedError
246+
file_client = FileClient.infer_client(uri=root)
247+
checkpoint = file_client.join_path(
248+
root, model_info.weights[len(http_prefix):])
249+
try:
250+
exists = file_client.exists(checkpoint)
251+
except AccessDeniedError:
252+
exists = False
253+
else:
254+
checkpoint = Path(root) / model_info.weights[len(http_prefix):]
255+
exists = checkpoint.exists()
256+
if exists:
257+
checkpoint = str(checkpoint)
258+
else:
259+
print(f'WARNING: {model_name}: {checkpoint} not found.')
260+
checkpoint = None
261+
else:
262+
checkpoint = None
263+
264+
try:
265+
# build the model from a config file and a checkpoint file
266+
result = inference(MMOCR_ROOT / config, checkpoint, tmpdir.name,
267+
args, model_name)
268+
result['valid'] = 'PASS'
269+
except Exception: # noqa 722
270+
import traceback
271+
logger.error(f'"{config}" :\n{traceback.format_exc()}')
272+
result = {'valid': 'FAIL'}
273+
274+
summary_data[model_name] = result
275+
276+
tmpdir.cleanup()
277+
show_summary(summary_data, args)
278+
279+
280+
if __name__ == '__main__':
281+
args = parse_args()
282+
main(args)

docs/en/user_guides/useful_tools.md

Lines changed: 20 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ python tools/visualizations/browse_dataset.py \
2929
| -t, --task | `auto`, `textdet`, `textrecog` | Specify the task type of the dataset. If `auto`, the task type will be inferred from the config. If the script is unable to infer the task type, you need to specify it manually. Defaults to `auto`. |
3030
| -n, --show-number | int | The number of samples to visualized. If not specified, display all images in the dataset. |
3131
| -i, --show-interval | float | Interval of visualization (s), defaults to 2. |
32-
| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
32+
| --cfg-options | str | Override configs. [Example](./config.md#command-line-modification) |
3333

3434
#### Examples
3535

@@ -110,30 +110,27 @@ python tools/analysis_tools/offline_eval.py configs/textdet/psenet/psenet_r50_fp
110110

111111
In addition, based on this tool, users can also convert predictions obtained from other libraries into MMOCR-supported formats, then use MMOCR's built-in metrics to evaluate them.
112112

113-
| ARGS | Type | Description |
114-
| ------------- | ----- | ------------------------------------------------------------------ |
115-
| config | str | (required) Path to the config. |
116-
| pkl_results | str | (required) The saved predictions. |
117-
| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) |
113+
| ARGS | Type | Description |
114+
| ------------- | ---- | ------------------------------------------------------------------ |
115+
| config | str | (required) Path to the config. |
116+
| pkl_results | str | (required) The saved predictions. |
117+
| --cfg-options | str | Override configs. [Example](./config.md#command-line-modification) |
118118

119119
### Calculate FLOPs and the Number of Parameters
120120

121-
We provide a method to calculate the FLOPs and the number of parameters, first we install the dependencies using the following command.
122-
123-
```shell
124-
pip install fvcore
125-
```
121+
We provide a method to calculate the FLOPs and the number of parameters.
126122

127123
The usage of the script to calculate FLOPs and the number of parameters is as follows.
128124

129125
```shell
130126
python tools/analysis_tools/get_flops.py ${config} --shape ${IMAGE_SHAPE}
131127
```
132128

133-
| ARGS | Type | Description |
134-
| ------- | ---- | ----------------------------------------------------------------------------------------- |
135-
| config | str | (required) Path to the config. |
136-
| --shape | int | Image size to use when calculating FLOPs, such as `--shape 320 320`. Default is `640 640` |
129+
| ARGS | Type | Description |
130+
| ------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
131+
| config | str | (required) Path to the config. |
132+
| --shape | int * \[1-3\] | Image size to use when calculating FLOPs, such as `--shape 320 320`. It can accept 1 to 3 arguments, representing `H&W`, `H, W` and `C, H, W` respectively (C = 3 by default). Default is `640 640` |
133+
| --cfg-options | str | Override configs. [Example](./config.md#command-line-modification) |
137134

138135
For example, you can run the following command to get FLOPs and the number of parameters of `dbnet_resnet18_fpnc_100k_synthtext.py`:
139136

@@ -144,51 +141,13 @@ python tools/analysis_tools/get_flops.py configs/textdet/dbnet/dbnet_resnet18_fp
144141
The output is as follows:
145142

146143
```shell
147-
input shape is (1, 3, 1024, 1024)
148-
| module | #parameters or shape | #flops |
149-
| :------------------------ | :------------------- | :------ |
150-
| model | 12.341M | 63.955G |
151-
| backbone | 11.177M | 38.159G |
152-
| backbone.conv1 | 9.408K | 2.466G |
153-
| backbone.conv1.weight | (64, 3, 7, 7) | |
154-
| backbone.bn1 | 0.128K | 83.886M |
155-
| backbone.bn1.weight | (64,) | |
156-
| backbone.bn1.bias | (64,) | |
157-
| backbone.layer1 | 0.148M | 9.748G |
158-
| backbone.layer1.0 | 73.984K | 4.874G |
159-
| backbone.layer1.1 | 73.984K | 4.874G |
160-
| backbone.layer2 | 0.526M | 8.642G |
161-
| backbone.layer2.0 | 0.23M | 3.79G |
162-
| backbone.layer2.1 | 0.295M | 4.853G |
163-
| backbone.layer3 | 2.1M | 8.616G |
164-
| backbone.layer3.0 | 0.919M | 3.774G |
165-
| backbone.layer3.1 | 1.181M | 4.842G |
166-
| backbone.layer4 | 8.394M | 8.603G |
167-
| backbone.layer4.0 | 3.673M | 3.766G |
168-
| backbone.layer4.1 | 4.721M | 4.837G |
169-
| neck | 0.836M | 14.887G |
170-
| neck.lateral_convs | 0.246M | 2.013G |
171-
| neck.lateral_convs.0.conv | 16.384K | 1.074G |
172-
| neck.lateral_convs.1.conv | 32.768K | 0.537G |
173-
| neck.lateral_convs.2.conv | 65.536K | 0.268G |
174-
| neck.lateral_convs.3.conv | 0.131M | 0.134G |
175-
| neck.smooth_convs | 0.59M | 12.835G |
176-
| neck.smooth_convs.0.conv | 0.147M | 9.664G |
177-
| neck.smooth_convs.1.conv | 0.147M | 2.416G |
178-
| neck.smooth_convs.2.conv | 0.147M | 0.604G |
179-
| neck.smooth_convs.3.conv | 0.147M | 0.151G |
180-
| det_head | 0.329M | 10.909G |
181-
| det_head.binarize | 0.164M | 10.909G |
182-
| det_head.binarize.0 | 0.147M | 9.664G |
183-
| det_head.binarize.1 | 0.128K | 20.972M |
184-
| det_head.binarize.3 | 16.448K | 1.074G |
185-
| det_head.binarize.4 | 0.128K | 83.886M |
186-
| det_head.binarize.6 | 0.257K | 67.109M |
187-
| det_head.threshold | 0.164M | |
188-
| det_head.threshold.0 | 0.147M | |
189-
| det_head.threshold.1 | 0.128K | |
190-
| det_head.threshold.3 | 16.448K | |
191-
| det_head.threshold.4 | 0.128K | |
192-
| det_head.threshold.6 | 0.257K | |
144+
145+
==============================
146+
Compute type: Random input
147+
Input shape: torch.Size([1024, 1024])
148+
Flops: 63.737G
149+
Params: 12.341M
150+
==============================
193151
!!!Please be cautious if you use the results in papers. You may need to check if all ops are supported and verify that the flops computation is correct.
152+
194153
```

0 commit comments

Comments
 (0)