Skip to content

Commit b9e03ab

Browse files
author
Bei Chen
committed
add cert
1 parent d5f23f2 commit b9e03ab

20 files changed

+5662
-0
lines changed

cert/README.md

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# CERT: Continual Pre-training on Sketches for Library-oriented Code Generation
2+
3+
CERT's source code and our crafted evaluation benchmarks.
4+
5+
## Installation
6+
7+
### For benchmarks installation
8+
```
9+
$ unzip human-eval.zip
10+
$ pip install -e human-eval
11+
```
12+
13+
### For the installation of the CERT runtime environment
14+
```
15+
$ pip install -r requirements.txt
16+
```
17+
18+
## Usage
19+
20+
### Encoding the cleaned code corpus:
21+
22+
- Converting each code file to many code blocks
23+
- Each code block is converted to code sketch
24+
- Tokenizing code and converting text to binary file.
25+
26+
```
27+
$ bash scripts/run_encode_domain.sh
28+
```
29+
30+
### Training CERT
31+
```
32+
$ bash run_cert.sh
33+
```
34+
35+
### Evaluating CERT
36+
37+
Our crafted PandasEval and NumpyEval are placed in human-eval/data.
38+
39+
```
40+
$ bash run_eval_monitor.sh
41+
```
42+
43+
Assign the output file path from the previous step to the POST_PATH variable in run_eval_monitor_step2.sh.
44+
45+
```
46+
$ bash run_eval_monitor_step2.sh
47+
```
48+
49+
50+
## Citation
51+
52+
Please cite using the following bibtex entry:
53+
54+
```
55+
@inproceedings{CERT,
56+
title={{CERT}: Continual Pre-training on Sketches for Library-oriented Code Generation},
57+
author={Zan, Daoguang and Chen, Bei and Yang, Dejian and Lin, Zeqi and Kim, Minsu and Guan, Bei and Wang, Yongji and Chen, Weizhu and Lou, Jian-Guang},
58+
booktitle={The 2022 International Joint Conference on Artificial Intelligence},
59+
year={2022}
60+
}
61+
```

cert/eval_cert.py

+170
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
import os
4+
import re
5+
from secrets import choice
6+
from typing import Dict
7+
8+
import torch
9+
from transformers import pipeline, set_seed
10+
from transformers import AutoTokenizer, AutoModelForCausalLM
11+
from transformers.pipelines.base import Pipeline
12+
13+
from tqdm import tqdm
14+
import ipdb
15+
16+
from human_eval.data import write_jsonl, read_problems
17+
18+
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda')
19+
torch.set_num_threads(16)
20+
21+
def remove_samples_in_comments(prompt: str) -> str:
22+
prompt = prompt.replace('\n', '#N#')
23+
p = re.compile(r'(#N#)+ (>>>|Example).*"""')
24+
return p.sub('\n """', prompt).replace('#N#', '\n')
25+
26+
def load_generation_pipe(model_name_or_path: str, gpu_device: int=0):
27+
model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
28+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
29+
30+
pipe = pipeline(
31+
'text-generation',
32+
model=model,
33+
tokenizer=tokenizer,
34+
device=gpu_device
35+
)
36+
37+
print("load generation pipeline from {} over, vocab size = {}, eos id = {}, gpu device = {}.".format(
38+
model_name_or_path, len(tokenizer), tokenizer.eos_token_id, gpu_device)
39+
)
40+
41+
return pipe
42+
43+
def first_block(string):
44+
"""Split off first block of code by scanning for class, def etc. on newlines."""
45+
return re.split("\nclass|\ndef|\n#|\n@|\nprint|\nif", string)[0].rstrip()
46+
47+
def complete_code(pipe, prompt, num_completions=1, **gen_kwargs):
48+
"""Complete prompt with text generation pipeline and return num_completions."""
49+
set_seed(123)
50+
51+
code_gens = pipe(prompt,
52+
num_return_sequences=num_completions,
53+
**gen_kwargs
54+
)
55+
56+
return [first_block(code_gen["generated_text"][len(prompt):]) for code_gen in code_gens]
57+
58+
def evaluate_on_human_eval(
59+
model_name_or_path: str,
60+
temperature: float,
61+
top_p: float,
62+
num_samples_per_task: int,
63+
max_new_tokens: int,
64+
gpu_device: int,
65+
domain: str,
66+
model_version: str,
67+
overwrite: bool=True,
68+
fetch_from_huggingface: bool=True
69+
) -> str:
70+
71+
if os.path.exists(model_name_or_path):
72+
output_dir = model_name_or_path
73+
elif fetch_from_huggingface:
74+
output_dir = "output/{}".format(model_name_or_path.replace("/", "_"))
75+
os.makedirs(output_dir, exist_ok=True)
76+
else:
77+
return None
78+
79+
eval_name = f"{domain}_API_eval.{model_version}.t{temperature}.p{top_p}.l{max_new_tokens}.n{num_samples_per_task}"
80+
saved_path = os.path.join(output_dir, f"{eval_name}.samples.jsonl")
81+
82+
print(f"saved_path: {saved_path}")
83+
84+
if not overwrite and os.path.exists(saved_path):
85+
return saved_path
86+
87+
pipe: Pipeline = load_generation_pipe(model_name_or_path, gpu_device=gpu_device)
88+
gen_kwargs = {
89+
"do_sample": True, # Default: True
90+
"temperature": temperature,
91+
"max_new_tokens": max_new_tokens,
92+
"top_p": top_p,
93+
"top_k": 0,
94+
"pad_token_id": pipe.tokenizer.pad_token_id if pipe.tokenizer.pad_token_id else pipe.tokenizer.eos_token_id,
95+
"eos_token_id": pipe.tokenizer.eos_token_id
96+
}
97+
98+
problems = read_problems()
99+
samples = []
100+
generate_batch_size = min(50, num_samples_per_task)
101+
102+
bos_token = pipe.tokenizer.bos_token if pipe.tokenizer.bos_token else pipe.tokenizer.eos_token
103+
104+
for task_id in tqdm(problems):
105+
prompt = problems[task_id]["prompt"]
106+
for _ in range(num_samples_per_task // generate_batch_size):
107+
input_prompt = bos_token + prompt
108+
gen_results = complete_code(pipe, input_prompt, num_completions=generate_batch_size, **gen_kwargs)
109+
for gen_result in gen_results:
110+
# samples.append(dict(task_id=task_id, completion=truncate(gen_result)))
111+
samples.append(dict(task_id=task_id, completion=gen_result))
112+
113+
write_jsonl(saved_path, samples)
114+
return saved_path
115+
116+
def run_samples_test(model_name_or_path: str):
117+
pipe = load_generation_pipe(model_name_or_path)
118+
119+
prompt = 'def convert_hours_minutes(hours):'
120+
complete_code(pipe, prompt, num_completions=4)
121+
122+
prompt = '''def area_of_rectangle(a: float, b: float):
123+
"""Returns the area of the rectangle."""'''
124+
complete_code(pipe, prompt, num_completions=2)
125+
126+
prompt = '''def get_urls_from_html(html):
127+
Get all embedded URLs in a HTML string.'''
128+
complete_code(pipe, prompt, num_completions=4)
129+
130+
prompt = '''def is_sorted(lst):
131+
"""
132+
Given a list of numbers, return whether or not they are sorted
133+
in ascending order. If list has more than 1 duplicate of the same
134+
number, return False. Assume no negative numbers and only integers.
135+
"""'''
136+
complete_code(pipe, prompt, 200, num_completions=4)
137+
138+
prompt = '''def is_sorted(lst):
139+
"""
140+
Given a list of numbers, return whether or not they are sorted in ascending order.
141+
If list has more than 1 duplicate of the same number, return False. Assume no negative numbers and only integers.
142+
"""'''
143+
complete_code(pipe, prompt, 200, num_completions=4)
144+
145+
if __name__ == '__main__':
146+
import argparse
147+
parser = argparse.ArgumentParser(description='Run evaluation for Code Generation Model.')
148+
149+
parser.add_argument('-model', '--model_name_or_path', type=str, required=True)
150+
parser.add_argument('-n', '--num_completions', type=int, default=100)
151+
parser.add_argument('-t', '--temperature', type=float, default=0.2)
152+
parser.add_argument('-p', '--top_p', type=float, default=0.95)
153+
parser.add_argument('-l', '--max_new_tokens', type=int, default=100)
154+
parser.add_argument('-gpu', "--gpu_device", type=int, default=0)
155+
parser.add_argument('-d', '--domain', type=str, default="Pandas", choices=["Pandas", "Numpy", "NLTK"])
156+
parser.add_argument('-mv', '--model_version', type=str, default="CERT", choices=["PYCODEGPT", "PYCODEGPT_XL", "CERT"])
157+
158+
args = parser.parse_args()
159+
160+
print(evaluate_on_human_eval(
161+
model_name_or_path=args.model_name_or_path,
162+
temperature=args.temperature,
163+
top_p=args.top_p,
164+
num_samples_per_task=args.num_completions,
165+
gpu_device=args.gpu_device,
166+
max_new_tokens=args.max_new_tokens,
167+
domain=args.domain,
168+
model_version=args.model_version
169+
))
170+
pass

0 commit comments

Comments
 (0)