Skip to content

Commit c823c91

Browse files
committed
Added SotaBench
1 parent 761ac94 commit c823c91

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Custom
22
tmp
3+
*.pkl
34

45
# Byte-compiled / optimized / DLL files
56
__pycache__/

sotabench.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import os
2+
import numpy as np
3+
import PIL
4+
import torch
5+
from torch.utils.data import DataLoader
6+
import torchvision.transforms as transforms
7+
from torchvision.datasets import ImageNet
8+
9+
from efficientnet_pytorch import EfficientNet
10+
11+
from sotabencheval.image_classification import ImageNetEvaluator
12+
from sotabencheval.utils import is_server
13+
14+
if is_server():
15+
DATA_ROOT = './.data/vision/imagenet'
16+
else: # local settings
17+
DATA_ROOT = os.environ['IMAGENET_DIR']
18+
assert bool(DATA_ROOT), 'please set IMAGENET_DIR environment variable'
19+
print('Local data root: ', DATA_ROOT)
20+
21+
model_name = 'EfficientNet-B5'
22+
model = EfficientNet.from_pretrained(model_name.lower())
23+
image_size = EfficientNet.get_image_size(model_name.lower())
24+
25+
input_transform = transforms.Compose([
26+
transforms.Resize(image_size, PIL.Image.BICUBIC),
27+
transforms.CenterCrop(image_size),
28+
transforms.ToTensor(),
29+
transforms.Normalize(
30+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
31+
])
32+
33+
test_dataset = ImageNet(
34+
DATA_ROOT,
35+
split="val",
36+
transform=input_transform,
37+
target_transform=None,
38+
)
39+
40+
test_loader = DataLoader(
41+
test_dataset,
42+
batch_size=128,
43+
shuffle=False,
44+
num_workers=4,
45+
pin_memory=True,
46+
)
47+
48+
model = model.cuda()
49+
model.eval()
50+
51+
evaluator = ImageNetEvaluator(model_name=model_name,
52+
paper_arxiv_id='1905.11946')
53+
54+
def get_img_id(image_name):
55+
return image_name.split('/')[-1].replace('.JPEG', '')
56+
57+
with torch.no_grad():
58+
for i, (input, target) in enumerate(test_loader):
59+
input = input.to(device='cuda', non_blocking=True)
60+
target = target.to(device='cuda', non_blocking=True)
61+
output = model(input)
62+
image_ids = [get_img_id(img[0]) for img in test_loader.dataset.imgs[i*test_loader.batch_size:(i+1)*test_loader.batch_size]]
63+
evaluator.add(dict(zip(image_ids, list(output.cpu().numpy()))))
64+
if evaluator.cache_exists:
65+
break
66+
67+
if not is_server():
68+
print("Results:")
69+
print(evaluator.get_results())
70+
71+
evaluator.save()

0 commit comments

Comments
 (0)