Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I had this error" RuntimeError: 'lengths' argument should be a 1D CPU int64 tensor", when i try to add LSTM on the top of bert #275

Open
Nada-1234 opened this issue Mar 2, 2023 · 1 comment

Comments

@Nada-1234
Copy link

class SentencePairClassifier(nn.Module):

    def __init__(self, bert_model="bert-base-multilingual-cased", freeze_bert=False):
        super(SentencePairClassifier, self).__init__()
        #  Instantiating BERT-based model object
        self.bert_layer = AutoModel.from_pretrained(bert_model)
            self.hidden_size = 768

        self.LSTM = nn.LSTM(self.hidden_size,self.hidden_size,bidirectional=True)
        
         

        # Freeze bert layers and only train the classification layer weights
        if freeze_bert:
            for p in self.bert_layer.parameters():
                p.requires_grad = False

        # Classification layer
        self.cls_layer = nn.Linear(self.hidden_size*2,1)


        # self.dropout = nn.Dropout(p=0.1)

    @autocast()  # run in mixed precision
    def forward(self, input_ids, attn_masks, token_type_ids):

        # Feeding the inputs to the BERT-based model to obtain contextualized representations
        cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks)
        cont_reps = cont_reps.permute(1, 0, 2)

        enc_hiddens, (last_hidden, last_cell) = self.LSTM(pack_padded_sequence(cont_reps,token_type_ids))

        output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1)

        output_hidden = F.dropout(output_hidden,0.2)
        output = self.cls_layer(output_hidden)
        return F.sigmoid(output)

I tried a lot but it doesn't work please help

@zanshuxun
Copy link
Collaborator

Please add more infomation for us to reproduce the error.

Describe the bug(问题描述)
A clear and concise description of what the bug is.

To Reproduce(复现步骤)
Steps to reproduce the behavior:

  1. Go to '...'
  2. Click on '....'
  3. Scroll down to '....'
  4. See error

Operating environment(运行环境):

  • python version [e.g. 3.6, 3.7]
  • torch version [e.g. 1.9.0, 1.10.0]
  • deepctr-torch version [e.g. 0.2.9,]

Additional context
Add any other context about the problem here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants