Skip to content

Commit d45db7c

Browse files
author
xusenlin
committed
tiny fix
1 parent afa5f7c commit d45db7c

File tree

4 files changed

+1
-224
lines changed

4 files changed

+1
-224
lines changed

api/config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,6 @@ class LLMSettings(BaseModel):
129129
description="Use flash attention."
130130
)
131131

132-
use_streamer_v2: Optional[bool] = Field(
133-
default=get_bool_env("USE_STREAMER_V2", "true"),
134-
description="Support for transformers.TextIteratorStreamer."
135-
)
136132
interrupt_requests: Optional[bool] = Field(
137133
default=get_bool_env("INTERRUPT_REQUESTS", "true"),
138134
description="Whether to interrupt requests when a new request is received.",

api/engine/hf.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
from api.templates import get_template
3737
from api.templates.glm import generate_stream_chatglm, generate_stream_chatglm_v3
3838
from api.templates.minicpm import generate_stream_minicpm_v
39-
from api.templates.stream import generate_stream, generate_stream_old
39+
from api.templates.stream import generate_stream
4040
from api.templates.utils import get_context_length
4141
from api.utils import create_error_response
4242

@@ -57,7 +57,6 @@ def __init__(
5757
model_name: str,
5858
template_name: Optional[str] = None,
5959
max_model_length: Optional[int] = None,
60-
use_streamer_v2: Optional[bool] = True,
6160
) -> None:
6261
self.model = model
6362
self.tokenizer = tokenizer
@@ -80,9 +79,6 @@ def __init__(
8079
elif self.model.config.model_type == "minicpmv":
8180
self.generate_stream_func = generate_stream_minicpm_v
8281

83-
if not use_streamer_v2:
84-
self.generate_stream_func = generate_stream_old
85-
8682
logger.info(f"Using {self.model_name} Model for Chat!")
8783
logger.info(f"Using {self.template} for Chat!")
8884

api/models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ def create_hf_llm():
8686
model_name=SETTINGS.model_name,
8787
max_model_length=SETTINGS.context_length if SETTINGS.context_length > 0 else None,
8888
template_name=SETTINGS.chat_template,
89-
use_streamer_v2=SETTINGS.use_streamer_v2,
9089
)
9190

9291

api/templates/stream.py

Lines changed: 0 additions & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,10 @@
99
Dict,
1010
Any,
1111
TYPE_CHECKING,
12-
Iterable,
1312
)
1413

1514
import torch
1615
from transformers import TextIteratorStreamer
17-
from transformers.generation.logits_process import (
18-
LogitsProcessorList,
19-
RepetitionPenaltyLogitsProcessor,
20-
TemperatureLogitsWarper,
21-
TopKLogitsWarper,
22-
TopPLogitsWarper,
23-
)
2416

2517
from api.templates.utils import apply_stopping_strings
2618

@@ -132,209 +124,3 @@ def generate_stream(
132124

133125
gc.collect()
134126
torch.cuda.empty_cache()
135-
136-
137-
def prepare_logits_processor(
138-
temperature: float, repetition_penalty: float, top_p: float, top_k: int
139-
) -> LogitsProcessorList:
140-
"""
141-
Prepare a list of logits processors based on the provided parameters.
142-
143-
Args:
144-
temperature (float): The temperature value for temperature warping.
145-
repetition_penalty (float): The repetition penalty value.
146-
top_p (float): The top-p value for top-p warping.
147-
top_k (int): The top-k value for top-k warping.
148-
149-
Returns:
150-
LogitsProcessorList: A list of logits processors.
151-
"""
152-
processor_list = LogitsProcessorList()
153-
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op, so we skip two cases.
154-
if temperature >= 1e-5 and temperature != 1.0:
155-
processor_list.append(TemperatureLogitsWarper(temperature))
156-
if repetition_penalty > 1.0:
157-
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
158-
if 1e-8 <= top_p < 1.0:
159-
processor_list.append(TopPLogitsWarper(top_p))
160-
if top_k > 0:
161-
processor_list.append(TopKLogitsWarper(top_k))
162-
return processor_list
163-
164-
165-
def is_partial_stop(output: str, stop_str: str):
166-
""" Check whether the output contains a partial stop str. """
167-
return any(
168-
stop_str.startswith(output[-i:])
169-
for i in range(0, min(len(output), len(stop_str)))
170-
)
171-
172-
173-
@torch.inference_mode()
174-
def generate_stream_old(
175-
model: "PreTrainedModel",
176-
tokenizer: "PreTrainedTokenizer",
177-
params: Dict[str, Any],
178-
):
179-
# Read parameters
180-
input_ids = params.get("inputs")
181-
prompt = params.get("prompt")
182-
model_name = params.get("model", "llm")
183-
temperature = float(params.get("temperature", 1.0))
184-
repetition_penalty = float(params.get("repetition_penalty", 1.0))
185-
top_p = float(params.get("top_p", 1.0))
186-
top_k = int(params.get("top_k", -1)) # -1 means disable
187-
max_new_tokens = int(params.get("max_tokens", 256))
188-
echo = bool(params.get("echo", True))
189-
stop_str = params.get("stop")
190-
191-
stop_token_ids = params.get("stop_token_ids") or []
192-
if tokenizer.eos_token_id not in stop_token_ids:
193-
stop_token_ids.append(tokenizer.eos_token_id)
194-
195-
logits_processor = prepare_logits_processor(
196-
temperature, repetition_penalty, top_p, top_k
197-
)
198-
199-
output_ids = list(input_ids)
200-
input_echo_len = len(input_ids)
201-
202-
device = next(model.parameters()).device
203-
start_ids = torch.as_tensor([input_ids], device=device)
204-
205-
past_key_values, sent_interrupt = None, False
206-
completion_id: str = f"cmpl-{str(uuid.uuid4())}"
207-
created: int = int(time.time())
208-
previous_text = ""
209-
for i in range(max_new_tokens):
210-
if i == 0: # prefill
211-
out = model(input_ids=start_ids, use_cache=True)
212-
logits = out.logits
213-
past_key_values = out.past_key_values
214-
else: # decoding
215-
out = model(
216-
input_ids=torch.as_tensor(
217-
[[token] if not sent_interrupt else output_ids],
218-
device=device,
219-
),
220-
use_cache=True,
221-
past_key_values=past_key_values if not sent_interrupt else None,
222-
)
223-
sent_interrupt = False
224-
logits = out.logits
225-
past_key_values = out.past_key_values
226-
227-
if logits_processor:
228-
if repetition_penalty > 1.0:
229-
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
230-
else:
231-
tmp_output_ids = None
232-
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
233-
else:
234-
last_token_logits = logits[0, -1, :]
235-
236-
if device == "mps":
237-
# Switch to CPU by avoiding some bugs in mps backend.
238-
last_token_logits = last_token_logits.float().to("cpu")
239-
240-
if temperature < 1e-5 or top_p < 1e-8: # greedy
241-
_, indices = torch.topk(last_token_logits, 2)
242-
tokens = [int(index) for index in indices.tolist()]
243-
else:
244-
probs = torch.softmax(last_token_logits, dim=-1)
245-
indices = torch.multinomial(probs, num_samples=2)
246-
tokens = [int(token) for token in indices.tolist()]
247-
248-
token = tokens[0]
249-
output_ids.append(token)
250-
251-
if token in stop_token_ids:
252-
stopped = True
253-
else:
254-
stopped = False
255-
256-
# Yield the output tokens
257-
if i % 2 == 0 or i == max_new_tokens - 1 or stopped:
258-
if echo:
259-
tmp_output_ids = output_ids
260-
rfind_start = len(prompt)
261-
else:
262-
tmp_output_ids = output_ids[input_echo_len:]
263-
rfind_start = 0
264-
265-
output = tokenizer.decode(
266-
tmp_output_ids,
267-
skip_special_tokens=True,
268-
spaces_between_special_tokens=False,
269-
clean_up_tokenization_spaces=True,
270-
)
271-
272-
partially_stopped, finish_reason = False, None
273-
if stop_str:
274-
if isinstance(stop_str, str):
275-
pos = output.rfind(stop_str, rfind_start)
276-
if pos != -1:
277-
output = output[:pos]
278-
stopped = True
279-
else:
280-
partially_stopped = is_partial_stop(output, stop_str)
281-
elif isinstance(stop_str, Iterable):
282-
for each_stop in stop_str:
283-
pos = output.rfind(each_stop, rfind_start)
284-
if pos != -1:
285-
output = output[:pos]
286-
stopped = True
287-
if each_stop == "Observation:":
288-
finish_reason = "function_call"
289-
break
290-
else:
291-
partially_stopped = is_partial_stop(output, each_stop)
292-
if partially_stopped:
293-
break
294-
else:
295-
raise ValueError("Invalid stop field type.")
296-
297-
# Prevent yielding partial stop sequence
298-
if (not partially_stopped) and output and output[-1] != "�":
299-
delta_text = output[len(previous_text):]
300-
previous_text = output
301-
302-
yield {
303-
"id": completion_id,
304-
"object": "text_completion",
305-
"created": created,
306-
"model": model_name,
307-
"delta": delta_text,
308-
"text": output,
309-
"logprobs": None,
310-
"finish_reason": finish_reason,
311-
"usage": {
312-
"prompt_tokens": input_echo_len,
313-
"completion_tokens": i,
314-
"total_tokens": input_echo_len + i,
315-
},
316-
}
317-
318-
if stopped:
319-
break
320-
321-
yield {
322-
"id": completion_id,
323-
"object": "text_completion",
324-
"created": created,
325-
"model": model_name,
326-
"delta": "",
327-
"text": output,
328-
"logprobs": None,
329-
"finish_reason": "stop",
330-
"usage": {
331-
"prompt_tokens": input_echo_len,
332-
"completion_tokens": i,
333-
"total_tokens": input_echo_len + i,
334-
},
335-
}
336-
337-
# Clean
338-
del past_key_values, out
339-
gc.collect()
340-
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)