diff --git a/simpletransformers/classification/multi_label_classification_model.py b/simpletransformers/classification/multi_label_classification_model.py index 611d6eb5..2ff71fb5 100755 --- a/simpletransformers/classification/multi_label_classification_model.py +++ b/simpletransformers/classification/multi_label_classification_model.py @@ -10,6 +10,7 @@ XLMForMultiLabelSequenceClassification, DistilBertForMultiLabelSequenceClassification, AlbertForMultiLabelSequenceClassification, + FlaubertForMultiLabelSequenceClassification, ) from simpletransformers.config.global_args import global_args @@ -27,6 +28,8 @@ DistilBertTokenizer, AlbertConfig, AlbertTokenizer, + FlaubertConfig, + FlaubertTokenizer, ) @@ -53,6 +56,7 @@ def __init__(self, model_type, model_name, num_labels=None, pos_weight=None, arg "xlm": (XLMConfig, XLMForMultiLabelSequenceClassification, XLMTokenizer), "distilbert": (DistilBertConfig, DistilBertForMultiLabelSequenceClassification, DistilBertTokenizer,), "albert": (AlbertConfig, AlbertForMultiLabelSequenceClassification, AlbertTokenizer,), + "flaubert": (FlaubertConfig, FlaubertForMultiLabelSequenceClassification, FlaubertTokenizer,), } config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type] diff --git a/simpletransformers/custom_models/models.py b/simpletransformers/custom_models/models.py index 9830d0db..9b33217b 100755 --- a/simpletransformers/custom_models/models.py +++ b/simpletransformers/custom_models/models.py @@ -9,6 +9,7 @@ from transformers.modeling_utils import SequenceSummary, PreTrainedModel from transformers import RobertaModel from transformers.configuration_roberta import RobertaConfig +from transformers import FlaubertModel from torch.nn import BCEWithLogitsLoss from transformers.modeling_albert import ( @@ -351,3 +352,56 @@ def forward( outputs = (loss,) + outputs return outputs # (loss), logits, (hidden_states), (attentions) + + +class FlaubertForMultiLabelSequenceClassification(FlaubertModel): + """ + Flaubert model adapted for multi-label sequence classification + """ + + def __init__(self, config, pos_weight=None): + super(FlaubertForMultiLabelSequenceClassification, self).__init__(config) + self.num_labels = config.num_labels + self.pos_weight = pos_weight + + self.transformer = FlaubertModel(config) + self.sequence_summary = SequenceSummary(config) + + self.init_weights() + + def forward( + self, + input_ids=None, + attention_mask=None, + langs=None, + token_type_ids=None, + position_ids=None, + lengths=None, + cache=None, + head_mask=None, + inputs_embeds=None, + labels=None, + ): + transformer_outputs = self.transformer( + input_ids, + attention_mask=attention_mask, + langs=langs, + token_type_ids=token_type_ids, + position_ids=position_ids, + lengths=lengths, + cache=cache, + head_mask=head_mask, + ) + + output = transformer_outputs[0] + logits = self.sequence_summary(output) + + outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here + + if labels is not None: + loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight) + labels = labels.float() + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels)) + outputs = (loss,) + outputs + + return outputs \ No newline at end of file