Skip to content

Commit 5e8d662

Browse files
authored
Update linearchain.py
1 parent fba3d5b commit 5e8d662

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

torch_struct/linearchain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def to_parts(sequence, extra, lengths=None):
104104
batch, N = sequence.shape
105105
labels = torch.zeros(batch, N - 1, C, C).long()
106106
if lengths is None:
107-
lengths = torch.LongTensor([N] * batch).to(edge.device)
107+
lengths = torch.LongTensor([N] * batch)
108108
for n in range(1, N):
109109
labels[torch.arange(batch), n - 1, sequence[:, n], sequence[:, n - 1]] = 1
110110
for b in range(batch):

0 commit comments

Comments
 (0)