Skip to content

Commit

Permalink
added <difficult> option
Browse files Browse the repository at this point in the history
  • Loading branch information
Cartucho committed Jul 25, 2018
1 parent 01162fd commit 7d146c6
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 30 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,14 @@ In the [extra](https://github.com/Cartucho/mAP/tree/master/extra) folder you can
- Use **matching names** (e.g. image: "image_1.jpg", ground-truth: "image_1.txt"; "image_2.jpg", "image_2.txt"...).
- In these files, each line should be in the following format:
```
<class_name> <left> <top> <right> <bottom>
<class_name> <left> <top> <right> <bottom> [<difficult>]
```
- The `difficult` parameter is optional, use it if you want to ignore a specific prediction.
- E.g. "image_1.txt":
```
tvmonitor 2 10 173 238
book 439 157 556 241
book 437 246 518 351
book 437 246 518 351 difficult
pottedplant 272 190 316 259
```
#### Create the predicted objects files
Expand Down
5 changes: 2 additions & 3 deletions extra/remove_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,15 @@ def rename_class(current_class_name, new_class_name):
with open('class_list.txt') as f:
for line in f:
current_class_name = line.rstrip("\n")
new_class_name = line.replace(' ', args.delimiter).rstrip("\n)
if line == new_class_name:
new_class_name = line.replace(' ', args.delimiter).rstrip("\n")
if current_class_name == new_class_name:
continue
y_n_message = ("Are you sure you want "
"to rename the class "
"\"" + current_class_name + "\" "
"into \"" + new_class_name + "\"?"
)


if query_yes_no(y_n_message, bypass=args.yes):
os.chdir("../ground-truth")
rename_class(current_class_name, new_class_name)
Expand Down
60 changes: 35 additions & 25 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,27 +306,36 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
lines_list = file_lines_to_list(txt_file)
# create ground-truth dictionary
bounding_boxes = []
is_difficult = False
for line in lines_list:
try:
class_name, left, top, right, bottom = line.split()
if "difficult" in line:
class_name, left, top, right, bottom, _difficult = line.split()
is_difficult = True
else:
class_name, left, top, right, bottom = line.split()
except ValueError:
error_msg = "Error: File " + txt_file + " in the wrong format.\n"
error_msg += " Expected: <class_name> <left> <top> <right> <bottom>\n"
error_msg += " Expected: <class_name> <left> <top> <right> <bottom> ['difficult']\n"
error_msg += " Received: " + line
error_msg += "\n\nIf you have a <class_name> with spaces between words you should remove them\n"
error_msg += "by running the script \"rename_class.py\" in the \"extra/\" folder."
error_msg += "by running the script \"remove_space.py\" or \"rename_class.py\" in the \"extra/\" folder."
error(error_msg)
# check if class is in the ignore list, if yes skip
if class_name in args.ignore:
continue
bbox = left + " " + top + " " + right + " " +bottom
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
# count that object
if class_name in gt_counter_per_class:
gt_counter_per_class[class_name] += 1
if is_difficult:
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
is_difficult = False
else:
# if class didn't exist yet
gt_counter_per_class[class_name] = 1
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
# count that object
if class_name in gt_counter_per_class:
gt_counter_per_class[class_name] += 1
else:
# if class didn't exist yet
gt_counter_per_class[class_name] = 1
# dump bounding_boxes into a ".json" file
with open(tmp_files_path + "/" + file_id + "_ground_truth.json", 'w') as outfile:
json.dump(bounding_boxes, outfile)
Expand Down Expand Up @@ -466,7 +475,7 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
ovmax = ov
gt_match = obj

# assign prediction as true positive or false positive
# assign prediction as true positive/don't care/false positive
if show_animation:
status = "NO MATCH FOUND!" # status is only used in the animation
# set minimum overlap
Expand All @@ -476,21 +485,22 @@ def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, out
index = specific_iou_classes.index(class_name)
min_overlap = float(iou_list[index])
if ovmax >= min_overlap:
if not bool(gt_match["used"]):
# true positive
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
# update the ".json" file
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
if show_animation:
status = "MATCH!"
else:
# false positive (multiple detection)
fp[idx] = 1
if show_animation:
status = "REPEATED MATCH!"
if "difficult" not in gt_match:
if not bool(gt_match["used"]):
# true positive
tp[idx] = 1
gt_match["used"] = True
count_true_positives[class_name] += 1
# update the ".json" file
with open(gt_file, 'w') as f:
f.write(json.dumps(ground_truth_data))
if show_animation:
status = "MATCH!"
else:
# false positive (multiple detection)
fp[idx] = 1
if show_animation:
status = "REPEATED MATCH!"
else:
# false positive
fp[idx] = 1
Expand Down

0 comments on commit 7d146c6

Please sign in to comment.