Skip to content

Commit

Permalink
[run-slow] grounding_dino
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Oct 11, 2024
1 parent c2a9a62 commit 27625a0
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions tests/models/grounding_dino/test_modeling_grounding_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,34 +57,34 @@


def generate_fake_bounding_boxes(n_boxes):
"""Generate bounding boxes in the format (cx, cy, w, h)"""
"""Generate bounding boxes in the format (center_x, center_y, width, height)"""
# Validate the input
if not isinstance(n_boxes, int):
raise ValueError("n_boxes must be an integer")
if n_boxes <= 0:
raise ValueError("n_boxes must be a positive integer")

# Generate random bounding boxes in the format (cx, cy, w, h)
# Generate random bounding boxes in the format (center_x, center_y, width, height)
bounding_boxes = torch.rand((n_boxes, 4))

# Extract the components
cx = bounding_boxes[:, 0]
cy = bounding_boxes[:, 1]
w = bounding_boxes[:, 2]
h = bounding_boxes[:, 3]
center_x = bounding_boxes[:, 0]
center_y = bounding_boxes[:, 1]
width = bounding_boxes[:, 2]
height = bounding_boxes[:, 3]

# Ensure width and height do not exceed bounds
w = torch.min(w, torch.tensor(1.0))
h = torch.min(h, torch.tensor(1.0))
width = torch.min(width, torch.tensor(1.0))
height = torch.min(height, torch.tensor(1.0))

# Ensure the bounding box stays within the normalized space
cx = torch.where(cx - w / 2 < 0, w / 2, cx)
cx = torch.where(cx + w / 2 > 1, 1 - w / 2, cx)
cy = torch.where(cy - h / 2 < 0, h / 2, cy)
cy = torch.where(cy + h / 2 > 1, 1 - h / 2, cy)
center_x = torch.where(center_x - width / 2 < 0, width / 2, center_x)
center_x = torch.where(center_x + width / 2 > 1, 1 - width / 2, center_x)
center_y = torch.where(center_y - height / 2 < 0, height / 2, center_y)
center_y = torch.where(center_y + height / 2 > 1, 1 - height / 2, center_y)

# Combine back into bounding boxes
bounding_boxes = torch.stack([cx, cy, w, h], dim=1)
bounding_boxes = torch.stack([center_x, center_y, width, height], dim=1)

return bounding_boxes

Expand Down

0 comments on commit 27625a0

Please sign in to comment.