Skip to content

Commit cc1e2ec

Browse files
committed
add infer.py
1 parent 2a59d80 commit cc1e2ec

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

tests/yamrt/infer.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
import onnxruntime
2+
import numpy as np
3+
import cv2
4+
import json
5+
import os
6+
import time
7+
8+
providers = ['CPUExecutionProvider']
9+
session = onnxruntime.InferenceSession('resnet18v1/resnet18v1.onnx', providers=providers)
10+
input_name = session.get_inputs()[0].name
11+
output_name = session.get_outputs()[0].name
12+
13+
def load_labels(path):
14+
with open(path) as f:
15+
data = json.load(f)
16+
return np.asarray(data)
17+
18+
def softmax(x):
19+
x = x.reshape(-1)
20+
e_x = np.exp(x - np.max(x))
21+
return e_x / e_x.sum(axis=0)
22+
23+
labels = load_labels('labels.json')
24+
25+
path = '/home/remloveh/.mxnet/datasets/imagenet/val/'
26+
li = os.listdir(path)
27+
li.sort()
28+
29+
total = 0
30+
right = 0
31+
label_num = 0
32+
start = time.time()
33+
for di in li:
34+
ll = os.listdir(path + di)
35+
for dd in ll:
36+
img = cv2.imread(path + di + '/' + dd)
37+
img = cv2.resize(img, (224, 224))
38+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
39+
40+
data = np.array(img).transpose(2, 0, 1)
41+
data = data.astype('float32')
42+
mean_vec = np.array([0.485, 0.456, 0.406])
43+
stddev_vec = np.array([0.229, 0.224, 0.225])
44+
norm_data = np.zeros(data.shape).astype('float32')
45+
46+
for i in range(data.shape[0]):
47+
norm_data[i,:,:] = (data[i,:,:]/255 - mean_vec[i]) / stddev_vec[i]
48+
norm_data = norm_data.reshape(1, 3, 224, 224).astype('float32')
49+
50+
result = session.run([output_name],{input_name:norm_data})
51+
res = softmax(np.array(result)).tolist()
52+
idx = np.argmax(res)
53+
54+
if idx == label_num:
55+
right = right + 1
56+
total = total + 1
57+
label_num = label_num + 1
58+
end = time.time()
59+
60+
print('time: ', int(end - start)/ total, 's')
61+
62+
63+
print('accuracy: ', right/total)

0 commit comments

Comments
 (0)