diff --git a/simpletransformers/question_answering/question_answering_model.py b/simpletransformers/question_answering/question_answering_model.py index 3b4c8c12..0d3a517a 100755 --- a/simpletransformers/question_answering/question_answering_model.py +++ b/simpletransformers/question_answering/question_answering_model.py @@ -704,7 +704,7 @@ def evaluate(self, eval_data, output_dir, verbose_logging=False): "token_type_ids": batch[2], } - if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta","bart]: + if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta","bart"]: del inputs["token_type_ids"] example_indices = batch[3]