We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 7013db4 commit ede76f8Copy full SHA for ede76f8
torch_struct/cky.py
@@ -23,7 +23,7 @@ def _dp(self, scores, lengths=None, force_grad=False, cache=True):
23
semiring.convert(roots).requires_grad_(True),
24
)
25
if lengths is None:
26
- lengths = torch.LongTensor([N] * batch)
+ lengths = torch.LongTensor([N] * batch).to(terms.device)
27
28
# Charts
29
beta = [
0 commit comments