Skip to content

Commit

Permalink
add MultiTaskDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
anselmwang authored and root committed Jan 4, 2020
1 parent c635d33 commit 0b584fd
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 7 deletions.
24 changes: 23 additions & 1 deletion mt_dnn/batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,35 @@
UNK_ID=100
BOS_ID=101

class MTDNNDataset(Dataset):
class MultiTaskDataset(Dataset):
def __init__(self, datasets):
self._datasets = datasets

task_id_2_data_set_dic = {}
for dataset in datasets:
task_id = dataset.get_task_id()
assert task_id not in task_id_2_data_set_dic, "Duplicate task_id %s" % task_id
task_id_2_data_set_dic[task_id] = dataset

self._task_id_2_data_set_dic = task_id_2_data_set_dic

def __len__(self):
return sum(len(dataset) for dataset in self._datasets)

def __getitem__(self, idx):
task_id, sample_id = idx
return self._task_id_2_data_set_dic[task_id][sample_id]

class SingleTaskDataset(Dataset):
def __init__(self, path, is_train=True, maxlen=128, factor=1.0, task_id=0, task_type=TaskType.Classification, data_type=DataFormat.PremiseOnly):
self._data = self.load(path, is_train, maxlen, factor, task_type)
self._task_id = task_id
self._task_type = task_type
self._data_type = data_type

def get_task_id(self):
return self._task_id

@staticmethod
def load(path, is_train=True, maxlen=128, factor=1.0, task_type=None):
assert task_type is not None
Expand Down
4 changes: 2 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from experiments.exp_def import TaskDefs, EncoderModelType
#from experiments.glue.glue_utils import eval_model

from mt_dnn.batcher import MTDNNDataset, Collater
from mt_dnn.batcher import SingleTaskDataset, Collater
from mt_dnn.model import MTDNNModel
from data_utils.metrics import calc_metrics
from mt_dnn.inference import eval_model
Expand Down Expand Up @@ -57,7 +57,7 @@ def dump(path, data):
model.load(checkpoint_path)
encoder_type = config.get('encoder_type', EncoderModelType.BERT)
# load data
test_data_set = MTDNNDataset(args.prep_input, False, task_type=task_type, maxlen=args.max_seq_len)
test_data_set = SingleTaskDataset(args.prep_input, False, task_type=task_type, maxlen=args.max_seq_len)
collater = Collater(gpu=args.cuda, is_train=False, task_id=args.task_id, task_type=task_type,
data_type=data_type, encoder_type=encoder_type)
test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=collater.collate_fn, pin_memory=args.cuda)
Expand Down
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from data_utils.log_wrapper import create_logger
from data_utils.utils import set_environment
from data_utils.task_def import TaskType, EncoderModelType
from mt_dnn.batcher import MTDNNDataset, Collater
from mt_dnn.batcher import SingleTaskDataset, Collater
from mt_dnn.model import MTDNNModel


Expand Down Expand Up @@ -207,7 +207,7 @@ def main():

train_path = os.path.join(data_dir, '{}_train.json'.format(dataset))
logger.info('Loading {} as task {}'.format(train_path, task_id))
train_data_set = MTDNNDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type)
train_data_set = SingleTaskDataset(train_path, True, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type)
train_data = DataLoader(train_data_set, batch_size=args.batch_size, shuffle=True, collate_fn=train_collater.collate_fn, pin_memory=args.cuda)
train_data_list.append(train_data)

Expand Down Expand Up @@ -237,14 +237,14 @@ def main():
dev_path = os.path.join(data_dir, '{}_dev.json'.format(dataset))
dev_data = None
if os.path.exists(dev_path):
dev_data_set = MTDNNDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type)
dev_data_set = SingleTaskDataset(dev_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type)
dev_data = DataLoader(dev_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda)
dev_data_list.append(dev_data)

test_path = os.path.join(data_dir, '{}_test.json'.format(dataset))
test_data = None
if os.path.exists(test_path):
test_data_set = MTDNNDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type)
test_data_set = SingleTaskDataset(test_path, False, maxlen=args.max_seq_len, task_id=task_id, task_type=task_type, data_type=data_type)
test_data = DataLoader(test_data_set, batch_size=args.batch_size_eval, collate_fn=test_collater.collate_fn, pin_memory=args.cuda)
test_data_list.append(test_data)

Expand Down

0 comments on commit 0b584fd

Please sign in to comment.