From 45a73c23d39f6e0b598371cd0356c4a2f50d1f1f Mon Sep 17 00:00:00 2001
From: Daniel Bolya
Date: Mon, 6 Aug 2018 15:41:40 -0700
Subject: [PATCH] Added the ability to select the config to use in the eval
script.
---
data/config.py | 17 ++++++-----------
eval.py | 8 +++++++-
yolact.py | 7 +++++--
3 files changed, 18 insertions(+), 14 deletions(-)
diff --git a/data/config.py b/data/config.py
index 81e61db1c..2480278c0 100644
--- a/data/config.py
+++ b/data/config.py
@@ -66,13 +66,9 @@ def replace(self, new_config_dict):
'type': ResNetBackbone,
'args': ([3, 4, 23, 3],),
- 'selected_layers': list(range(2, 7)),
- 'pred_scales': [[1, 2], [2], [2], [2], [2]],
- 'pred_aspect_ratios': [[[1], [1.05, 0.62]],
- [[1.29, 0.79, 0.47, 2.33, 0.27]],
- [[1.19, 0.72, 0.43, 2.13, 0.25]],
- [[1.34, 0.84, 0.52, 2.38, 0.30]],
- [[1.40, 0.95, 0.64, 2.16]]],
+ 'selected_layers': list(range(2, 8)),
+ 'pred_scales': [[1]]*6,
+ 'pred_aspect_ratios': [ [[0.66685089, 1.7073535, 0.87508774, 1.16524493, 0.49059086]] ] * 6,
})
resnet50_backbone = resnet101_backbone.copy({
@@ -143,11 +139,10 @@ def replace(self, new_config_dict):
'name': 'yolact_resnet101',
'backbone': resnet101_backbone,
- 'min_size': 400,
- 'max_size': 600,
+ 'max_size': 550,
'train_masks': True,
- 'preserve_aspect_ratio': True,
+ 'preserve_aspect_ratio': False,
'use_prediction_module': True,
'use_yolo_regressors': True,
})
@@ -161,7 +156,7 @@ def replace(self, new_config_dict):
'pred_aspect_ratios': [ [[1], [1, sqrt(2), 1/sqrt(2), sqrt(3), 1/sqrt(3)][:n]] for n in [3, 5, 5, 5, 3, 3] ],
}),
- 'max_size': 600,
+ 'max_size': 550,
'train_masks': True,
'preserve_aspect_ratio': False,
diff --git a/eval.py b/eval.py
index 639127463..aa71ba2c8 100644
--- a/eval.py
+++ b/eval.py
@@ -7,7 +7,7 @@
from utils.functions import sanitize_coordinates, SavePath
import pycocotools
-from data import cfg
+from data import cfg, set_cfg
import numpy as np
import torch
@@ -73,11 +73,16 @@ def str2bool(v):
help='The output file for coco mask results if --coco_results is set.')
parser.add_argument('--max_num_detections', default=100, type=int,
help='The maximum number of detections to consider for each image for mAP scoring. COCO uses 100.')
+parser.add_argument('--config', default=None,
+ help='The config object to use.')
parser.set_defaults(display=False, resume=False, output_coco_json=False, shuffle=False)
args = parser.parse_args()
+if args.config is not None:
+ set_cfg(args.config)
+
iou_thresholds = [x / 100 for x in range(50, 100, 5)]
coco_cats = [] # Call prep_coco_cats to fill this
@@ -577,6 +582,7 @@ def evaluate(net, dataset):
if it > 1:
print('Avg FPS: %.4f' % (1 / frame_times.get_avg()))
plt.imshow(np.clip(img_numpy, 0, 1))
+ plt.title(str(dataset.ids[image_idx]))
plt.show()
else:
if it > 1: fps = 1 / frame_times.get_avg()
diff --git a/yolact.py b/yolact.py
index 629ad93ab..14c9af753 100644
--- a/yolact.py
+++ b/yolact.py
@@ -118,7 +118,7 @@ def make_priors(self, conv_h, conv_w):
# +0.5 because priors are in center-size notation
x = (i + 0.5) / conv_w
y = (j + 0.5) / conv_h
-
+
for scale, ars in zip(self.scales, self.aspect_ratios):
for ar in ars:
w = scale * ar / conv_w
@@ -229,8 +229,11 @@ def forward(self, x):
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
x = torch.zeros((1, 3, cfg.max_size, cfg.max_size))
-
y = net(x)
+
+ for p in net.prediction_layers:
+ print(p.last_conv_size)
+
print()
for a in y:
print(a.size(), torch.sum(a))