diff --git a/.all-contributorsrc b/.all-contributorsrc
index 094511f2..ed02193f 100644
--- a/.all-contributorsrc
+++ b/.all-contributorsrc
@@ -164,6 +164,15 @@
"contributions": [
"maintenance"
]
+ },
+ {
+ "login": "adrienrenaud",
+ "name": "Adrien Renaud",
+ "avatar_url": "https://avatars3.githubusercontent.com/u/6208157?v=4",
+ "profile": "https://github.com/adrienrenaud",
+ "contributions": [
+ "code"
+ ]
}
],
"contributorsPerLine": 7,
diff --git a/README.md b/README.md
index 573f9968..ff63a290 100755
--- a/README.md
+++ b/README.md
@@ -1,6 +1,6 @@
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Downloads](https://pepy.tech/badge/simpletransformers)](https://pepy.tech/project/simpletransformers)
-[![All Contributors](https://img.shields.io/badge/all_contributors-17-orange.svg?style=flat-square)](#contributors-)
+[![All Contributors](https://img.shields.io/badge/all_contributors-18-orange.svg?style=flat-square)](#contributors-)
# Simple Transformers
@@ -1149,6 +1149,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
nagenshukla 💻 |
flaviussn 💻 📖 |
Marc Torrellas 🚧 |
+ Adrien Renaud 💻 |
diff --git a/simpletransformers/classification/multi_label_classification_model.py b/simpletransformers/classification/multi_label_classification_model.py
index 611d6eb5..2ff71fb5 100755
--- a/simpletransformers/classification/multi_label_classification_model.py
+++ b/simpletransformers/classification/multi_label_classification_model.py
@@ -10,6 +10,7 @@
XLMForMultiLabelSequenceClassification,
DistilBertForMultiLabelSequenceClassification,
AlbertForMultiLabelSequenceClassification,
+ FlaubertForMultiLabelSequenceClassification,
)
from simpletransformers.config.global_args import global_args
@@ -27,6 +28,8 @@
DistilBertTokenizer,
AlbertConfig,
AlbertTokenizer,
+ FlaubertConfig,
+ FlaubertTokenizer,
)
@@ -53,6 +56,7 @@ def __init__(self, model_type, model_name, num_labels=None, pos_weight=None, arg
"xlm": (XLMConfig, XLMForMultiLabelSequenceClassification, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMultiLabelSequenceClassification, DistilBertTokenizer,),
"albert": (AlbertConfig, AlbertForMultiLabelSequenceClassification, AlbertTokenizer,),
+ "flaubert": (FlaubertConfig, FlaubertForMultiLabelSequenceClassification, FlaubertTokenizer,),
}
config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
diff --git a/simpletransformers/custom_models/models.py b/simpletransformers/custom_models/models.py
index 9830d0db..9b33217b 100755
--- a/simpletransformers/custom_models/models.py
+++ b/simpletransformers/custom_models/models.py
@@ -9,6 +9,7 @@
from transformers.modeling_utils import SequenceSummary, PreTrainedModel
from transformers import RobertaModel
from transformers.configuration_roberta import RobertaConfig
+from transformers import FlaubertModel
from torch.nn import BCEWithLogitsLoss
from transformers.modeling_albert import (
@@ -351,3 +352,56 @@ def forward(
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
+
+
+class FlaubertForMultiLabelSequenceClassification(FlaubertModel):
+ """
+ Flaubert model adapted for multi-label sequence classification
+ """
+
+ def __init__(self, config, pos_weight=None):
+ super(FlaubertForMultiLabelSequenceClassification, self).__init__(config)
+ self.num_labels = config.num_labels
+ self.pos_weight = pos_weight
+
+ self.transformer = FlaubertModel(config)
+ self.sequence_summary = SequenceSummary(config)
+
+ self.init_weights()
+
+ def forward(
+ self,
+ input_ids=None,
+ attention_mask=None,
+ langs=None,
+ token_type_ids=None,
+ position_ids=None,
+ lengths=None,
+ cache=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ ):
+ transformer_outputs = self.transformer(
+ input_ids,
+ attention_mask=attention_mask,
+ langs=langs,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ lengths=lengths,
+ cache=cache,
+ head_mask=head_mask,
+ )
+
+ output = transformer_outputs[0]
+ logits = self.sequence_summary(output)
+
+ outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here
+
+ if labels is not None:
+ loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
+ labels = labels.float()
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
+ outputs = (loss,) + outputs
+
+ return outputs
\ No newline at end of file