diff --git a/simpletransformers/classification/transformer_models/flaubert_model.py b/simpletransformers/classification/transformer_models/flaubert_model.py index 1434ccd6..1ec90400 100644 --- a/simpletransformers/classification/transformer_models/flaubert_model.py +++ b/simpletransformers/classification/transformer_models/flaubert_model.py @@ -5,7 +5,7 @@ from torch.nn import CrossEntropyLoss, MSELoss -class FlaubertForSequenceClassification(XLMPreTrainedModel): +class FlaubertForSequenceClassification(FlaubertModel): r""" **labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: Labels for computing the sequence classification/regression loss.