Skip to content

Commit

Permalink
Minor fix for correctly append hidden_states and attentions
Browse files Browse the repository at this point in the history
  • Loading branch information
monologg committed May 9, 2020
1 parent 57cf4f7 commit b6b3e79
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion model/modeling_jointalbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, s
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
total_loss += self.args.slot_loss_coef * slot_loss

outputs = ((intent_logits, slot_logits),) + outputs[1:] # add hidden states and attention if they are here
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here

outputs = (total_loss,) + outputs

Expand Down
2 changes: 1 addition & 1 deletion model/modeling_jointbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, s
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
total_loss += self.args.slot_loss_coef * slot_loss

outputs = ((intent_logits, slot_logits),) + outputs[1:] # add hidden states and attention if they are here
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here

outputs = (total_loss,) + outputs

Expand Down

0 comments on commit b6b3e79

Please sign in to comment.