Skip to content

Commit

Permalink
added FlaubertForSequenceClassification
Browse files Browse the repository at this point in the history
  • Loading branch information
adrienrenaud committed Feb 10, 2020
1 parent 1c8835a commit 048b3eb
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 0 deletions.
4 changes: 4 additions & 0 deletions simpletransformers/classification/classification_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
CamembertTokenizer,
XLMRobertaConfig,
XLMRobertaTokenizer,
FlaubertConfig,
FlaubertTokenizer,
)

from simpletransformers.classification.classification_utils import (
Expand All @@ -63,6 +65,7 @@
from simpletransformers.classification.transformer_models.albert_model import AlbertForSequenceClassification
from simpletransformers.classification.transformer_models.camembert_model import CamembertForSequenceClassification
from simpletransformers.classification.transformer_models.xlm_roberta_model import XLMRobertaForSequenceClassification
from simpletransformers.classification.transformer_models.flaubert_model import FlaubertForSequenceClassification

from simpletransformers.config.global_args import global_args

Expand Down Expand Up @@ -97,6 +100,7 @@ def __init__(
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
"camembert": (CamembertConfig, CamembertForSequenceClassification, CamembertTokenizer),
"xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
"flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
}

if args and 'manual_seed' in args:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from transformers.modeling_xlm import SequenceSummary, XLMModel, XLMPreTrainedModel
from transformers.modeling_flaubert import FlaubertModel
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss


class FlaubertForSequenceClassification(XLMPreTrainedModel):
r"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-uncased')
model = FlaubertForSequenceClassification.from_pretrained('flaubert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
loss, logits = outputs[:2]
""" # noqa: ignore flake8"

def __init__(self, config, weight=None):
super(FlaubertForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.weight = weight

self.transformer = FlaubertModel(config)
self.sequence_summary = SequenceSummary(config)

self.init_weights()

def forward(
self,
input_ids=None,
attention_mask=None,
langs=None,
token_type_ids=None,
position_ids=None,
lengths=None,
cache=None,
head_mask=None,
inputs_embeds=None,
labels=None,
):
transformer_outputs = self.transformer(
input_ids,
attention_mask=attention_mask,
langs=langs,
token_type_ids=token_type_ids,
position_ids=position_ids,
lengths=lengths,
cache=cache,
head_mask=head_mask,
)

output = transformer_outputs[0]
logits = self.sequence_summary(output)

outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here

if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss(weight=self.weight)
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs

return outputs

0 comments on commit 048b3eb

Please sign in to comment.