-
Notifications
You must be signed in to change notification settings - Fork 0
/
benchmark.py
133 lines (92 loc) · 3.61 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Just a benchmark script for evaluating Bits-Per-Byte,Bits-Per-Character, and
Perplexity-Per-Length for enwik8,text8, and WikiText2 respectively. Handles all
my models + HF models.
All other benchmarks reported are computed using:
https://github.com/EleutherAI/lm-evaluation-harness/
"""
import argparse
import pandas as pd
from transformers import GPT2TokenizerFast
from benchmark.benchmark_map import DATASET_MAP
from benchmark.benchmark_metrics import METRIC_REGISTRY
def parse():
parser = argparse.ArgumentParser(description="LM Benchmarks")
parser.add_argument("--dataset", type=str)
parser.add_argument("--model", type=str)
parser.add_argument("--type", type=str)
parser.add_argument("--eval-ctx", type=str)
parser.add_argument("--hf-model", default=False, action="store_true")
parser.add_argument("--bit-quantize", default=False, action="store_true")
args = parser.parse_args()
return args
def check_args(args):
assert args.model in ["base", "medium*", "base*", "XL*", "medium"]
assert args.type in ["GPT2"]
def main():
args = parse()
check_args(args)
if not args.hf_model:
from src.models.GPT2 import model_getter as model_getter
if "*" in args.model:
save_paths = {
"base*": "checkpoints/127_weights.pth.tar",
"medium*": "checkpoints/303_weights.pth.tar",
"XL*": "checkpoints/1B_weights_8bit.pth.tar",
}
model = model_getter(
args.model,
vocab_size=50257,
num_ctx=512,
model_checkpoint=save_paths[args.model],
**{
"fused_residuals": True,
"num_head": 8,
"use_alibi": True,
"quantized_state": True if "XL" in args.model else False,
},
)
elif args.model == "medium":
model = model_getter(
"medium",
vocab_size=50257,
num_ctx=1024,
model_checkpoint="checkpoints/354_weights.pth.tar",
**{"fused_residuals": False, "use_alibi": False},
)
else:
from transformers import AutoModelForCausalLM
model_name = "gpt2" if args.model == "base" else "gpt2-medium"
model = AutoModelForCausalLM.from_pretrained(model_name)
model.cuda()
model.eval()
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
dataset_list = [item for item in args.dataset.split(",")]
task_df = pd.DataFrame()
for dataset in dataset_list:
args.dataset = dataset
dataset_tuple = DATASET_MAP[dataset]
benchmark_function = METRIC_REGISTRY[dataset_tuple.metric]
context_list = [int(item) for item in args.eval_ctx.split(",")]
eval_value = []
ctx = []
for eval_ctx in context_list:
stride, max_length = eval_ctx, eval_ctx
assert dataset_tuple.metric in ["PPL", "BPB", "BPC"]
metric = benchmark_function(
model, args, tokenizer, stride=stride, max_length=max_length
)
eval_value.append(metric.cpu().numpy())
ctx.append(eval_ctx)
single_task_df = pd.DataFrame(
{
"task": [dataset_tuple.dataset_name] * len(context_list),
"metric": [dataset_tuple.metric] * len(context_list),
"value": eval_value,
"eval context length": ctx,
}
)
task_df = task_df.append(single_task_df)
print(task_df)
if __name__ == "__main__":
main()