Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Features] support dynamic src_length #7740

Merged
merged 9 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,12 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \
- `flask_port`: Flask服务端口号,默认8010。
- 其他参数请参见[推理文档](./docs/inference.md)中推理参数配置。

此外,如果想通过API脚本的方式跑推理,可参考:`./request_flask_server.py` 文件。

</div></details>



### 6. PyTorch模型权重转换
PaddleNLP 提供了可自动将 PyTorch 相关的权重转化为 Paddle 权重的接口,代码如下:

Expand Down
45 changes: 37 additions & 8 deletions llm/flask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import socket
from contextlib import closing
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from time import sleep

import requests
Expand Down Expand Up @@ -68,6 +68,7 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor):
self.args.flask_port + port_interval * predictor.tensor_parallel_rank,
self.args.flask_port + port_interval * (predictor.tensor_parallel_rank + 1),
)
self.total_max_length = predictor.config.src_length + predictor.config.max_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。

我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面之前的判断确定predictor.config.src_length + predictor.config.max_length不可能大于max_position,因此这里直接是max_position就可以了。。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。

我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)

现在的逻辑就是这样,你可以看:create_predictor 方法最下面的初始化过程。

所以,在 flask_server 里面出来之后就肯定是初始化好了的。


if self.predictor.tensor_parallel_rank == 0:
# fetch port info
Expand Down Expand Up @@ -123,16 +124,44 @@ def streaming(data):

# build chat template
if self.predictor.tokenizer.chat_template is not None:
history = json.loads(history)
if not history:
history = []
# also support history data
elif isinstance(history, str):
history = json.loads(history)

assert len(history) % 2 == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果len(history) == 0: 下面的都不用做了,很担心成了 query = [[]]

chat_query = []
for idx in range(0, len(history), 2):
chat_query.append(["", ""])
chat_query[-1][0], chat_query[-1][1] = history[idx]["utterance"], history[idx + 1]["utterance"]
query = [chat_query]
if isinstance(history[idx], str):
chat_query.append([history[idx], history[idx + 1]])
elif isinstance(history[idx], dict):
chat_query.append([history[idx]["utterance"], history[idx + 1]["utterance"]])
else:
raise ValueError(
"history data should be list[str] or list[dict], eg: ['sentence-1', 'sentece-2', ...], or "
"[{'utterance': 'sentence-1'}, {'utterance': 'sentence-2'}, ...]"
)

# the input of predictor should be batched.
# batched query: [ [[user, bot], [user, bot], ..., [user]] ]
query = [chat_query + [[query]]]

generation_args = data
self.predictor.config.max_length = generation_args["max_length"]
if "src_length" in generation_args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 src_length 和 max_length 加起来不能超过最大值吧?,如果超过了需要改一下吧?

Copy link
Contributor Author

@wj-Mcat wj-Mcat Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于这块之前其实是有讨论过的,结论是:分别用 src_length 和 max_length 分别来做输入和输出的最大长度控制,并且不要强制控制。

不强制控制是在于模型可能是存在支持外推的能力,此时用这个就限制了外推的效果了。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那就枚举出支持外推的模型,然后给他不做限制,不能外推的就要限制

self.predictor.config.src_length = generation_args["src_length"]

if self.predictor.config.src_length + self.predictor.config.max_length > self.total_max_length:
output = {
"error_code": 1,
"error_msg": f"The sum of src_length<{self.predictor.config.src_length}> and "
f"max_length<{self.predictor.config.max_length}> should be smaller than or equal to "
f"the max-total-length<{self.total_max_length}>",
}
yield json.dumps(output, ensure_ascii=False) + "\n"
return

self.predictor.config.top_p = generation_args["top_p"]
self.predictor.config.temperature = generation_args["temperature"]
self.predictor.config.top_k = generation_args["top_k"]
Expand Down Expand Up @@ -160,13 +189,13 @@ def streaming(data):
# refer to: https://github.com/pallets/flask/blob/main/src/flask/app.py#L605
app.run(host="0.0.0.0", port=self.port, threaded=False)

def start_ui_service(self, args):
def start_ui_service(self, args, predictor_args):
# do not support start ui service in one command
from multiprocessing import Process

from gradio_ui import main

p = Process(target=main, args=(args,))
p = Process(target=main, args=(args, predictor_args))
p.daemon = True
p.start()

Expand Down Expand Up @@ -194,6 +223,6 @@ def start_ui_service(self, args):
server = PredictorServer(server_args, predictor)

if server.predictor.tensor_parallel_rank == 0:
server.start_ui_service(server_args)
server.start_ui_service(server_args, asdict(predictor.config))

server.start_flask_server()
89 changes: 72 additions & 17 deletions llm/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,29 @@ def setup_args():
return args


def launch(args):
def create_src_slider(value, maximum):
return gr.Slider(
minimum=1,
maximum=maximum,
value=value,
step=1,
label="Max Src Length",
info="最大输入长度。",
)


def create_max_slider(value, maximum):
return gr.Slider(
minimum=1,
maximum=maximum,
value=value,
step=1,
label="Max Decoding Length",
info="生成结果的最大长度。",
)


def launch(args, default_params: dict = {}):
"""Launch characters dialogue demo."""

def rollback(state):
Expand All @@ -42,7 +64,7 @@ def rollback(state):
shown_context = get_shown_context(context)
return utterance, shown_context, context, state

def regen(state, top_k, top_p, temperature, repetition_penalty, max_length):
def regen(state, top_k, top_p, temperature, repetition_penalty, max_length, src_length):
"""Regenerate response."""
context = state.setdefault("context", [])
if len(context) < 2:
Expand Down Expand Up @@ -74,13 +96,13 @@ def begin(utterance, state):
shown_context = get_shown_context(context)
return utterance, shown_context, context, state

def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_length):
def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length):
"""Model inference."""
utterance = utterance.strip().replace("<br>", "\n")
context = state.setdefault("context", [])

if not utterance:
gr.Warning("invalid inputs111")
gr.Warning("invalid inputs")
# gr.Warning("请输入有效问题")
shown_context = get_shown_context(context)
return None, shown_context, context, state
Expand All @@ -93,11 +115,17 @@ def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_l
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_length": max_length,
"src_length": src_length,
"min_length": 1,
}
res = requests.post(f"http://0.0.0.0:{args.flask_port}/api/chat", json=data, stream=True)
for line in res.iter_lines():
result = json.loads(line)
if result["error_code"] != 0:
gr.Warning(result["error_msg"])
shown_context = get_shown_context(context)
return None, shown_context, context, state

bot_response = result["result"]["response"]

# replace \n with br: https://github.com/gradio-app/gradio/issues/4344
Expand Down Expand Up @@ -156,30 +184,57 @@ def get_shown_context(context):
with gr.Row():
with gr.Column(scale=1):
top_k = gr.Slider(
minimum=1, maximum=100, value=50, step=1, label="Top-k", info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。"
minimum=0,
maximum=default_params.get("top_k", 20),
value=0,
step=1,
label="Top-k",
info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。",
)
top_p = gr.Slider(
minimum=0, maximum=1, value=0.7, step=0.05, label="Top-p", info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。"
minimum=0,
maximum=1,
value=default_params.get("top_p", 0.7),
step=0.05,
label="Top-p",
info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。",
)
temperature = gr.Slider(
minimum=0.05,
maximum=1.5,
value=0.95,
value=default_params.get("temperature", 0.95),
step=0.05,
label="Temperature",
info="该参数越小,模型生成结果更加随机,反之生成结果更加确定。",
)
repetition_penalty = gr.Slider(
minimum=0.1,
maximum=10,
value=1.0,
value=default_params.get("repetition_penalty", 1.2),
step=0.05,
label="Repetition Penalty",
info="该参数越大,生成结果重复的概率越低。设置 1 则不开启。",
)
max_length = gr.Slider(
minimum=1, maximum=1024, value=50, step=1, label="Max Length", info="生成结果的最大长度。"
)
default_src_length = default_params["src_length"]
total_length = default_params["src_length"] + default_params["max_length"]
src_length = create_src_slider(default_src_length, total_length)
max_length = create_max_slider(50, total_length)

def src_length_change_event(src_length_value, max_length_value):
return create_max_slider(
min(total_length - src_length_value, max_length_value),
total_length - src_length_value,
)

def max_length_change_event(src_length_value, max_length_value):
return create_src_slider(
min(total_length - max_length_value, src_length_value),
total_length - max_length_value,
)

src_length.change(src_length_change_event, inputs=[src_length, max_length], outputs=max_length)
max_length.change(max_length_change_event, inputs=[src_length, max_length], outputs=src_length)

with gr.Column(scale=4):
state = gr.State({})
context_chatbot = gr.Chatbot(label="Context")
Expand All @@ -200,7 +255,7 @@ def get_shown_context(context):
api_name="chat",
).then(
infer,
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
)

Expand All @@ -219,13 +274,13 @@ def get_shown_context(context):
)
regen_btn.click(
regen,
inputs=[state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
queue=False,
api_name="chat",
).then(
infer,
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
)

Expand All @@ -237,15 +292,15 @@ def get_shown_context(context):
api_name="chat",
).then(
infer,
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
)

block.queue().launch(server_name="0.0.0.0", server_port=args.port, debug=True)


def main(args):
launch(args)
def main(args, default_params: dict = {}):
launch(args, default_params)


if __name__ == "__main__":
Expand Down
42 changes: 40 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from utils import (
dybatch_preprocess,
get_alibi_slopes,
get_default_max_decoding_length,
get_default_max_encoding_length,
get_infer_model_path,
get_model_max_position_embeddings,
get_prefix_tuning_params,
init_chat_template,
load_real_time_tokens,
Expand All @@ -56,8 +59,8 @@
class PredictorArgument:
model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."})
model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"})
src_length: int = field(default=1024, metadata={"help": "The max length of source text."})
max_length: int = field(default=2048, metadata={"help": "the max length for decoding."})
src_length: int = field(default=None, metadata={"help": "The max length of source text."})
max_length: int = field(default=None, metadata={"help": "the max length for decoding."})
top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"})
top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"})
temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"})
Expand Down Expand Up @@ -693,6 +696,40 @@ def create_predictor(
if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.unk_token

config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)

max_position_embeddings = get_model_max_position_embeddings(config)
if max_position_embeddings is None:
max_position_embeddings = 2048
logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048")

if predictor_args.src_length is None:
if predictor_args.max_length is None:
predictor_args.src_length = get_default_max_encoding_length(config)
predictor_args.max_length = get_default_max_decoding_length(config)
else:
predictor_args.src_length = max_position_embeddings - predictor_args.max_length
if predictor_args.src_length <= 0:
raise ValueError(
f"--max_length<{predictor_args.max_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor_args.max_length is None:
predictor_args.max_length = max_position_embeddings - predictor_args.src_length
if predictor_args.max_length <= 0:
raise ValueError(
f"--src_length<{predictor_args.src_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor_args.src_length + predictor_args.max_length > max_position_embeddings:
raise ValueError(
f"The sum of src_length<{predictor_args.src_length}> and "
f"max_length<{predictor_args.max_length}> should be smaller than or equal to "
f"the maximum position embedding size<{max_position_embeddings}>"
)

# update config parameter for inference predictor
if predictor_args.decode_strategy == "greedy_search":
predictor_args.top_p = 0.0
Expand Down Expand Up @@ -885,6 +922,7 @@ def create_predictor(
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
else:
raise ValueError("the `mode` should be one of [dynamic, static]")

return predictor


Expand Down
Loading