Skip to content

Commit

Permalink
add auto to language modeling and sort classes
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed May 9, 2020
1 parent d658342 commit 7dd8a47
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions simpletransformers/language_modeling/language_modeling_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@
from transformers import (
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoTokenizer,
AutoModelWithLMHead,
BertConfig,
BertForMaskedLM,
BertTokenizer,
Expand Down Expand Up @@ -78,13 +81,14 @@
logger = logging.getLogger(__name__)

MODEL_CLASSES = {
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"auto": (AutoConfig, AutoModelWithLMHead, AutoTokenizer),
"bert": (BertConfig, BertForMaskedLM, BertTokenizer),
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
"camembert": (CamembertConfig, CamembertForMaskedLM, CamembertTokenizer),
"distilbert": (DistilBertConfig, DistilBertForMaskedLM, DistilBertTokenizer),
"electra": (ElectraConfig, ElectraForLanguageModelingModel, ElectraTokenizer),
"gpt2": (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
"openai-gpt": (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
"roberta": (RobertaConfig, RobertaForMaskedLM, RobertaTokenizer),
}


Expand Down

0 comments on commit 7dd8a47

Please sign in to comment.