Skip to content

Commit

Permalink
Minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Jan 23, 2022
1 parent 43d4f6e commit 437cb85
Showing 1 changed file with 7 additions and 13 deletions.
20 changes: 7 additions & 13 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,7 @@
("yolov5s", "r6.0", "v6.0", "c3b140f3", False),
],
)
def test_load_from_ultralytics(
arch: str,
version: str,
upstream_version: str,
hash_prefix: str,
use_p6: bool,
):
def test_load_from_ultralytics(arch, version, upstream_version, hash_prefix, use_p6):
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)
Expand All @@ -53,7 +47,7 @@ def test_load_from_ultralytics(
"arch, version, upstream_version, hash_prefix",
[("yolov5s-VOC", "r4.0", "v5.0", "23818cff")],
)
def test_load_from_ultralytics_voc(arch: str, version: str, upstream_version: str, hash_prefix: str):
def test_load_from_ultralytics_voc(arch, version, upstream_version, hash_prefix):
img_path = "test/assets/bus.jpg"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
Expand All @@ -76,17 +70,17 @@ def test_load_from_ultralytics_voc(arch: str, version: str, upstream_version: st
with torch.no_grad():
outs = model_yolov5(img[None])[0]
outs = non_max_suppression(outs, conf, iou)
out_from_yolov5 = outs[0]
out_yolov5 = outs[0]

# Define yolort model
model_yolort = YOLO.load_from_yolov5(checkpoint_path, score_thresh=conf, version=version)
model_yolort.eval()
with torch.no_grad():
out_from_yolort = model_yolort(img[None])
out_yolort = model_yolort(img[None])

torch.testing.assert_allclose(out_from_yolort[0]["boxes"], out_from_yolov5[:, :4])
torch.testing.assert_allclose(out_from_yolort[0]["scores"], out_from_yolov5[:, 4])
torch.testing.assert_allclose(out_from_yolort[0]["labels"], out_from_yolov5[:, 5].to(dtype=torch.int64))
torch.testing.assert_allclose(out_yolort[0]["boxes"], out_yolov5[:, :4])
torch.testing.assert_allclose(out_yolort[0]["scores"], out_yolov5[:, 4])
torch.testing.assert_allclose(out_yolort[0]["labels"], out_yolov5[:, 5].to(dtype=torch.int64))


def test_read_image_to_tensor():
Expand Down

0 comments on commit 437cb85

Please sign in to comment.