We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
class SentencePairClassifier(nn.Module):
def __init__(self, bert_model="bert-base-multilingual-cased", freeze_bert=False): super(SentencePairClassifier, self).__init__() # Instantiating BERT-based model object self.bert_layer = AutoModel.from_pretrained(bert_model) self.hidden_size = 768 self.LSTM = nn.LSTM(self.hidden_size,self.hidden_size,bidirectional=True) # Freeze bert layers and only train the classification layer weights if freeze_bert: for p in self.bert_layer.parameters(): p.requires_grad = False # Classification layer self.cls_layer = nn.Linear(self.hidden_size*2,1) # self.dropout = nn.Dropout(p=0.1) @autocast() # run in mixed precision def forward(self, input_ids, attn_masks, token_type_ids): # Feeding the inputs to the BERT-based model to obtain contextualized representations cont_reps, pooler_output = self.bert_layer(input_ids, attn_masks) cont_reps = cont_reps.permute(1, 0, 2) enc_hiddens, (last_hidden, last_cell) = self.LSTM(pack_padded_sequence(cont_reps,token_type_ids)) output_hidden = torch.cat((last_hidden[0], last_hidden[1]), dim=1) output_hidden = F.dropout(output_hidden,0.2) output = self.cls_layer(output_hidden) return F.sigmoid(output)
I tried a lot but it doesn't work please help
The text was updated successfully, but these errors were encountered:
Please add more infomation for us to reproduce the error.
Describe the bug(问题描述) A clear and concise description of what the bug is.
To Reproduce(复现步骤) Steps to reproduce the behavior:
Operating environment(运行环境):
Additional context Add any other context about the problem here.
Sorry, something went wrong.
No branches or pull requests
class SentencePairClassifier(nn.Module):
I tried a lot but it doesn't work please help
The text was updated successfully, but these errors were encountered: