Skip to content

Commit

Permalink
fix bart for QA
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 committed Jun 14, 2020
1 parent 28c9399 commit 4c624a0
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ def evaluate(self, eval_data, output_dir, verbose_logging=False):
"token_type_ids": batch[2],
}

if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta"]:
if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta","bart]:
del inputs["token_type_ids"]

example_indices = batch[3]
Expand Down Expand Up @@ -834,7 +834,7 @@ def predict(self, to_predict, n_best_size=None):
"token_type_ids": batch[2],
}

if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta"]:
if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta","bart"]:
del inputs["token_type_ids"]

example_indices = batch[3]
Expand Down Expand Up @@ -958,7 +958,7 @@ def _get_inputs_dict(self, batch):
"end_positions": batch[4],
}

if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta"]:
if self.args["model_type"] in ["xlm", "roberta", "distilbert", "camembert", "electra", "xlmroberta","bart"]:
del inputs["token_type_ids"]

if self.args["model_type"] in ["xlnet", "xlm"]:
Expand Down

0 comments on commit 4c624a0

Please sign in to comment.