From 0b584fd5fa7e8aca632068eea5679ea360cfc0d8 Mon Sep 17 00:00:00 2001 From: Yu Wang Date: Fri, 3 Jan 2020 23:28:35 +0000 Subject: [PATCH] add MultiTaskDataset --- mt_dnn/batcher.py | 24 +++++++++++++++++++++++- predict.py | 4 ++-- train.py | 8 ++++---- 3 files changed, 29 insertions(+), 7 deletions(-) diff --git a/mt_dnn/batcher.py b/mt_dnn/batcher.py index 5ec9cb3c..bc8c1ace 100644 --- a/mt_dnn/batcher.py +++ b/mt_dnn/batcher.py @@ -12,13 +12,35 @@ UNK_ID=100 BOS_ID=101 -class MTDNNDataset(Dataset): +class MultiTaskDataset(Dataset): + def __init__(self, datasets): + self._datasets = datasets + + task_id_2_data_set_dic = {} + for dataset in datasets: + task_id = dataset.get_task_id() + assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id + task_id_2_data_set_dic[task_id] = dataset + + self._task_id_2_data_set_dic = task_id_2_data_set_dic + + def __len__(self): + return sum(len(dataset) for dataset in self._datasets) + + def __getitem__(self, idx): + task_id, sample_id = idx + return self._task_id_2_data_set_dic[task_id][sample_id] + +class SingleTaskDataset(Dataset): def __init__(self, path, is_train=True, maxlen=128, factor=1.0, task_id=0, task_type=TaskType.Classification, data_type=DataFormat.PremiseOnly): self._data = self.load(path, is_train, maxlen, factor, task_type) self._task_id = task_id self._task_type = task_type self._data_type = data_type + def get_task_id(self): + return self._task_id + @staticmethod def load(path, is_train=True, maxlen=128, factor=1.0, task_type=None): assert task_type is not None diff --git a/predict.py b/predict.py index 09478cbd..62138e5e 100644 --- a/predict.py +++ b/predict.py @@ -7,7 +7,7 @@ from experiments.exp_def import TaskDefs, EncoderModelType #from experiments.glue.glue_utils import eval_model -from mt_dnn.batcher import MTDNNDataset, Collater +from mt_dnn.batcher import SingleTaskDataset, Collater from mt_dnn.model import MTDNNModel from data_utils.metrics import calc_metrics from mt_dnn.inference import eval_model @@ -57,7 +57,7 @@ def dump(path, data): model.load(checkpoint_path) encoder_type = config.get('encoder_type', EncoderModelType.BERT) # load data -test_data_set = MTDNNDataset(args.prep_input, False, task_type=task_type, maxlen=args.max_seq_len) +test_data_set = SingleTaskDataset(args.prep_input, False, task_type=task_type, maxlen=args.max_seq_len) collater = Collater(gpu=args.cuda, is_train=False, task_id=args.task_id, task_type=task_type, data_type=data_type, encoder_type=encoder_type) test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=collater.collate_fn, pin_memory=args.cuda) diff --git a/train.py b/train.py index 4a346602..f61606af 100644 --- a/train.py +++ b/train.py @@ -17,7 +17,7 @@ from data_utils.log_wrapper import create_logger from data_utils.utils import set_environment from data_utils.task_def import TaskType, EncoderModelType -from mt_dnn.batcher import MTDNNDataset, Collater +from mt_dnn.batcher import SingleTaskDataset, Collater from mt_dnn.model import MTDNNModel @@ -207,7 +207,7 @@ def main(): train_path = os.path.join(data_dir, '{}_train.json'.format(dataset)) logger.info('Loading {} as task {}'.format(train_path, task_id)) - train_data_set = MTDNNDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type) + train_data_set = SingleTaskDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type) train_data = DataLoader(train_data_set, batch_size=args.batch_size, shuffle=True, collate_fn=train_collater.collate_fn, pin_memory=args.cuda) train_data_list.append(train_data) @@ -237,14 +237,14 @@ def main(): dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset)) dev_data = None if os.path.exists(dev_path): - dev_data_set = MTDNNDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type) + dev_data_set = SingleTaskDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type) dev_data = DataLoader(dev_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) dev_data_list.append(dev_data) test_path = os.path.join(data_dir, '{}_test.json'.format(dataset)) test_data = None if os.path.exists(test_path): - test_data_set = MTDNNDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type) + test_data_set = SingleTaskDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type) test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda) test_data_list.append(test_data)