Skip to content

Commit

Permalink
Merge branch 'adrienrenaud-add-flaubert-model-multilabel-v2'
Browse files Browse the repository at this point in the history
  • Loading branch information
ThilinaRajapakse committed Feb 11, 2020
2 parents e1400fc + 2078766 commit c6a8286
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
9 changes: 9 additions & 0 deletions .all-contributorsrc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,15 @@
"contributions": [
"maintenance"
]
},
{
"login": "adrienrenaud",
"name": "Adrien Renaud",
"avatar_url": "https://avatars3.githubusercontent.com/u/6208157?v=4",
"profile": "https://github.com/adrienrenaud",
"contributions": [
"code"
]
}
],
"contributorsPerLine": 7,
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) [![Downloads](https://pepy.tech/badge/simpletransformers)](https://pepy.tech/project/simpletransformers)
<!-- ALL-CONTRIBUTORS-BADGE:START - Do not remove or modify this section -->
[![All Contributors](https://img.shields.io/badge/all_contributors-17-orange.svg?style=flat-square)](#contributors-)
[![All Contributors](https://img.shields.io/badge/all_contributors-18-orange.svg?style=flat-square)](#contributors-)
<!-- ALL-CONTRIBUTORS-BADGE:END -->

# Simple Transformers
Expand Down Expand Up @@ -1149,6 +1149,7 @@ Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/d
<td align="center"><a href="https://github.com/nagenshukla"><img src="https://avatars0.githubusercontent.com/u/39196228?v=4" width="100px;" alt=""/><br /><sub><b>nagenshukla</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=nagenshukla" title="Code">💻</a></td>
<td align="center"><a href="https://www.linkedin.com/in/flaviussn/"><img src="https://avatars0.githubusercontent.com/u/20523032?v=4" width="100px;" alt=""/><br /><sub><b>flaviussn</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=flaviussn" title="Code">💻</a> <a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=flaviussn" title="Documentation">📖</a></td>
<td align="center"><a href="http://marctorrellas.github.com"><img src="https://avatars1.githubusercontent.com/u/22045779?v=4" width="100px;" alt=""/><br /><sub><b>Marc Torrellas</b></sub></a><br /><a href="#maintenance-marctorrellas" title="Maintenance">🚧</a></td>
<td align="center"><a href="https://github.com/adrienrenaud"><img src="https://avatars3.githubusercontent.com/u/6208157?v=4" width="100px;" alt=""/><br /><sub><b>Adrien Renaud</b></sub></a><br /><a href="https://github.com/ThilinaRajapakse/simpletransformers/commits?author=adrienrenaud" title="Code">💻</a></td>
</tr>
</table>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
XLMForMultiLabelSequenceClassification,
DistilBertForMultiLabelSequenceClassification,
AlbertForMultiLabelSequenceClassification,
FlaubertForMultiLabelSequenceClassification,
)
from simpletransformers.config.global_args import global_args

Expand All @@ -27,6 +28,8 @@
DistilBertTokenizer,
AlbertConfig,
AlbertTokenizer,
FlaubertConfig,
FlaubertTokenizer,
)


Expand All @@ -53,6 +56,7 @@ def __init__(self, model_type, model_name, num_labels=None, pos_weight=None, arg
"xlm": (XLMConfig, XLMForMultiLabelSequenceClassification, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMultiLabelSequenceClassification, DistilBertTokenizer,),
"albert": (AlbertConfig, AlbertForMultiLabelSequenceClassification, AlbertTokenizer,),
"flaubert": (FlaubertConfig, FlaubertForMultiLabelSequenceClassification, FlaubertTokenizer,),
}

config_class, model_class, tokenizer_class = MODEL_CLASSES[model_type]
Expand Down
54 changes: 54 additions & 0 deletions simpletransformers/custom_models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from transformers.modeling_utils import SequenceSummary, PreTrainedModel
from transformers import RobertaModel
from transformers.configuration_roberta import RobertaConfig
from transformers import FlaubertModel
from torch.nn import BCEWithLogitsLoss

from transformers.modeling_albert import (
Expand Down Expand Up @@ -351,3 +352,56 @@ def forward(
outputs = (loss,) + outputs

return outputs # (loss), logits, (hidden_states), (attentions)


class FlaubertForMultiLabelSequenceClassification(FlaubertModel):
"""
Flaubert model adapted for multi-label sequence classification
"""

def __init__(self, config, pos_weight=None):
super(FlaubertForMultiLabelSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.pos_weight = pos_weight

self.transformer = FlaubertModel(config)
self.sequence_summary = SequenceSummary(config)

self.init_weights()

def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
)

output = transformer_outputs[0]
logits = self.sequence_summary(output)

outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here

if labels is not None:
loss_fct = BCEWithLogitsLoss(pos_weight=self.pos_weight)
labels = labels.float()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
outputs = (loss,) + outputs

return outputs

0 comments on commit c6a8286

Please sign in to comment.