-
Notifications
You must be signed in to change notification settings - Fork 20
/
inference.py
executable file
·139 lines (106 loc) · 3.89 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import openai
import os
import time
import threading
import json
import _thread
from tqdm import tqdm
from contextlib import contextmanager
from collections import defaultdict
openai.api_key = "your openai api key"
class TimeoutException(Exception):
def __init__(self, msg=''):
self.msg = msg
@contextmanager
def time_limit(seconds, msg=''):
timer = threading.Timer(seconds, lambda: _thread.interrupt_main())
timer.start()
try:
yield
except KeyboardInterrupt:
raise TimeoutException("Timed out for operation {}".format(msg))
finally:
# if the action ends in specified time, timer is canceled
timer.cancel()
def add_prompt(item, prompt):
def rmreturn(s):
s = s.replace('\n\n', ' ')
s = s.replace('\n', ' ')
return s.strip()
query = item['question']
prompt = prompt.replace('{query}', query)
if item.get('output'): # background info
backinfo = rmreturn(item['output'][0])
prompt = prompt.replace('{background}', backinfo)
return prompt
def clustering_prompt(items, prompt):
def rmreturn(s):
s = s.replace('\n\n', ' ')
s = s.replace('\n', ' ')
return s.strip()
cluster_prompts = []
for item in items:
query = item['question']
backinfo = rmreturn(item['output'][0])
item_prompt = prompt.replace('{query}', query)
item_prompt += f' {backinfo}'
cluster_prompts.append(item_prompt)
cluster_prompts.append(prompt)
return ' \n\n\n\n '.join(cluster_prompts)
def run_embeddings(input_text, engine='text-similarity-davinci-001'):
texts = [t.replace('\n', '') for t in input_text]
outputs = openai.Embedding.create(input=texts, model=engine)['data']
embeddings = [o['embedding'] for o in outputs]
return embeddings
def run_inference(inputs_with_prompts, engine, max_tokens, num_sequence=1, temp=0):
completions = {"choices": []}
for _ in range(200):
try:
with time_limit(20, 'run gpt-3'):
completions = openai.Completion.create(
engine=engine,
max_tokens=max_tokens,
prompt=inputs_with_prompts,
temperature=temp,
n=num_sequence, # num of returned sequence
)
break
except:
time.sleep(2)
outputs = [c["text"] for c in completions["choices"]]
return outputs
def run_main(inlines, outfile, engine, prompt, max_tokens, n=1, temp=0):
if os.path.exists(outfile):
outs = open(outfile, 'a', encoding='utf8')
num_lines = len(open(outfile, 'r').readlines())
inlines = inlines[num_lines - 1: ]
else: # not os.path.exists(outfile)
outs = open(outfile, 'a', encoding='utf8')
outs.write(json.dumps({"prompt": prompt}) + '\n')
pbar = tqdm(total = len(inlines))
index = 0
pbar.update(index)
while index < len(inlines):
inputs, answers = [], []
inputs_with_prompts = []
for _ in range(20):
if index >= len(inlines): break
input_with_prompt = add_prompt(inlines[index], prompt)
inputs.append(inlines[index]['question']) ## a string
answers.append(inlines[index]['answer']) ## a list of strings
inputs_with_prompts.append(input_with_prompt)
index += 1
samples = defaultdict(list)
outputs = run_inference(inputs_with_prompts,
engine, max_tokens, n, temp)
for j, output in enumerate(outputs):
samples[j//n].append(output)
for i in range(len(inputs_with_prompts)):
outs.write(json.dumps({
'question': inputs[i],
'answer': answers[i],
'output': samples[i]})
+'\n')
pbar.update(len(inputs_with_prompts))
pbar.close()
outs.close()