We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 40b23a5 commit 1b2a564Copy full SHA for 1b2a564
torch_struct/alignment.py
@@ -39,7 +39,7 @@ def _check_potentials(self, edge, lengths=None):
39
assert M >= N
40
41
if lengths is None:
42
- lengths = torch.LongTensor([N] * batch)
+ lengths = torch.LongTensor([N] * batch).to(edge.device)
43
44
assert max(lengths) <= N, "Length longer than edge scores"
45
assert max(lengths) == N, "One length must be at least N"
0 commit comments