Skip to content

Commit

Permalink
add test for mindir export
Browse files Browse the repository at this point in the history
  • Loading branch information
SamitHuang committed Apr 15, 2023
1 parent 648e4e4 commit d2b847e
Showing 1 changed file with 46 additions and 0 deletions.
46 changes: 46 additions & 0 deletions tests/ut/test_mindir_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import sys
sys.path.append('.')
import mindspore as ms
import pytest
import numpy as np
from mindocr import list_models, build_model
from tools.export import export

@pytest.mark.parametrize('name', ['dbnet_resnet50', 'crnn_resnet34'])
def test_mindir_infer(name):
task = 'rec'
if 'db' in name:
task = 'det'

export(name, task, pretrained=True)

fn = f"{name}.mindir"

ms.set_context(mode=ms.GRAPH_MODE)
graph = ms.load(fn)
model = ms.nn.GraphCell(graph)

if task=='rec':
c, h, w = 3, 32, 100
else:
c, h, w = 3, 640, 640

bs = 1
x = ms.Tensor(np.ones([bs, c, h, w]), dtype=ms.float32)

outputs_mindir = model(x)

# get original ckpt outputs
net = build_model(name, pretrained=True)
outputs_ckpt = net(x)

for i, o in enumerate(outputs_mindir):
print('mindir net out: ', outputs_mindir[i].sum(), outputs_mindir[i].shape)
print('ckpt net out: ', outputs_ckpt[i].sum(), outputs_mindir[i].shape)
assert outputs_mindir[i].sum()==outputs_ckpt[i].sum()


if __name__ == '__main__':
names = list_models()
test_mindir_infer(names[0])

0 comments on commit d2b847e

Please sign in to comment.