|
9 | 9 | Dict,
|
10 | 10 | Any,
|
11 | 11 | TYPE_CHECKING,
|
12 |
| - Iterable, |
13 | 12 | )
|
14 | 13 |
|
15 | 14 | import torch
|
16 | 15 | from transformers import TextIteratorStreamer
|
17 |
| -from transformers.generation.logits_process import ( |
18 |
| - LogitsProcessorList, |
19 |
| - RepetitionPenaltyLogitsProcessor, |
20 |
| - TemperatureLogitsWarper, |
21 |
| - TopKLogitsWarper, |
22 |
| - TopPLogitsWarper, |
23 |
| -) |
24 | 16 |
|
25 | 17 | from api.templates.utils import apply_stopping_strings
|
26 | 18 |
|
@@ -132,209 +124,3 @@ def generate_stream(
|
132 | 124 |
|
133 | 125 | gc.collect()
|
134 | 126 | torch.cuda.empty_cache()
|
135 |
| - |
136 |
| - |
137 |
| -def prepare_logits_processor( |
138 |
| - temperature: float, repetition_penalty: float, top_p: float, top_k: int |
139 |
| -) -> LogitsProcessorList: |
140 |
| - """ |
141 |
| - Prepare a list of logits processors based on the provided parameters. |
142 |
| -
|
143 |
| - Args: |
144 |
| - temperature (float): The temperature value for temperature warping. |
145 |
| - repetition_penalty (float): The repetition penalty value. |
146 |
| - top_p (float): The top-p value for top-p warping. |
147 |
| - top_k (int): The top-k value for top-k warping. |
148 |
| -
|
149 |
| - Returns: |
150 |
| - LogitsProcessorList: A list of logits processors. |
151 |
| - """ |
152 |
| - processor_list = LogitsProcessorList() |
153 |
| - # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases. |
154 |
| - if temperature >= 1e-5 and temperature != 1.0: |
155 |
| - processor_list.append(TemperatureLogitsWarper(temperature)) |
156 |
| - if repetition_penalty > 1.0: |
157 |
| - processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) |
158 |
| - if 1e-8 <= top_p < 1.0: |
159 |
| - processor_list.append(TopPLogitsWarper(top_p)) |
160 |
| - if top_k > 0: |
161 |
| - processor_list.append(TopKLogitsWarper(top_k)) |
162 |
| - return processor_list |
163 |
| - |
164 |
| - |
165 |
| -def is_partial_stop(output: str, stop_str: str): |
166 |
| - """ Check whether the output contains a partial stop str. """ |
167 |
| - return any( |
168 |
| - stop_str.startswith(output[-i:]) |
169 |
| - for i in range(0, min(len(output), len(stop_str))) |
170 |
| - ) |
171 |
| - |
172 |
| - |
173 |
| -@torch.inference_mode() |
174 |
| -def generate_stream_old( |
175 |
| - model: "PreTrainedModel", |
176 |
| - tokenizer: "PreTrainedTokenizer", |
177 |
| - params: Dict[str, Any], |
178 |
| -): |
179 |
| - # Read parameters |
180 |
| - input_ids = params.get("inputs") |
181 |
| - prompt = params.get("prompt") |
182 |
| - model_name = params.get("model", "llm") |
183 |
| - temperature = float(params.get("temperature", 1.0)) |
184 |
| - repetition_penalty = float(params.get("repetition_penalty", 1.0)) |
185 |
| - top_p = float(params.get("top_p", 1.0)) |
186 |
| - top_k = int(params.get("top_k", -1)) # -1 means disable |
187 |
| - max_new_tokens = int(params.get("max_tokens", 256)) |
188 |
| - echo = bool(params.get("echo", True)) |
189 |
| - stop_str = params.get("stop") |
190 |
| - |
191 |
| - stop_token_ids = params.get("stop_token_ids") or [] |
192 |
| - if tokenizer.eos_token_id not in stop_token_ids: |
193 |
| - stop_token_ids.append(tokenizer.eos_token_id) |
194 |
| - |
195 |
| - logits_processor = prepare_logits_processor( |
196 |
| - temperature, repetition_penalty, top_p, top_k |
197 |
| - ) |
198 |
| - |
199 |
| - output_ids = list(input_ids) |
200 |
| - input_echo_len = len(input_ids) |
201 |
| - |
202 |
| - device = next(model.parameters()).device |
203 |
| - start_ids = torch.as_tensor([input_ids], device=device) |
204 |
| - |
205 |
| - past_key_values, sent_interrupt = None, False |
206 |
| - completion_id: str = f"cmpl-{str(uuid.uuid4())}" |
207 |
| - created: int = int(time.time()) |
208 |
| - previous_text = "" |
209 |
| - for i in range(max_new_tokens): |
210 |
| - if i == 0: # prefill |
211 |
| - out = model(input_ids=start_ids, use_cache=True) |
212 |
| - logits = out.logits |
213 |
| - past_key_values = out.past_key_values |
214 |
| - else: # decoding |
215 |
| - out = model( |
216 |
| - input_ids=torch.as_tensor( |
217 |
| - [[token] if not sent_interrupt else output_ids], |
218 |
| - device=device, |
219 |
| - ), |
220 |
| - use_cache=True, |
221 |
| - past_key_values=past_key_values if not sent_interrupt else None, |
222 |
| - ) |
223 |
| - sent_interrupt = False |
224 |
| - logits = out.logits |
225 |
| - past_key_values = out.past_key_values |
226 |
| - |
227 |
| - if logits_processor: |
228 |
| - if repetition_penalty > 1.0: |
229 |
| - tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) |
230 |
| - else: |
231 |
| - tmp_output_ids = None |
232 |
| - last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] |
233 |
| - else: |
234 |
| - last_token_logits = logits[0, -1, :] |
235 |
| - |
236 |
| - if device == "mps": |
237 |
| - # Switch to CPU by avoiding some bugs in mps backend. |
238 |
| - last_token_logits = last_token_logits.float().to("cpu") |
239 |
| - |
240 |
| - if temperature < 1e-5 or top_p < 1e-8: # greedy |
241 |
| - _, indices = torch.topk(last_token_logits, 2) |
242 |
| - tokens = [int(index) for index in indices.tolist()] |
243 |
| - else: |
244 |
| - probs = torch.softmax(last_token_logits, dim=-1) |
245 |
| - indices = torch.multinomial(probs, num_samples=2) |
246 |
| - tokens = [int(token) for token in indices.tolist()] |
247 |
| - |
248 |
| - token = tokens[0] |
249 |
| - output_ids.append(token) |
250 |
| - |
251 |
| - if token in stop_token_ids: |
252 |
| - stopped = True |
253 |
| - else: |
254 |
| - stopped = False |
255 |
| - |
256 |
| - # Yield the output tokens |
257 |
| - if i % 2 == 0 or i == max_new_tokens - 1 or stopped: |
258 |
| - if echo: |
259 |
| - tmp_output_ids = output_ids |
260 |
| - rfind_start = len(prompt) |
261 |
| - else: |
262 |
| - tmp_output_ids = output_ids[input_echo_len:] |
263 |
| - rfind_start = 0 |
264 |
| - |
265 |
| - output = tokenizer.decode( |
266 |
| - tmp_output_ids, |
267 |
| - skip_special_tokens=True, |
268 |
| - spaces_between_special_tokens=False, |
269 |
| - clean_up_tokenization_spaces=True, |
270 |
| - ) |
271 |
| - |
272 |
| - partially_stopped, finish_reason = False, None |
273 |
| - if stop_str: |
274 |
| - if isinstance(stop_str, str): |
275 |
| - pos = output.rfind(stop_str, rfind_start) |
276 |
| - if pos != -1: |
277 |
| - output = output[:pos] |
278 |
| - stopped = True |
279 |
| - else: |
280 |
| - partially_stopped = is_partial_stop(output, stop_str) |
281 |
| - elif isinstance(stop_str, Iterable): |
282 |
| - for each_stop in stop_str: |
283 |
| - pos = output.rfind(each_stop, rfind_start) |
284 |
| - if pos != -1: |
285 |
| - output = output[:pos] |
286 |
| - stopped = True |
287 |
| - if each_stop == "Observation:": |
288 |
| - finish_reason = "function_call" |
289 |
| - break |
290 |
| - else: |
291 |
| - partially_stopped = is_partial_stop(output, each_stop) |
292 |
| - if partially_stopped: |
293 |
| - break |
294 |
| - else: |
295 |
| - raise ValueError("Invalid stop field type.") |
296 |
| - |
297 |
| - # Prevent yielding partial stop sequence |
298 |
| - if (not partially_stopped) and output and output[-1] != "�": |
299 |
| - delta_text = output[len(previous_text):] |
300 |
| - previous_text = output |
301 |
| - |
302 |
| - yield { |
303 |
| - "id": completion_id, |
304 |
| - "object": "text_completion", |
305 |
| - "created": created, |
306 |
| - "model": model_name, |
307 |
| - "delta": delta_text, |
308 |
| - "text": output, |
309 |
| - "logprobs": None, |
310 |
| - "finish_reason": finish_reason, |
311 |
| - "usage": { |
312 |
| - "prompt_tokens": input_echo_len, |
313 |
| - "completion_tokens": i, |
314 |
| - "total_tokens": input_echo_len + i, |
315 |
| - }, |
316 |
| - } |
317 |
| - |
318 |
| - if stopped: |
319 |
| - break |
320 |
| - |
321 |
| - yield { |
322 |
| - "id": completion_id, |
323 |
| - "object": "text_completion", |
324 |
| - "created": created, |
325 |
| - "model": model_name, |
326 |
| - "delta": "", |
327 |
| - "text": output, |
328 |
| - "logprobs": None, |
329 |
| - "finish_reason": "stop", |
330 |
| - "usage": { |
331 |
| - "prompt_tokens": input_echo_len, |
332 |
| - "completion_tokens": i, |
333 |
| - "total_tokens": input_echo_len + i, |
334 |
| - }, |
335 |
| - } |
336 |
| - |
337 |
| - # Clean |
338 |
| - del past_key_values, out |
339 |
| - gc.collect() |
340 |
| - torch.cuda.empty_cache() |
0 commit comments