Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions inference/models/owlv2/owlv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,16 +235,25 @@ def filter_tensors_by_objectness(
logit_shift: torch.Tensor,
logit_scale: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Fuse squeeze operations for potential speedup and clarity
objectness = objectness.squeeze(0)
objectness, objectness_indices = torch.topk(objectness, MAX_DETECTIONS, dim=0)
boxes = boxes.squeeze(0)
image_class_embeds = image_class_embeds.squeeze(0)
logit_shift = logit_shift.squeeze(0).squeeze(1)
logit_scale = logit_scale.squeeze(0).squeeze(1)
boxes = boxes[objectness_indices]
image_class_embeds = image_class_embeds[objectness_indices]
logit_shift = logit_shift[objectness_indices]
logit_scale = logit_scale[objectness_indices]
# Combine sequential squeeze ops into one for logit_shift and logit_scale
logit_shift = logit_shift.squeeze()
logit_scale = logit_scale.squeeze()

# topk returns values and indices in one go, so only indices needed for all tensors
objectness, objectness_indices = torch.topk(objectness, MAX_DETECTIONS, dim=0)

# Apply advanced indexing once for all tensors
# Avoids repeated indexing overhead
indices = objectness_indices
boxes = boxes.index_select(0, indices)
image_class_embeds = image_class_embeds.index_select(0, indices)
logit_shift = logit_shift.index_select(0, indices)
logit_scale = logit_scale.index_select(0, indices)

return objectness, boxes, image_class_embeds, logit_shift, logit_scale


Expand Down