-
Notifications
You must be signed in to change notification settings - Fork 725
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
added FlaubertForSequenceClassification
- Loading branch information
1 parent
1c8835a
commit 048b3eb
Showing
2 changed files
with
89 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
simpletransformers/classification/transformer_models/flaubert_model.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from transformers.modeling_xlm import SequenceSummary, XLMModel, XLMPreTrainedModel | ||
from transformers.modeling_flaubert import FlaubertModel | ||
import torch | ||
import torch.nn as nn | ||
from torch.nn import CrossEntropyLoss, MSELoss | ||
|
||
|
||
class FlaubertForSequenceClassification(XLMPreTrainedModel): | ||
r""" | ||
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: | ||
Labels for computing the sequence classification/regression loss. | ||
Indices should be in ``[0, ..., config.num_labels - 1]``. | ||
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss), | ||
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy). | ||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: | ||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: | ||
Classification (or regression if config.num_labels==1) loss. | ||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)`` | ||
Classification (or regression if config.num_labels==1) scores (before SoftMax). | ||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) | ||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) | ||
of shape ``(batch_size, sequence_length, hidden_size)``: | ||
Hidden-states of the model at the output of each layer plus the initial embedding outputs. | ||
**attentions**: (`optional`, returned when ``config.output_attentions=True``) | ||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: | ||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. | ||
Examples:: | ||
tokenizer = FlaubertTokenizer.from_pretrained('flaubert-base-uncased') | ||
model = FlaubertForSequenceClassification.from_pretrained('flaubert-base-uncased') | ||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 | ||
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1 | ||
outputs = model(input_ids, labels=labels) | ||
loss, logits = outputs[:2] | ||
""" # noqa: ignore flake8" | ||
|
||
def __init__(self, config, weight=None): | ||
super(FlaubertForSequenceClassification, self).__init__(config) | ||
self.num_labels = config.num_labels | ||
self.weight = weight | ||
|
||
self.transformer = FlaubertModel(config) | ||
self.sequence_summary = SequenceSummary(config) | ||
|
||
self.init_weights() | ||
|
||
def forward( | ||
self, | ||
input_ids=None, | ||
attention_mask=None, | ||
langs=None, | ||
token_type_ids=None, | ||
position_ids=None, | ||
lengths=None, | ||
cache=None, | ||
head_mask=None, | ||
inputs_embeds=None, | ||
labels=None, | ||
): | ||
transformer_outputs = self.transformer( | ||
input_ids, | ||
attention_mask=attention_mask, | ||
langs=langs, | ||
token_type_ids=token_type_ids, | ||
position_ids=position_ids, | ||
lengths=lengths, | ||
cache=cache, | ||
head_mask=head_mask, | ||
) | ||
|
||
output = transformer_outputs[0] | ||
logits = self.sequence_summary(output) | ||
|
||
outputs = (logits,) + transformer_outputs[1:] # Keep new_mems and attention/hidden states if they are here | ||
|
||
if labels is not None: | ||
if self.num_labels == 1: | ||
# We are doing regression | ||
loss_fct = MSELoss() | ||
loss = loss_fct(logits.view(-1), labels.view(-1)) | ||
else: | ||
loss_fct = CrossEntropyLoss(weight=self.weight) | ||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) | ||
outputs = (loss,) + outputs | ||
|
||
return outputs |