Skip to content

Commit

Permalink
Add OWLv2 non-square pixel fix
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Oct 25, 2024
1 parent 07a58f0 commit 057a9b4
Showing 1 changed file with 18 additions and 2 deletions.
20 changes: 18 additions & 2 deletions datadreamer/dataset_annotation/owlv2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,22 @@ def annotate_batch(
torch.cat(all_labels), num_classes=len(prompts)
)

# Fix the bounding boxes
width_ratio = 1
height_ratio = 1
width = images[i].width
height = images[i].height
if width > height:
height_ratio = height / width
elif height > width:
width_ratio = width / height

all_boxes = [
box
/ torch.tensor([width_ratio, height_ratio, width_ratio, height_ratio])
for box in all_boxes
]

# Apply NMS
# transform predictions to shape [N, 5 + num_classes], N is the number of bboxes for nms function
all_boxes_cat = torch.cat(
Expand Down Expand Up @@ -294,8 +310,8 @@ def release(self, empty_cuda_cache: bool = False) -> None:

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = OWLv2Annotator(device="cpu", size="large")
annotator = OWLv2Annotator(device="cpu", size="base")
final_boxes, final_scores, final_labels = annotator.annotate_batch(
[im], ["robot", "horse"]
[im], ["bus", "person"]
)
annotator.release()

0 comments on commit 057a9b4

Please sign in to comment.