Skip to content

Commit

Permalink
add automodel and sort classes for QA
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 authored May 9, 2020
1 parent bc94b34 commit 38b1538
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@
from transformers import (
WEIGHTS_NAME,
AdamW,
AutoConfig,
AutoTokenizer,
AutoModelForQuestionAnswering,
AlbertConfig,
AlbertForQuestionAnswering,
AlbertTokenizer,
Expand Down Expand Up @@ -90,13 +93,14 @@ def __init__(self, model_type, model_name, args=None, use_cuda=True, cuda_device
""" # noqa: ignore flake8"

MODEL_CLASSES = {
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
"auto": (AutoConfig, AutoTokenizer, AutoModelForQuestionAnswering),
"bert": (BertConfig, BertForQuestionAnswering, BertTokenizer),
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer),
"electra": (ElectraConfig, ElectraForQuestionAnswering, ElectraTokenizer),
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
"xlnet": (XLNetConfig, XLNetForQuestionAnswering, XLNetTokenizer),
"xlm": (XLMConfig, XLMForQuestionAnswering, XLMTokenizer),
"distilbert": (DistilBertConfig, DistilBertForQuestionAnswering, DistilBertTokenizer,),
"albert": (AlbertConfig, AlbertForQuestionAnswering, AlbertTokenizer),
"roberta": (RobertaConfig, RobertaForQuestionAnswering, RobertaTokenizer),
"electra": (ElectraConfig, ElectraForQuestionAnswering, ElectraTokenizer),
}

if args and "manual_seed" in args:
Expand Down

0 comments on commit 38b1538

Please sign in to comment.