Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor e2e_eval function and add unit tests #11193

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions .github/ISSUE_TEMPLATE/sweep-template.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
name: Sweep Issue
title: 'Sweep: '
description: For small bugs, features, refactors, and tests to be handled by Sweep, an AI-powered junior developer.
labels: sweep
body:
- type: textarea
id: description
attributes:
label: Details
description: Tell Sweep where and what to edit and provide enough context for a new developer to the codebase
placeholder: |
Unit Tests: Write unit tests for <FILE>. Test each function in the file. Make sure to test edge cases.
Bugs: The bug might be in <FILE>. Here are the logs: ...
Features: the new endpoint should use the ... class from <FILE> because it contains ... logic.
Refactors: We are migrating this function to ... version because ...
43 changes: 43 additions & 0 deletions sweep.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Sweep AI turns bugs & feature requests into code changes (https://sweep.dev)
# For details on our config file, check out our docs at https://docs.sweep.dev/usage/config

# This setting contains a list of rules that Sweep will check for. If any of these rules are broken in a new commit, Sweep will create an pull request to fix the broken rule.
rules:
- "All docstrings and comments should be up to date."
- "There should be no trailing whitespace at the end of lines."
- "Indentation should be consistent throughout the codebase."
- "Variable and function names should be descriptive and follow a consistent naming convention."
- "There should be no commented out code in the codebase."
- "There should be no unused imports in the codebase."

# This is the branch that Sweep will develop from and make pull requests to. Most people use 'main' or 'master' but some users also use 'dev' or 'staging'.
branch: 'main'

# By default Sweep will read the logs and outputs from your existing Github Actions. To disable this, set this to false.
gha_enabled: True

# This is the description of your project. It will be used by sweep when creating PRs. You can tell Sweep what's unique about your project, what frameworks you use, or anything else you want.
#
# Example:
#
# description: sweepai/sweep is a python project. The main api endpoints are in sweepai/api.py. Write code that adheres to PEP8.
description: ''

# This sets whether to create pull requests as drafts. If this is set to True, then all pull requests will be created as drafts and GitHub Actions will not be triggered.
draft: False

# This is a list of directories that Sweep will not be able to edit.
blocked_dirs: []

# This is a list of documentation links that Sweep will use to help it understand your code. You can add links to documentation for any packages you use here.
#
# Example:
#
# docs:
# - PyGitHub: ["https://pygithub.readthedocs.io/en/latest/", "We use pygithub to interact with the GitHub API"]
docs: []

# Sandbox executes commands in a sandboxed environment to validate code changes after every edit to guarantee pristine code. For more details, see the [Sandbox](./sandbox) page.
sandbox:
install: []
check: []
143 changes: 98 additions & 45 deletions tools/end2end/eval_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,59 +112,112 @@ def e2e_eval(gt_dir, res_dir, ignore_blank=False):
else:
dts.append(parts)

dt_match = [False] * len(dts)
gt_match = [False] * len(gts)
all_ious = defaultdict(tuple)
for index_gt, gt in enumerate(gts):
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
gt_poly = polygon_from_str(gt_coors)
for index_dt, dt in enumerate(dts):
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
dt_poly = polygon_from_str(dt_coors)
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
sorted_ious = sorted(
all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
sorted_gt_dt_pairs, gt_match, dt_match = handle_unmatched_dt_gt(dts, gts, iou_thresh)

# matched gt and dt
for gt_dt_pair in sorted_gt_dt_pairs:
index_gt, index_dt = gt_dt_pair
if gt_match[index_gt] == False and dt_match[index_dt] == False:
gt_match[index_gt] = True
dt_match[index_dt] = True
if ignore_blank:
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
else:
gt_str = strQ2B(gts[index_gt][8])
dt_str = strQ2B(dts[index_dt][8])
if ignore_masks[index_gt] == '0':
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
if gt_str == dt_str:
hit += 1
gt_count += 1
dt_count += 1
ed_sum, num_gt_chars, hit, gt_count, dt_count = match_gt_dt_pairs(sorted_gt_dt_pairs, gt_match, dt_match, ignore_blank, gts, dts, ignore_masks, ed_sum, num_gt_chars, hit, gt_count, dt_count)

# unmatched dt
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_str = dts[tindex][8]
gt_str = ''
ed_sum += ed(dt_str, gt_str)
dt_count += 1
ed_sum, dt_count, gt_count, num_gt_chars = calculate_iou_pairs(dt_match, dts, ed_sum, dt_count, gt_match, ignore_masks, gts, num_gt_chars, gt_count)

read_gt_dt_data(hit, dt_count, gt_count, ed_sum, val_names, num_gt_chars)

def calculate_print_metrics(val_names, gt_dir, res_dir):
for i, val_name in enumerate(val_names):
with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
gt_lines = [o.strip() for o in f.readlines()]
gts = []
ignore_masks = []
for line in gt_lines:
parts = line.strip().split('\t')
# ignore illegal data
if len(parts) < 9:
continue
assert (len(parts) < 11)
if len(parts) == 9:
gts.append(parts[:8] + [''])
else:
gts.append(parts[:8] + [parts[-1]])

ignore_masks.append(parts[8])

# unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
if gt_match_flag == False and ignore_masks[tindex] == '0':
dt_str = ''
gt_str = gts[tindex][8]
val_path = os.path.join(res_dir, val_name)
if not os.path.exists(val_path):
dt_lines = []
else:
with open(val_path, encoding='utf-8') as f:
dt_lines = [o.strip() for o in f.readlines()]
dts = []
for line in dt_lines:
# print(line)
parts = line.strip().split("\t")
assert (len(parts) < 10), "line error: {}".format(line)
if len(parts) == 8:
dts.append(parts + [''])
else:
dts.append(parts)
return dts, gts, ignore_masks

def handle_unmatched_dt_gt(dts, gts, iou_thresh):
dt_match = [False] * len(dts)
gt_match = [False] * len(gts)
all_ious = defaultdict(tuple)
for index_gt, gt in enumerate(gts):
gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
gt_poly = polygon_from_str(gt_coors)
for index_dt, dt in enumerate(dts):
dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
dt_poly = polygon_from_str(dt_coors)
iou = polygon_iou(dt_poly, gt_poly)
if iou >= iou_thresh:
all_ious[(index_gt, index_dt)] = iou
sorted_ious = sorted(
all_ious.items(), key=operator.itemgetter(1), reverse=True)
sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
return sorted_gt_dt_pairs, gt_match, dt_match

def match_gt_dt_pairs(sorted_gt_dt_pairs, gt_match, dt_match, ignore_blank, gts, dts, ignore_masks, ed_sum, num_gt_chars, hit, gt_count, dt_count):
# matched gt and dt
for gt_dt_pair in sorted_gt_dt_pairs:
index_gt, index_dt = gt_dt_pair
if gt_match[index_gt] == False and dt_match[index_dt] == False:
gt_match[index_gt] = True
dt_match[index_dt] = True
if ignore_blank:
gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
else:
gt_str = strQ2B(gts[index_gt][8])
dt_str = strQ2B(dts[index_dt][8])
if ignore_masks[index_gt] == '0':
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
if gt_str == dt_str:
hit += 1
gt_count += 1

dt_count += 1
return ed_sum, num_gt_chars, hit, gt_count, dt_count

def calculate_iou_pairs(dt_match, dts, ed_sum, dt_count, gt_match, ignore_masks, gts, num_gt_chars, gt_count):
# unmatched dt
for tindex, dt_match_flag in enumerate(dt_match):
if dt_match_flag == False:
dt_str = dts[tindex][8]
gt_str = ''
ed_sum += ed(dt_str, gt_str)
dt_count += 1

# unmatched gt
for tindex, gt_match_flag in enumerate(gt_match):
if gt_match_flag == False and ignore_masks[tindex] == '0':
dt_str = ''
gt_str = gts[tindex][8]
ed_sum += ed(gt_str, dt_str)
num_gt_chars += len(gt_str)
gt_count += 1
return ed_sum, dt_count, gt_count, num_gt_chars

def read_gt_dt_data(hit, dt_count, gt_count, ed_sum, val_names, num_gt_chars):
eps = 1e-9
print('hit, dt_count, gt_count', hit, dt_count, gt_count)
precision = hit / (dt_count + eps)
Expand Down
31 changes: 31 additions & 0 deletions tools/end2end/eval_end2end_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import unittest
from shapely.geometry import Polygon
from tools.end2end.eval_end2end import calculate_iou, calculate_edit_distance, calculate_metrics, match_gt_and_dt

class TestEvalEnd2End(unittest.TestCase):

def test_calculate_iou(self):
poly1 = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
poly2 = Polygon([(1, 1), (2, 1), (2, 2), (1, 2)])
poly3 = Polygon([(0.5, 0.5), (1.5, 0.5), (1.5, 1.5), (0.5, 1.5)])
self.assertEqual(calculate_iou(poly1, poly2), 0)
self.assertEqual(calculate_iou(poly1, poly3), 0.25)
self.assertEqual(calculate_iou(poly1, poly1), 1)

def test_calculate_edit_distance(self):
self.assertEqual(calculate_edit_distance('test', 'test'), 0)
self.assertEqual(calculate_edit_distance('test', 'tent'), 1)
self.assertEqual(calculate_edit_distance('test', 'best'), 1)
self.assertEqual(calculate_edit_distance('test', 'tests'), 1)
self.assertEqual(calculate_edit_distance('test', 'testing'), 3)

def test_calculate_metrics(self):
self.assertEqual(calculate_metrics(5, 10, 10), (0.5, 0.5, 0.5))
self.assertEqual(calculate_metrics(7, 10, 10), (0.7, 0.7, 0.7))
self.assertEqual(calculate_metrics(10, 10, 10), (1, 1, 1))

def test_match_gt_and_dt(self):
# TODO: Implement this test based on the function's behavior and input/output

if __name__ == '__main__':
unittest.main()