Skip to content

Commit

Permalink
added FlaubertForMultiLabelSequenceClassification
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienrenaud committed Feb 10, 2020
1 parent 9170750 commit 2078766
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
XLMForMultiLabelSequenceClassification,
DistilBertForMultiLabelSequenceClassification,
AlbertForMultiLabelSequenceClassification,
FlaubertForMultiLabelSequenceClassification,
)
from simpletransformers.config.global_args import global_args

Expand All @@ -27,6 +28,8 @@
DistilBertTokenizer,
AlbertConfig,
AlbertTokenizer,
FlaubertConfig,
FlaubertTokenizer,
)


Expand All @@ -53,6 +56,7 @@ def __init__(self, model_type, model_name, num_labels=None, pos_weight=None, arg
"xlm": (XLMConfig, XLMForMultiLabelSequenceClassification, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMultiLabelSequenceClassification, DistilBertTokenizer,),
"albert": (AlbertConfig, AlbertForMultiLabelSequenceClassification, AlbertTokenizer,),
"flaubert": (FlaubertConfig, FlaubertForMultiLabelSequenceClassification, FlaubertTokenizer,),
}

config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
Expand Down
54 changes: 54 additions & 0 deletions simpletransformers/custom_models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers.modeling_utils import SequenceSummary, PreTrainedModel
from transformers import RobertaModel
from transformers.configuration_roberta import RobertaConfig
from transformers import FlaubertModel
from torch.nn import BCEWithLogitsLoss

from transformers.modeling_albert import (
Expand Down Expand Up @@ -351,3 +352,56 @@ def forward(
outputs = (loss,) + outputs

return outputs # (loss), logits, (hidden_states), (attentions)


class FlaubertForMultiLabelSequenceClassification(FlaubertModel):
"""
Flaubert model adapted for multi-label sequence classification
"""

def __init__(self, config, pos_weight=None):
super(FlaubertForMultiLabelSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.pos_weight = pos_weight

self.transformer = FlaubertModel(config)
self.sequence_summary = SequenceSummary(config)

self.init_weights()

def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
)

output = transformer_outputs[0]
logits = self.sequence_summary(output)

outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here

if labels is not None:
loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
labels = labels.float()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
outputs = (loss,) + outputs

return outputs

0 comments on commit 2078766

Please sign in to comment.