Skip to content

Commit

Permalink
Add feature to display confidence score (#426, #399)
Browse files Browse the repository at this point in the history
  • Loading branch information
CVHub520 committed May 27, 2024
1 parent 08a1081 commit e6de5e6
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 23 deletions.
1 change: 1 addition & 0 deletions anylabeling/configs/auto_labeling/ch_ppocr_v4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ display_name: ch_PP-OCRv4 (PaddleOCR)
det_model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.2/ch_PP-OCRv4_det_infer.onnx
rec_model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.2/ch_PP-OCRv4_rec_infer.onnx
cls_model_path: https://github.com/CVHub520/X-AnyLabeling/releases/download/v0.2.2/ch_ppocr_mobile_v2.0_cls_infer.onnx
drop_score: 0.5
use_angle_cls: True
10 changes: 7 additions & 3 deletions anylabeling/services/auto_labeling/__base__/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,13 @@ def predict_shapes(self, image, image_path=None):
if self.task == "track":
image_shape = image.shape[:2][::-1]
results = np.concatenate((boxes, scores, class_ids), axis=1)
boxes, track_ids, _, class_ids = self.tracker.track(
boxes, track_ids, scores, class_ids = self.tracker.track(
results, image_shape
)

shapes = []
for box, class_id, point, track_id in zip(
boxes, class_ids, points, track_ids
for box, class_id, point, track_id, score in zip(
boxes, class_ids, points, track_ids, scores
):
if (
self.show_boxes and self.task != "track"
Expand All @@ -318,6 +318,7 @@ def predict_shapes(self, image, image_path=None):
shape.line_color = "#000000"
shape.line_width = 1
shape.label = str(self.classes[int(class_id)])
shape.score = float(score)
shape.selected = False
shapes.append(shape)
if self.task == "seg":
Expand All @@ -330,6 +331,7 @@ def predict_shapes(self, image, image_path=None):
shape.line_color = "#000000"
shape.line_width = 1
shape.label = str(self.classes[int(class_id)])
shape.score = float(score)
shape.selected = False
shapes.append(shape)
if self.task == "track":
Expand All @@ -346,6 +348,7 @@ def predict_shapes(self, image, image_path=None):
shape.line_color = "#000000"
shape.line_width = 1
shape.label = str(self.classes[int(class_id)])
shape.score = float(score)
shape.selected = False
shapes.append(shape)
if self.task == "obb":
Expand All @@ -367,6 +370,7 @@ def predict_shapes(self, image, image_path=None):
shape.line_color = "#000000"
shape.line_width = 1
shape.label = str(self.classes[int(class_id)])
shape.score = float(score)
shape.selected = False
shapes.append(shape)

Expand Down
2 changes: 2 additions & 0 deletions anylabeling/services/auto_labeling/damo_yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def postprocess(self, predictions, ratio_hw):
"xmax": x + w,
"ymax": y + h,
"label": str(self.classes[int(class_ids[i])]),
"score": float(confidences[i])
}
output_infos.append(output_info)

Expand Down Expand Up @@ -135,6 +136,7 @@ def predict_shapes(self, image, image_path=None):
for result in results:
shape = Shape(
label=result["label"],
score=result["score"],
shape_type="rectangle",
)
xmin = result["xmin"]
Expand Down
8 changes: 6 additions & 2 deletions anylabeling/services/auto_labeling/ppocr_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ def __init__(self, model_config, on_message) -> None:
self.det_net = self.load_model("det_model_path")
self.rec_net = self.load_model("rec_model_path")
self.cls_net = self.load_model("cls_model_path")
self.drop_score = self.config.get("drop_score", 0.5)
self.use_angle_cls = self.config["use_angle_cls"]
self.current_dir = os.path.dirname(__file__)

Expand Down Expand Up @@ -133,7 +134,7 @@ def parse_args(self):
self.current_dir, "configs", "ppocr_keys_v1.txt"
),
use_space_char=True,
drop_score=0.5,
drop_score=self.drop_score,
# params for e2e
e2e_algorithm="PGNet",
e2e_model_dir="",
Expand Down Expand Up @@ -182,25 +183,28 @@ def predict_shapes(self, image, image_path=None):

args = self.parse_args()
text_sys = TextSystem(args)
dt_boxes, rec_res = text_sys(image)
dt_boxes, rec_res, scores = text_sys(image)

results = [
{
"description": rec_res[i][0],
"points": np.array(dt_boxes[i]).astype(np.int32).tolist(),
"score": float(scores[i])
}
for i in range(len(dt_boxes))
]

shapes = []
for i, res in enumerate(results):
score = res["score"]
points = res["points"]
description = res["description"]
pt1, pt2, pt3, pt4 = points
pt2 = [pt3[0], pt1[1]]
pt4 = [pt1[0], pt3[1]]
shape = Shape(
label="text",
score=score,
shape_type="rectangle",
group_id=int(i),
description=description,
Expand Down
28 changes: 16 additions & 12 deletions anylabeling/services/auto_labeling/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,26 +120,26 @@ def postprocess(self, input_image, outputs):
).astype(int)
boxes = np.stack([x1, y1, x2, y2], axis=1)

output_boxes = []
results = []
for box, index, score in zip(boxes, indexs, scores):
x1 = box[0]
y1 = box[1]
x2 = box[2]
y2 = box[3]
label = str(self.classes[index])

output_box = {
result = {
"x1": x1,
"y1": y1,
"x2": x2,
"y2": y2,
"label": label,
"score": score,
"score": float(score),
}

output_boxes.append(output_box)
results.append(result)

return output_boxes
return results

def predict_shapes(self, image, image_path=None):
"""
Expand All @@ -160,15 +160,19 @@ def predict_shapes(self, image, image_path=None):
detections = self.net.get_ort_inference(
blob, extract=True, squeeze=True
)
boxes = self.postprocess(image, detections)
results = self.postprocess(image, detections)
shapes = []

for box in boxes:
xmin = box["x1"]
ymin = box["y1"]
xmax = box["x2"]
ymax = box["y2"]
shape = Shape(label=box["label"], shape_type="rectangle")
for result in results:
xmin = result["x1"]
ymin = result["y1"]
xmax = result["x2"]
ymax = result["y2"]
shape = Shape(
label=result["label"],
score=result["score"],
shape_type="rectangle"
)
shape.add_point(QtCore.QPointF(xmin, ymin))
shape.add_point(QtCore.QPointF(xmax, ymin))
shape.add_point(QtCore.QPointF(xmax, ymax))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -983,14 +983,15 @@ def __call__(self, img, cls=True):

rec_res = self.text_recognizer(img_crop_list)

filter_boxes, filter_rec_res = [], []
filter_boxes, filter_rec_res, scores = [], [], []
for box, rec_result in zip(dt_boxes, rec_res):
text, score = rec_result
if score >= self.drop_score:
filter_boxes.append(box)
filter_rec_res.append(rec_result)
scores.append(score)

return filter_boxes, filter_rec_res
return filter_boxes, filter_rec_res, scores


def build_post_process(config, global_config=None):
Expand Down
3 changes: 2 additions & 1 deletion anylabeling/services/auto_labeling/yolo_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def predict_shapes(self, image, image_path=None):

shapes = []
for i in selected:
score = float(scores[i])
label = str(self.config["classes"][classes[i]])
if self.filter_classes and label not in self.filter_classes:
continue
Expand All @@ -296,7 +297,7 @@ def predict_shapes(self, image, image_path=None):
ymin = y
xmax = x + w
ymax = y + h
shape = Shape(label=label, shape_type="rectangle", flags={})
shape = Shape(label=label, score=score, shape_type="rectangle", flags={})
shape.add_point(QtCore.QPointF(xmin, ymin))
shape.add_point(QtCore.QPointF(xmax, ymin))
shape.add_point(QtCore.QPointF(xmax, ymax))
Expand Down
1 change: 1 addition & 0 deletions anylabeling/services/auto_labeling/yolow.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def predict_shapes(self, image, image_path=None):
continue
xmin, ymin, xmax, ymax = bbox
rectangle_shape = Shape(
score=float(score),
label=str(self.classes[int(cls_id)]),
shape_type="rectangle",
)
Expand Down
3 changes: 2 additions & 1 deletion anylabeling/services/auto_labeling/yolox.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ def predict_shapes(self, image, image_path=None):
if score < self.config["confidence_threshold"]:
continue
x1, y1, x2, y2 = box
score = float(score)
label = str(self.classes[int(cls_inds)])
rectangle_shape = Shape(label=label, shape_type="rectangle")
rectangle_shape = Shape(label=label, score=score, shape_type="rectangle")
rectangle_shape.add_point(QtCore.QPointF(x1, y1))
rectangle_shape.add_point(QtCore.QPointF(x2, y1))
rectangle_shape.add_point(QtCore.QPointF(x2, y2))
Expand Down
2 changes: 2 additions & 0 deletions anylabeling/views/labeling/label_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def load(self, filename):
]
shape_keys = [
"label",
"score",
"points",
"group_id",
"difficult",
Expand Down Expand Up @@ -115,6 +116,7 @@ def load(self, filename):
shapes = [
{
"label": s["label"],
"score": s.get("score", None),
"points": s["points"],
"shape_type": s.get("shape_type", "polygon"),
"flags": s.get("flags", {}),
Expand Down
3 changes: 3 additions & 0 deletions anylabeling/views/labeling/label_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -2587,6 +2587,7 @@ def load_labels(self, shapes):
s = []
for shape in shapes:
label = shape["label"]
score = shape.get("score", None)
points = shape["points"]
shape_type = shape["shape_type"]
flags = shape["flags"]
Expand All @@ -2603,6 +2604,7 @@ def load_labels(self, shapes):

shape = Shape(
label=label,
score=score,
shape_type=shape_type,
group_id=group_id,
description=description,
Expand Down Expand Up @@ -2657,6 +2659,7 @@ def format_shape(s):
data = s.other_data.copy()
info = {
"label": s.label,
"score": s.score,
"points": [(p.x(), p.y()) for p in s.points],
"group_id": s.group_id,
"description": s.description,
Expand Down
2 changes: 2 additions & 0 deletions anylabeling/views/labeling/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Shape:
def __init__(
self,
label=None,
score=None,
line_color=None,
shape_type=None,
flags=None,
Expand All @@ -56,6 +57,7 @@ def __init__(
attributes={},
):
self.label = label
self.score = score
self.group_id = group_id
self.description = description
self.difficult = difficult
Expand Down
10 changes: 8 additions & 2 deletions anylabeling/views/labeling/widgets/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1185,7 +1185,10 @@ def paintEvent(self, event): # noqa: C901
pen = QtGui.QPen(QtGui.QColor("#FFA500"), 8, Qt.SolidLine)
p.setPen(pen)
for shape in self.shapes:
label = shape.label
if shape.score is not None:
label = f"{shape.label} {shape.score:.2f}"
else:
label = shape.label
d = shape.point_size / shape.scale
if label:
try:
Expand All @@ -1207,7 +1210,10 @@ def paintEvent(self, event): # noqa: C901
p.setPen(pen)
for shape in self.shapes:
d = 1.5 # default shape sacle
label = shape.label
if shape.score is not None:
label = f"{shape.label} {shape.score:.2f}"
else:
label = shape.label
if label:
try:
bbox = shape.bounding_rect()
Expand Down

0 comments on commit e6de5e6

Please sign in to comment.