Skip to content

Commit d738374

Browse files
committed
fix gpu support
1 parent 5177651 commit d738374

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

netshare/models/doppelganger_torch/network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ def forward(
428428
mask_[:, i] *= tmp_mask
429429
tmp_mask = mask_[:, i]
430430

431-
mask_shift = torch.cat((torch.ones(mask_.size()[0], 1), mask_[:, :-1]), axis=1)
431+
mask_shift = torch.cat((torch.ones(mask_.size()[0], 1).to(self.device), mask_[:, :-1]), axis=1)
432432
# mask shape: (batch_size, step/sample_len, 1)
433433
mask_shift = torch.unsqueeze(mask_shift, 2)
434434
# mask shape: (batch_size, step/sample_len, num_feature*sample_len)

0 commit comments

Comments
 (0)