Skip to content

Commit

Permalink
Deprecate apply_permutation
Browse files Browse the repository at this point in the history
  • Loading branch information
yzhangcs committed May 11, 2023
1 parent bbe23fe commit 42688d0
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions supar/modules/lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn as nn
from supar.modules.dropout import SharedDropout
from torch.nn.modules.rnn import apply_permutation
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence


Expand Down Expand Up @@ -170,10 +169,7 @@ def permute_hidden(
) -> Tuple[torch.Tensor, torch.Tensor]:
if permutation is None:
return hx
h = apply_permutation(hx[0], permutation)
c = apply_permutation(hx[1], permutation)

return h, c
return hx[0].index_select(1, permutation), hx[1].index_select(1, permutation)

def layer_forward(
self,
Expand Down

0 comments on commit 42688d0

Please sign in to comment.