Skip to content

Commit 1283877

Browse files
authored
[superglue] fix wrong concatenation which made batching results wrong (#38850)
1 parent f8b8886 commit 1283877

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/transformers/models/superglue/modeling_superglue.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -725,8 +725,8 @@ def _match_image_pair(
725725
matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
726726
matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
727727

728-
matches = torch.cat([matches0, matches1]).reshape(batch_size, 2, -1)
729-
matching_scores = torch.cat([matching_scores0, matching_scores1]).reshape(batch_size, 2, -1)
728+
matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
729+
matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
730730

731731
if output_hidden_states:
732732
all_hidden_states = all_hidden_states + encoded_keypoints[1]

0 commit comments

Comments
 (0)