Skip to content

Commit

Permalink
updatae model
Browse files Browse the repository at this point in the history
  • Loading branch information
TanyaZhao committed Oct 10, 2020
1 parent f28889e commit 772b8ae
Showing 1 changed file with 2 additions and 37 deletions.
39 changes: 2 additions & 37 deletions models/bert_mrc.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, \
valid_output[i][jj] = sequence_output[i][j]

last_bert_layer = self.dropout(valid_output)

# last_bert_layer = self.answer_selection(last_bert_layer, num_ques=self.num_ques)


logits = self.classifier(last_bert_layer) # batch*3, max_seq_len, n_class

if labels is not None:
Expand All @@ -86,37 +84,4 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, \
return loss
else:
logits = F.log_softmax(logits, dim=2) # batch, max_seq_len, n_class
return logits


def answer_selection(self, logits, num_ques):
'''
:param logits: batch*3, max_seq_len, hidden
:return:
'''
batch_size, max_len, feat_dim = logits.size()
logits = logits.view(-1, num_ques, max_len, feat_dim)
logits = logits.permute(1, 0, 2, 3) # 3, batch, max_len, hidden
answer1 = logits[0] # batch, max_len, hidden
answer2 = logits[1]
answer3 = logits[2]
concat_answer = torch.cat([answer1, answer2, answer3])

w1 = F.sigmoid(self.relation1(concat_answer)) # batch, max_len, hidden
w2 = F.sigmoid(self.relation2(concat_answer))
w3 = F.sigmoid(self.relation3(concat_answer))

answer = (answer1*w1 + answer2*w2 + answer3*w3) / 3

return answer










return logits

0 comments on commit 772b8ae

Please sign in to comment.