-
Notifications
You must be signed in to change notification settings - Fork 2
/
data.py
100 lines (81 loc) · 4.18 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import json
from typing import Dict
from datasets import load_dataset, load_from_disk, Features, Value
from transformers import AutoTokenizer
from utils_and_base_types import read_path
def run_download(config: Dict, logger):
logger.info(f'(Progress) Download invoked with config \n{json.dumps(config, indent=2)}')
if config['name'] == 'paws':
dataset = load_dataset(config['name'], "labeled_final", split=config['split'])
else:
dataset = load_dataset(config['name'], split=config['split'])
dataset.save_to_disk(dataset_path=read_path(config['path_out']))
logger.info(f'(Progress) Terminated normally')
def get_dataset(name: str, config: Dict):
"""Returns a pytorch dataset."""
if name.startswith('huggingface'):
tokenizer = AutoTokenizer.from_pretrained(config['name_model'])
dataset = load_from_disk(read_path(config['path_dataset']))
start = 0 if 'start' not in config else config['start']
if start < 0:
start = 0
end = len(dataset) if 'end' not in config else config['end']
if end < 0:
end = len(dataset)
dataset = dataset.select(indices=range(start, end))
def encode_snli(instances):
return tokenizer(instances['premise'],
instances['hypothesis'],
truncation=True,
padding='max_length',
max_length=config['max_length'],
return_special_tokens_mask=True)
def encode_paws(instances):
return tokenizer(instances['sentence1'],
instances['sentence2'],
truncation=True,
padding='max_length',
max_length=config['max_length'],
return_special_tokens_mask=True)
def encode_text(instances):
return tokenizer(instances['text'],
truncation=True,
padding='max_length',
max_length=config['max_length'],
return_special_tokens_mask=True)
if name in ['huggingface.imdb', 'huggingface.ag_news']:
dataset = dataset.map(encode_text, batched=True, batch_size=config['batch_size'])
elif name == 'huggingface.snli':
dataset = dataset.map(encode_snli, batched=True, batch_size=config['batch_size'])
elif name == 'huggingface.paws':
dataset = dataset.map(encode_paws, batched=True, batch_size=config['batch_size'])
dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True,
batch_size=config['batch_size'])
if name == 'huggingface.snli':
dataset = dataset.filter(lambda examples: examples['labels'] != -1) # -1 if ground truth unknown in snli
dataset.set_format(type='torch', columns=config['columns'])
return dataset
if name == 'local.explanations':
def encode_local(instances):
res = {k: instances[k] for k in config['columns']}
return res
paths_json_files = config['paths_json_files']
if isinstance(paths_json_files, list):
paths_json_files = [read_path(path) for path in paths_json_files]
else:
paths_json_files = read_path(paths_json_files)
dataset = load_dataset('json', data_files=paths_json_files)
# the json lines are loaded into a dictionary; the default key is 'train'
# note: this does not retrieve the actual train split (if the json lines are not part of the train split)
dataset = dataset['train']
features = dataset.features.copy()
# type handling
features['input_ids'].feature = Value(dtype='int64')
features['attributions'].feature = Value(dtype='float32')
features = Features(features)
dataset.cast_(features)
dataset = dataset.map(encode_local, batched=True)
dataset.set_format(type='torch', columns=config['columns'])
return dataset # todo: why is the train-indexing this necessary?
else:
raise NotImplementedError