From df1b5f0eaf5f3f9ba0fb588fc4b6461034696cbe Mon Sep 17 00:00:00 2001 From: Yu Wang Date: Mon, 23 Sep 2019 15:21:13 -0700 Subject: [PATCH] add MRC to BatchGen & fix unit test --- experiments/glue/glue_prepro.py | 2 +- mt_dnn/batcher.py | 13 +++++++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/experiments/glue/glue_prepro.py b/experiments/glue/glue_prepro.py index cbd552b0..dda138f6 100644 --- a/experiments/glue/glue_prepro.py +++ b/experiments/glue/glue_prepro.py @@ -3,9 +3,9 @@ import random from sys import path +path.append(os.getcwd()) from experiments.common_utils import dump_rows -path.append(os.getcwd()) from data_utils.log_wrapper import create_logger from experiments.glue.glue_utils import * diff --git a/mt_dnn/batcher.py b/mt_dnn/batcher.py index 149c72f9..5130234f 100644 --- a/mt_dnn/batcher.py +++ b/mt_dnn/batcher.py @@ -130,11 +130,20 @@ def __iter__(self): # in training model, label is used by Pytorch, so would be tensor if self.task_type == TaskType.Regression: batch_data.append(torch.FloatTensor(labels)) - else: + batch_info['label'] = len(batch_data) - 1 + elif self.task_type in (TaskType.Classification, TaskType.Ranking): batch_data.append(torch.LongTensor(labels)) - batch_info['label'] = len(batch_data) - 1 + batch_info['label'] = len(batch_data) - 1 + elif self.task_type == TaskType.Span: + start = [sample['token_start'] for sample in batch] + end = [sample['token_end'] for sample in batch] + batch_data.extend([torch.LongTensor(start), torch.LongTensor(end)]) + batch_info['start'] = len(batch_data) - 2 + batch_info['end'] = len(batch_data) - 1 + # soft label generated by ensemble models for knowledge distillation if self.soft_label_on and (batch[0].get('softlabel', None) is not None): + assert self.task_type != TaskType.Span # Span task doesn't support soft label yet. sortlabels = [sample['softlabel'] for sample in batch] sortlabels = torch.FloatTensor(sortlabels) batch_info['soft_label'] = self.patch(sortlabels.pin_memory()) if self.gpu else sortlabels