Skip to content

Add scripts to evaluate models in the zoo on different datasets #69

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/config/image_classification_mobilenetv1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ Benchmark:
Model:
name: "MobileNetV1"
modelPath: "models/image_classification_mobilenet/image_classification_mobilenetv1_2022apr.onnx"
labelPath: "models/image_classification_mobilenet/imagenet_labels.txt"

2 changes: 1 addition & 1 deletion benchmark/config/image_classification_mobilenetv2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ Benchmark:
Model:
name: "MobileNetV2"
modelPath: "models/image_classification_mobilenet/image_classification_mobilenetv2_2022apr.onnx"
labelPath: "models/image_classification_mobilenet/imagenet_labels.txt"

2 changes: 1 addition & 1 deletion benchmark/config/image_classification_ppresnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ Benchmark:
Model:
name: "PPResNet"
modelPath: "models/image_classification_ppresnet/image_classification_ppresnet50_2022jan.onnx"
labelPath: "models/image_classification_ppresnet/imagenet_labels.txt"

12 changes: 12 additions & 0 deletions models/image_classification_mobilenet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applicatio

MobileNetV2: Inverted Residuals and Linear Bottlenecks

Results of accuracy evaluation with [tools/eval](../../tools/eval).

| Models | Top-1 Accuracy | Top-5 Accuracy |
| ------ | -------------- | -------------- |
| MobileNet V1 | 67.64 | 87.97 |
| MobileNet V1 quant | 55.53 | 78.74 |
| MobileNet V2 | 69.44 | 89.23 |
| MobileNet V2 quant | 68.37 | 88.56 |

\*: 'quant' stands for 'quantized'.

## Demo

Run the following command to try the demo:
Expand All @@ -24,3 +35,4 @@ All files in this directory are licensed under [Apache 2.0 License](./LICENSE).
- MobileNet V2: https://arxiv.org/abs/1801.04381
- MobileNet V1 weight and scripts for training: https://github.com/wjc852456/pytorch-mobilenet-v1
- MobileNet V2 weight: https://github.com/onnx/models/tree/main/vision/classification/mobilenet

28 changes: 20 additions & 8 deletions models/image_classification_mobilenet/mobilenet_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import cv2 as cv

class MobileNetV1:
def __init__(self, modelPath, labelPath, backendId=0, targetId=0):
def __init__(self, modelPath, labelPath=None, topK=1, backendId=0, targetId=0):
self.model_path = modelPath
self.label_path = labelPath
assert topK >= 1
self.top_k = topK
self.backend_id = backendId
self.target_id = targetId

Expand All @@ -23,9 +25,10 @@ def __init__(self, modelPath, labelPath, backendId=0, targetId=0):

def _load_labels(self):
labels = []
with open(self.label_path, 'r') as f:
for line in f:
labels.append(line.strip())
if self.label_path is not None:
with open(self.label_path, 'r') as f:
for line in f:
labels.append(line.strip())
return labels

@property
Expand Down Expand Up @@ -61,9 +64,18 @@ def infer(self, image):
return results

def _postprocess(self, output_blob):
predicted_labels = []
batched_class_id_list = []
for o in output_blob:
class_id = np.argmax(o)
predicted_labels.append(self.labels[class_id])
return predicted_labels
class_id_list = o.argsort()[::-1][:self.top_k]
batched_class_id_list.append(class_id_list)
if len(self.labels) > 0:
batched_predicted_labels = []
for class_id_list in batched_class_id_list:
predicted_labels = []
for class_id in class_id_list:
predicted_labels.append(self._labels[class_id])
batched_predicted_labels.append(predicted_labels)
return batched_predicted_labels
else:
return batched_class_id_list

28 changes: 20 additions & 8 deletions models/image_classification_mobilenet/mobilenet_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import cv2 as cv

class MobileNetV2:
def __init__(self, modelPath, labelPath, backendId=0, targetId=0):
def __init__(self, modelPath, labelPath=None, topK=1, backendId=0, targetId=0):
self.model_path = modelPath
self.label_path = labelPath
assert topK >= 1
self.top_k = topK
self.backend_id = backendId
self.target_id = targetId

Expand All @@ -23,9 +25,10 @@ def __init__(self, modelPath, labelPath, backendId=0, targetId=0):

def _load_labels(self):
labels = []
with open(self.label_path, 'r') as f:
for line in f:
labels.append(line.strip())
if self.label_path is not None:
with open(self.label_path, 'r') as f:
for line in f:
labels.append(line.strip())
return labels

@property
Expand Down Expand Up @@ -61,9 +64,18 @@ def infer(self, image):
return results

def _postprocess(self, output_blob):
predicted_labels = []
batched_class_id_list = []
for o in output_blob:
class_id = np.argmax(o)
predicted_labels.append(self.labels[class_id])
return predicted_labels
class_id_list = o.argsort()[::-1][:self.top_k]
batched_class_id_list.append(class_id_list)
if len(self.labels) > 0:
batched_predicted_labels = []
for class_id_list in batched_class_id_list:
predicted_labels = []
for class_id in class_id_list:
predicted_labels.append(self._labels[class_id])
batched_predicted_labels.append(predicted_labels)
return batched_predicted_labels
else:
return batched_class_id_list

12 changes: 11 additions & 1 deletion models/image_classification_ppresnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@ Deep Residual Learning for Image Recognition

This model is ported from [PaddleHub](https://github.com/PaddlePaddle/PaddleHub) using [this script from OpenCV](https://github.com/opencv/opencv/blob/master/samples/dnn/dnn_model_runner/dnn_conversion/paddlepaddle/paddle_resnet50.py).

Results of accuracy evaluation with [tools/eval](../../tools/eval).

| Models | Top-1 Accuracy | Top-5 Accuracy |
| ------ | -------------- | -------------- |
| PP-ResNet | 82.28 | 96.15 |
| PP-ResNet quant | 0.22 | 0.96 |

\*: 'quant' stands for 'quantized'.

## Demo

Run the following command to try the demo:
Expand All @@ -19,4 +28,5 @@ All files in this directory are licensed under [Apache 2.0 License](./LICENSE).

- https://arxiv.org/abs/1512.03385
- https://github.com/opencv/opencv/tree/master/samples/dnn/dnn_model_runner/dnn_conversion/paddlepaddle
- https://github.com/PaddlePaddle/PaddleHub
- https://github.com/PaddlePaddle/PaddleHub

29 changes: 22 additions & 7 deletions models/image_classification_ppresnet/ppresnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import cv2 as cv

class PPResNet:
def __init__(self, modelPath, labelPath, backendId=0, targetId=0):
def __init__(self, modelPath, labelPath=None, topK=1, backendId=0, targetId=0):
self._modelPath = modelPath
self._labelPath = labelPath
assert topK >= 1
self._topK = topK
self._backendId = backendId
self._targetId = targetId

Expand All @@ -30,9 +32,10 @@ def __init__(self, modelPath, labelPath, backendId=0, targetId=0):

def _load_labels(self):
labels = []
with open(self._labelPath, 'r') as f:
for line in f:
labels.append(line.strip())
if self._labelPath is not None:
with open(self._labelPath, 'r') as f:
for line in f:
labels.append(line.strip())
return labels

@property
Expand Down Expand Up @@ -65,11 +68,23 @@ def infer(self, image):
outputBlob = self._model.forward(self._outputNames)

# Postprocess
results = self._postprocess(outputBlob)
results = self._postprocess(outputBlob[0])

return results

def _postprocess(self, outputBlob):
class_id = np.argmax(outputBlob[0])
return self._labels[class_id]
batched_class_id_list = []
for ob in outputBlob:
class_id_list = ob.argsort()[::-1][:self._topK]
batched_class_id_list.append(class_id_list)
if len(self._labels) > 0:
batched_predicted_labels = []
for class_id_list in batched_class_id_list:
predicted_labels = []
for class_id in class_id_list:
predicted_labels.append(self._labels[class_id])
batched_predicted_labels.append(predicted_labels)
return batched_predicted_labels
else:
return batched_class_id_list

55 changes: 55 additions & 0 deletions tools/eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Accuracy evaluation of models in OpenCV Zoo

Make sure you have the following packages installed:

```shell
pip install tqdm
```

Generally speaking, evaluation can be done with the following command:

```shell
python eval.py -m model_name -d dataset_name -dr dataset_root_dir
```

Supported datasets:
- [ImageNet](./datasets/imagenet.py)

## ImageNet

### Prepare data

Please visit https://image-net.org/ to download the ImageNet dataset and [the labels from caffe](http://dl.caffe.berkeleyvision.org/caffe_ilsvrc12.tar.gz). Organize files as follow:

```shell
$ tree -L 2 /path/to/imagenet
.
├── caffe_ilsvrc12
│   ├── det_synset_words.txt
│   ├── imagenet.bet.pickle
│   ├── imagenet_mean.binaryproto
│   ├── synsets.txt
│   ├── synset_words.txt
│   ├── test.txt
│   ├── train.txt
│   └── val.txt
├── caffe_ilsvrc12.tar.gz
├── ILSVRC
│   ├── Annotations
│   ├── Data
│   └── ImageSets
├── imagenet_object_localization_patched2019.tar.gz
├── LOC_sample_submission.csv
├── LOC_synset_mapping.txt
├── LOC_train_solution.csv
└── LOC_val_solution.csv
```

### Evaluation

Run evaluation with the following command:

```shell
python eval.py -m mobilenet -d imagenet -dr /path/to/imagenet
```

15 changes: 15 additions & 0 deletions tools/eval/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from .imagenet import ImageNet

class Registery:
def __init__(self, name):
self._name = name
self._dict = dict()

def get(self, key):
return self._dict[key]

def register(self, item):
self._dict[item.__name__] = item

DATASETS = Registery("Datasets")
DATASETS.register(ImageNet)
64 changes: 64 additions & 0 deletions tools/eval/datasets/imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import os

import numpy as np
import cv2 as cv

from tqdm import tqdm

class ImageNet:
def __init__(self, root, size=224):
self.root = root
self.size = size
self.top1_acc = -1
self.top5_acc = -1

self.root_val = os.path.join(self.root, "ILSVRC", "Data", "CLS-LOC", "val")
self.val_label_file = os.path.join(self.root, "caffe_ilsvrc12", "val.txt")

self.val_label = self.load_label(self.val_label_file)

@property
def name(self):
return self.__class__.__name__

def load_label(self, label_file):
label = list()
with open(label_file, "r") as f:
for line in f:
line = line.strip()
key, value = line.split()

key = os.path.join(self.root_val, key)
value = int(value)

label.append([key, value])

return label

def eval(self, model):
top_1_hits = 0
top_5_hits = 0
pbar = tqdm(self.val_label)
for fn, label in pbar:
pbar.set_description("Evaluating {} with {} val set".format(model.name, self.name))

img = cv.imread(fn)
img = cv.cvtColor(img, cv.COLOR_BGR2RGB)
img = cv.resize(img, dsize=(256, 256))
img = img[16:240, 16:240, :]

pred = model.infer(img)
if label == pred[0][0]:
top_1_hits += 1
if label in pred[0]:
top_5_hits += 1

self.top1_acc = top_1_hits/(len(self.val_label) * 1.0)
self.top5_acc = top_5_hits/(len(self.val_label) * 1.0)

def get_result(self):
return self.top1_acc, self.top5_acc

def print_result(self):
print("Top-1 Accuracy: {:.2f}%; Top-5 Accuracy: {:.2f}%".format(self.top1_acc*100, self.top5_acc*100))

Loading