-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
245 lines (195 loc) · 7.42 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
import os
import torch
import pathlib
import re
import collections
import functools
import inspect
import sys
import pytest
from typing import List
from torch.utils.data import RandomSampler, DataLoader, Subset, Dataset
class ExitCodeError(Exception):
pass
def sh(x):
if os.system(x):
raise ExitCodeError()
def simple_parse_args_string(args_string):
"""
Parses something like
args1=val1,arg2=val2
Into a dictionary
"""
args_string = args_string.strip()
if not args_string:
return {}
arg_list = args_string.split(",")
args_dict = {}
for arg in arg_list:
k, v = arg.split("=")
args_dict[k] = v
return args_dict
def join_iters(iters):
for iter in iters:
yield from iter
def chunks(iter, n):
arr = []
for x in iter:
arr.append(x)
if len(arr) == n:
yield arr
arr = []
if arr:
yield arr
def group(arr, fn):
res = collections.defaultdict(list)
for ob in arr:
res[fn(ob)].append(ob)
return list(res.values())
def general_detokenize(string):
string = string.replace(" n't", "n't")
string = string.replace(" )", ")")
string = string.replace("( ", "(")
string = string.replace('" ', '"')
string = string.replace(' "', '"')
string = re.sub(r" (['.,])", r"\1", string)
return string
def get_rolling_token_windows(token_list, prefix_token, max_seq_len, context_len):
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
:param token_list: list
List of tokens to be PREDICTED
:param max_seq_len: int
max_seq_len of model (or max_seq_len we want to use)
:param context_len: int
Amount of desired token context for prediction. Needs to be at least 1.
:param prefix_token: token
Dummy token like <eos> so the first token has something to condition on
:return: generator
Generator of tuples
(input_tokens, pred_tokens)
Note: Score only the last len(pred_tokens) logits of the LM
"""
assert 1 <= context_len <= max_seq_len
if not token_list:
return
# +1 offset, going from input->preds
pred_len = max_seq_len - context_len + 1
predicted = 0
# Special handling for first window: predict all tokens
first_seq_len = min(max_seq_len, len(token_list))
yield ([prefix_token] + token_list[: first_seq_len - 1], token_list[:first_seq_len])
predicted += first_seq_len
while predicted < len(token_list):
window_pred_len = min(len(token_list) - predicted, pred_len)
window_end = predicted + window_pred_len
yield (
token_list[window_end - max_seq_len - 1 : window_end - 1],
token_list[window_end - window_pred_len : window_end],
)
predicted += window_pred_len
def make_disjoint_window(pair):
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair
return a[: -(len(b) - 1)], b
class Reorderer:
def __init__(self, arr, fn):
self.size = len(arr)
arr = list(enumerate(arr))
arr = group(arr, lambda x: fn(x[1]))
arr = [([y[0] for y in x], x[0][1]) for x in arr]
arr.sort(key=lambda x: fn(x[1]))
self.arr = arr
def get_reordered(self):
return [x[1] for x in self.arr]
def get_original(self, newarr):
res = [None] * self.size
cov = [False] * self.size
for (inds, _), v in zip(self.arr, newarr):
for ind in inds:
res[ind] = v
cov[ind] = True
assert all(cov)
return res
def positional_deprecated(fn):
"""
A decorator to nudge users into passing only keyword args (`kwargs`) to the
wrapped function, `fn`.
"""
@functools.wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(
f"WARNING: using {fn.__name__} with positional arguments is "
"deprecated and will be disallowed in a future version of "
"lm-evaluation-harness!"
)
return fn(*args, **kwargs)
return _wrapper
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / "tests" / "test_version_stable.py").exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(
f"Unable to find package root within {max_layers} upwards" + f"of {start_path}"
)
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = " or ".join(task_list)
args = [
f"{package_root}/tests/test_version_stable.py",
f"--rootdir={package_root}",
"-k",
f"{task_string}",
]
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(
f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}"
)
def get_input_from_ctx_cont(ctx, continuation, tokenizer):
ctx_enc = [tokenizer.eos_token_id] if ctx=="" else tokenizer.encode(ctx, add_special_tokens=False)
continuation_enc = tokenizer.encode(continuation, add_special_tokens=False)
max_length = 2048
inp = torch.tensor(ctx_enc + continuation_enc, dtype = torch.long)
if max_length > len(inp):#np = torch.tensor(ctx_enc + continuation_enc, dtype = torch.long)
inp = torch.cat([inp, torch.zeros(max_length - len(inp), dtype = torch.long)], dim = 0)
return inp, len(ctx_enc), len(continuation_enc)
def get_dataloader_from_dataset(dataset, subset_size = None, batch_size = 1):
if subset_size and subset_size < len(dataset):
dataset = Subset(dataset, list(range(subset_size)))
sampler = RandomSampler(dataset)
return DataLoader(dataset, sampler=sampler,batch_size=batch_size)
def create_dataloader(tokenizer, docs, fewshot_context, doc_to_cont, subset_size = None, batch_size = 1, num_fewshot = 0):
class NewDataset(Dataset):
def __init__(self, tokenizer, docs, fewshot_context, doc_to_cont, num_fewshot):
super().__init__()
self.tokenizer = tokenizer
self.docs = docs
self.fewshot_context = fewshot_context
self.doc_to_cont = doc_to_cont
def __getitem__(self, i):
out_doc = self.docs[i]
ctx = self.fewshot_context(out_doc, num_fewshot)
continuation = self.doc_to_cont(out_doc)
inp, l_ctx_enc, l_cont_enc = get_input_from_ctx_cont(ctx, continuation, self.tokenizer)
return ctx, continuation, inp, l_ctx_enc, l_cont_enc
def __len__(self):
return len(self.docs)
ds = NewDataset(tokenizer, docs, fewshot_context, doc_to_cont, num_fewshot)
return get_dataloader_from_dataset(ds, subset_size = subset_size, batch_size = batch_size)