-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Features] support dynamic src_length #7740
Changes from all commits
dbec012
9d788c8
c88fc40
0938c7f
a4772b0
afb81cc
5dade6a
ea9f5b3
f4cd00c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
import os | ||
import socket | ||
from contextlib import closing | ||
from dataclasses import dataclass, field | ||
from dataclasses import asdict, dataclass, field | ||
from time import sleep | ||
|
||
import requests | ||
|
@@ -68,6 +68,7 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor): | |
self.args.flask_port + port_interval * predictor.tensor_parallel_rank, | ||
self.args.flask_port + port_interval * (predictor.tensor_parallel_rank + 1), | ||
) | ||
self.total_max_length = predictor.config.src_length + predictor.config.max_length | ||
|
||
if self.predictor.tensor_parallel_rank == 0: | ||
# fetch port info | ||
|
@@ -123,16 +124,44 @@ def streaming(data): | |
|
||
# build chat template | ||
if self.predictor.tokenizer.chat_template is not None: | ||
history = json.loads(history) | ||
if not history: | ||
history = [] | ||
# also support history data | ||
elif isinstance(history, str): | ||
history = json.loads(history) | ||
|
||
assert len(history) % 2 == 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 如果len(history) == 0: 下面的都不用做了,很担心成了 query = [[]] |
||
chat_query = [] | ||
for idx in range(0, len(history), 2): | ||
chat_query.append(["", ""]) | ||
chat_query[-1][0], chat_query[-1][1] = history[idx]["utterance"], history[idx + 1]["utterance"] | ||
query = [chat_query] | ||
if isinstance(history[idx], str): | ||
chat_query.append([history[idx], history[idx + 1]]) | ||
elif isinstance(history[idx], dict): | ||
chat_query.append([history[idx]["utterance"], history[idx + 1]["utterance"]]) | ||
else: | ||
raise ValueError( | ||
"history data should be list[str] or list[dict], eg: ['sentence-1', 'sentece-2', ...], or " | ||
"[{'utterance': 'sentence-1'}, {'utterance': 'sentence-2'}, ...]" | ||
) | ||
|
||
# the input of predictor should be batched. | ||
# batched query: [ [[user, bot], [user, bot], ..., [user]] ] | ||
query = [chat_query + [[query]]] | ||
|
||
generation_args = data | ||
self.predictor.config.max_length = generation_args["max_length"] | ||
if "src_length" in generation_args: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里 src_length 和 max_length 加起来不能超过最大值吧?,如果超过了需要改一下吧? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 关于这块之前其实是有讨论过的,结论是:分别用 src_length 和 max_length 分别来做输入和输出的最大长度控制,并且不要强制控制。 不强制控制是在于模型可能是存在支持外推的能力,此时用这个就限制了外推的效果了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 那就枚举出支持外推的模型,然后给他不做限制,不能外推的就要限制 |
||
self.predictor.config.src_length = generation_args["src_length"] | ||
|
||
if self.predictor.config.src_length + self.predictor.config.max_length > self.total_max_length: | ||
output = { | ||
"error_code": 1, | ||
"error_msg": f"The sum of src_length<{self.predictor.config.src_length}> and " | ||
f"max_length<{self.predictor.config.max_length}> should be smaller than or equal to " | ||
f"the max-total-length<{self.total_max_length}>", | ||
} | ||
yield json.dumps(output, ensure_ascii=False) + "\n" | ||
return | ||
|
||
self.predictor.config.top_p = generation_args["top_p"] | ||
self.predictor.config.temperature = generation_args["temperature"] | ||
self.predictor.config.top_k = generation_args["top_k"] | ||
|
@@ -160,13 +189,13 @@ def streaming(data): | |
# refer to: https://github.com/pallets/flask/blob/main/src/flask/app.py#L605 | ||
app.run(host="0.0.0.0", port=self.port, threaded=False) | ||
|
||
def start_ui_service(self, args): | ||
def start_ui_service(self, args, predictor_args): | ||
# do not support start ui service in one command | ||
from multiprocessing import Process | ||
|
||
from gradio_ui import main | ||
|
||
p = Process(target=main, args=(args,)) | ||
p = Process(target=main, args=(args, predictor_args)) | ||
p.daemon = True | ||
p.start() | ||
|
||
|
@@ -194,6 +223,6 @@ def start_ui_service(self, args): | |
server = PredictorServer(server_args, predictor) | ||
|
||
if server.predictor.tensor_parallel_rank == 0: | ||
server.start_ui_service(server_args) | ||
server.start_ui_service(server_args, asdict(predictor.config)) | ||
|
||
server.start_flask_server() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。
我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面之前的判断确定predictor.config.src_length + predictor.config.max_length不可能大于max_position,因此这里直接是max_position就可以了。。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
现在的逻辑就是这样,你可以看:create_predictor 方法最下面的初始化过程。
所以,在 flask_server 里面出来之后就肯定是初始化好了的。