-
Notifications
You must be signed in to change notification settings - Fork 25
/
Copy pathinference_modes.py
177 lines (161 loc) · 7.92 KB
/
inference_modes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import re
from funchub.math import *
def func_embedding_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, return_top=5):
cur_generation = ""
cur_generation_with_func = ""
start_length = []
end_length = []
logs = []
funcmodel.inference_mode = "func_embedding"
func_map = list(funcmodel.func_dict.keys())
try:
results = []
func_calls = []
while True:
prompt = templates["general"].replace("[QUESTION]", question) + cur_generation
results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=[13], return_top=return_top)
if return_top > 0:
results, token_log = results
logs.append(token_log)
endflag = True
current_token = 0
record_tokens = token_log[-1]
cur_generation = results[0].replace(templates["general"].replace("[QUESTION]", question), "")
for op in func_map:
if cur_generation.endswith(op+"("):
endflag = False
if start_length and end_length:
bias = 0
# copy the current generation to cur_generation_with_func
cur_generation_with_func = cur_generation
for i in range(len(start_length)):
cur_generation_with_func = cur_generation_with_func[:start_length[i]+bias] +func_calls[i] + cur_generation_with_func[end_length[i]+bias:]
bias += len(func_calls[i]) - (end_length[i] - start_length[i])
else:
cur_generation_with_func = cur_generation
prompt = templates[op].replace("[QUESTION]", question) + cur_generation_with_func
len_prompt = len(prompt)
funcmodel.inference_mode = "baseline"
results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=[29897, 3892], return_top=return_top)
funcmodel.inference_mode = "func_embedding"
if return_top > 0:
results, token_log = results
logs.append(token_log)
generated = results[0][len_prompt:]
cur_generation += generated
args = cur_generation.split(op)[-1].replace("=", "").replace(">", "").replace("((", "(").replace("))", ")")
# remove any $ in the args
args = args.replace("$", "")
# remove , in the args
if ", " in args:
args = args.replace(", ", ";").replace(",", "").replace(";", ", ")
args = args.replace(" ", "")
if "(" not in args or ")" not in args:
raise Exception("invalid args")
# handle %
if '%' in args:
temp = args.split("(")[1].split(")")[0].split(",")
for arg_i, arg in enumerate(temp):
# if have percentage, convert to decimal
if "%" in arg:
arg = arg.replace("%", "").strip()
arg = str(float(arg) / 100)
temp[arg_i] = arg
args = f"({', '.join(temp)})"
try:
res = eval(f"{op[1:-1]}_{args}")
func_calls.append(f"{op}{args} = {res}")
start_length.append(len(cur_generation.split(op)[0]))
cur_generation = cur_generation.split(op)[0] + str(res)
end_length.append(len(cur_generation))
# only generate the next token
# disable all the numbers
prompt = templates["general"].replace("[QUESTION]", question) + cur_generation
results = funcmodel.generate([prompt], max_gen_len=1, temperature=temperature, top_p=top_p, stop_token=[13], return_top=return_top, disable_token = [29900, 29896, 29906, 29941, 29946, 29945, 29953, 29955, 29947, 29929]) # disable all the numbers: 0-9
if return_top > 0:
results, token_log = results
logs.append(token_log)
cur_generation = results[0].replace(templates["general"].replace("[QUESTION]", question), "")
except:
# backtrace
current_token += 1
decode_token = lambda x: funcmodel.tokenizer.decode(x) if x < 32000 else func_map[x - 32000]
cur_generation = cur_generation.split(op)[0] + decode_token(record_tokens[1][current_token][0])
break
if endflag:
break
log = {
"case_idx": case_idx,
"question": question,
"func_calls": func_calls,
"generation": cur_generation.replace("\n", "\\n").strip(),
"status": "success"
}
except Exception as e:
log = {
"case_idx": case_idx,
"question": question,
"func_calls": func_calls,
"generation": cur_generation.replace("\n", "\\n").strip(),
"status": str(e)
}
return log
def vh_embedding_inference(case_idx, question, funcmodel, temperature, top_p, max_func_call):
funcmodel.inference_mode = "func_embedding"
inputs = question[0]
disable_funcs = question[1]
last_func = []
for _ in range(max_func_call):
inputs = funcmodel.generate([inputs], max_gen_len=1, temperature=temperature, top_p=top_p,return_top=0, disable_func=disable_funcs + last_func, no_left_parens=True)[0]
if inputs.endswith(">"):
inputs = inputs.replace("]<", "] <")
inputs += '\n'
last_func = [] if "[WALK]" in inputs.split("\n")[-2] else re.findall(r"\[.*?\]", inputs.split("\n")[-2])
print("last func", last_func)
if "[END]" in inputs.split("Plan:")[-1]:
break
log = {
"case_idx": case_idx,
"question": question[0],
"func_calls": inputs.replace(question[0], "").strip().split("\n"),
"generation": inputs.replace("\n", "\\n").strip(),
# no need to return logs
# "token_log": logs,
"status": "success"
}
return log
def kamel_embedding_inference(templates, case_idx, question, funcmodel, temperature, top_p, max_gen_len, max_func_call):
funcmodel.inference_mode = "func_embedding"
cur_generation = ""
if "funcgeneral" not in templates:
templates["funcgeneral"] = templates["general"]
try:
results = []
func_calls = []
while True:
if max_func_call == 0:
break
prompt = templates["funcgeneral"].replace("[QUESTION]", question) + cur_generation
results = funcmodel.generate([prompt], max_gen_len=max_gen_len, temperature=temperature, top_p=top_p, stop_token=[13])
max_func_call -= 1
cur_generation = results[0].replace(templates["funcgeneral"].replace("[QUESTION]", question), "")
# one function token is enough
break
log = {
"case_idx": case_idx,
"question": question,
"func_calls": func_calls,
"generation": cur_generation.replace("\n", "\\n").strip(),
"status": "success"
}
# f.write(json.dumps(log) + "\n")
except Exception as e:
# if local_rank == 0:
log = {
"case_idx": case_idx,
"question": question,
"func_calls": func_calls,
"generation": cur_generation.replace("\n", "\\n").strip(),
"status": str(e)
}
return log