1
1
import concurrent
2
2
import concurrent .futures
3
3
import os
4
+ import random
4
5
import socket
5
6
import time
6
- import traceback
7
7
from typing import Dict , List , Optional , Union
8
8
9
9
import requests
10
10
from requests .adapters import HTTPAdapter
11
+ from requests .exceptions import ConnectionError
11
12
from urllib3 .connection import HTTPConnection
12
13
13
14
try :
21
22
22
23
PromptType = Union [PromptList , str ]
23
24
24
- BAILING_RETRY_DELAY : int = 30
25
-
26
25
27
26
class HTTPAdapterWithSocketOptions (HTTPAdapter ):
28
27
@@ -104,7 +103,7 @@ def __init__(
104
103
def generate (
105
104
self ,
106
105
inputs : Union [List [str ], PromptList ],
107
- max_out_len : int = 4096 ,
106
+ max_out_len : int = 11264 ,
108
107
) -> List [str ]:
109
108
"""Generate results given a list of inputs.
110
109
@@ -128,24 +127,33 @@ def generate(
128
127
): i
129
128
for i , input in enumerate (inputs )
130
129
}
131
- results = []
130
+ results = ['' ] * len ( inputs )
132
131
for future in concurrent .futures .as_completed (future_to_m ):
133
132
m = future_to_m [future ] # noqa F841
134
133
resp = future .result ()
135
134
if resp and resp .status_code == 200 :
136
135
try :
137
136
result = resp .json ()
138
137
except Exception as e : # noqa F841
139
- results .append ('' )
138
+ self .logger .error (f'Fail to inference; '
139
+ f'model_name={ self .path } ; '
140
+ f'error={ e } , '
141
+ f'request={ inputs [m ]} ' )
140
142
else :
141
143
if (result .get ('choices' )
142
144
and result ['choices' ][0 ].get ('message' ) and
143
145
result ['choices' ][0 ]['message' ].get ('content' )
144
146
is not None ):
145
- results .append (
146
- result ['choices' ][0 ]['message' ]['content' ])
147
+ results [m ] = \
148
+ result ['choices' ][0 ]['message' ]['content' ]
149
+ else :
150
+ self .logger .error (f'Receive invalid result. '
151
+ f'result={ result } ; '
152
+ f'request={ inputs [m ]} ' )
147
153
else :
148
- results .append ('' )
154
+ self .logger .error (f'Receive invalid response. '
155
+ f'response={ resp } ; '
156
+ f'request={ inputs [m ]} ' )
149
157
self .flush ()
150
158
return results
151
159
@@ -184,39 +192,31 @@ def _generate(
184
192
message ['role' ] = item ['role' ]
185
193
messages .append (message )
186
194
request = {
187
- 'model' :
188
- self ._model ,
189
- 'messages' :
190
- messages ,
191
- 'max_seq_len' :
192
- max (
193
- max_out_len if max_out_len else 4096 ,
194
- self .max_seq_len if self .max_seq_len else 4096 ,
195
- ),
195
+ 'model' : self ._model ,
196
+ 'messages' : messages ,
197
+ 'max_tokens' : max_out_len ,
196
198
}
197
199
request .update (self .generation_kwargs )
198
- try :
199
- retry_num = 0
200
- while retry_num < self . retry :
200
+ retry_num = 0
201
+ while retry_num < self . retry :
202
+ try :
201
203
response = self ._infer_result (request , sess )
202
- if response .status_code == 200 :
203
- break # success
204
- elif response .status_code == 426 :
205
- retry_num += 1 # retry
206
- elif response .status_code in [429 , 500 , 504 ]:
207
- time .sleep (BAILING_RETRY_DELAY )
208
- retry_num += 1 # retry
209
- else :
210
- raise ValueError (f'Status code = { response .status_code } ' )
204
+ except ConnectionError :
205
+ time .sleep (random .randint (10 , 30 ))
206
+ retry_num += 1 # retry
207
+ continue
208
+ if response .status_code == 200 :
209
+ break # success
210
+ elif response .status_code == 426 :
211
+ retry_num += 1 # retry
212
+ elif response .status_code in [302 , 429 , 500 , 504 ]:
213
+ time .sleep (random .randint (10 , 30 ))
214
+ retry_num += 1 # retry
211
215
else :
212
- raise ValueError (
213
- f'Exceed the maximal retry times. Last status code '
214
- f'= { response .status_code } ' )
215
- except Exception as e :
216
- self .logger .error (f'Fail to inference request={ request } ; '
217
- f'model_name={ self .path } ; error={ e } , '
218
- f'stack:{ traceback .format_exc ()} ' )
219
- raise e
216
+ raise ValueError (f'Status code = { response .status_code } ' )
217
+ else :
218
+ # Exceed the maximal retry times.
219
+ return ''
220
220
return response
221
221
222
222
# @retry(stop_max_attempt_number=3, wait_fixed=16000) # ms
0 commit comments