Skip to content

Commit 1b2a564

Browse files
authored
Update alignment.py
1 parent 40b23a5 commit 1b2a564

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_struct/alignment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _check_potentials(self, edge, lengths=None):
3939
assert M >= N
4040

4141
if lengths is None:
42-
lengths = torch.LongTensor([N] * batch)
42+
lengths = torch.LongTensor([N] * batch).to(edge.device)
4343

4444
assert max(lengths) <= N, "Length longer than edge scores"
4545
assert max(lengths) == N, "One length must be at least N"

0 commit comments

Comments
 (0)