Skip to content

Commit

Permalink
[New Features] support dynamic src_length (#7740)
Browse files Browse the repository at this point in the history
* support dynamic src_length

* revert max_position_embedding

* update doc

* update flask_server

* update max_length control

* update request flask_server

* fix max-position-embeddings

* update error message

* update predictor length init
  • Loading branch information
wj-Mcat authored and JunnYu committed Jan 4, 2024
1 parent 866d834 commit cfdecf6
Show file tree
Hide file tree
Showing 8 changed files with 310 additions and 31 deletions.
9 changes: 9 additions & 0 deletions llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,11 @@ python -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" flask_server.py \
- `flask_port`: Flask服务端口号,默认8010。
- 其他参数请参见动态图推理中参数。

此外,如果想通过API脚本的方式跑推理,可参考:`./request_flask_server.py` 文件。

</div></details>

<<<<<<< HEAD
## 6. 量化

量化算法可以将模型权重和激活转为更低比特数值类型表示,能够有效减少显存占用和计算开销。下面我们提供GPTQ和PaddleSlim自研的PTQ策略,分别实现WINT4和W8A8量化。更多技术细节详见[量化策略详细教程](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/docs/zh_cn/tutorials/quant/advanced_quantization.md)
Expand Down Expand Up @@ -465,6 +468,12 @@ python finetune_generation.py ./llama/gptq_argument.json
### 7.2 转化 Pytorch 权重

PaddleNLP 提供了可自动将 Pytorch 相关的权重转化为 Paddle 权重的接口,代码如下:
=======


### 6. PyTorch模型权重转换
PaddleNLP 提供了可自动将 PyTorch 相关的权重转化为 Paddle 权重的接口,代码如下:
>>>>>>> 04dc6251f ([New Features] support dynamic src_length (#7740))
```python
from paddlenlp.transformers import AutoModelForCausalLM
Expand Down
45 changes: 37 additions & 8 deletions llm/flask_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import os
import socket
from contextlib import closing
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from time import sleep

import requests
Expand Down Expand Up @@ -66,6 +66,7 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor):
self.args.base_port + port_interval * predictor.tensor_parallel_rank,
self.args.base_port + port_interval * (predictor.tensor_parallel_rank + 1),
)
self.total_max_length = predictor.config.src_length + predictor.config.max_length

if self.predictor.tensor_parallel_rank == 0:
# fetch port info
Expand Down Expand Up @@ -121,16 +122,44 @@ def streaming(data):

# build chat template
if self.predictor.tokenizer.chat_template is not None:
history = json.loads(history)
if not history:
history = []
# also support history data
elif isinstance(history, str):
history = json.loads(history)

assert len(history) % 2 == 0
chat_query = []
for idx in range(0, len(history), 2):
chat_query.append(["", ""])
chat_query[-1][0], chat_query[-1][1] = history[idx]["utterance"], history[idx + 1]["utterance"]
query = [chat_query]
if isinstance(history[idx], str):
chat_query.append([history[idx], history[idx + 1]])
elif isinstance(history[idx], dict):
chat_query.append([history[idx]["utterance"], history[idx + 1]["utterance"]])
else:
raise ValueError(
"history data should be list[str] or list[dict], eg: ['sentence-1', 'sentece-2', ...], or "
"[{'utterance': 'sentence-1'}, {'utterance': 'sentence-2'}, ...]"
)

# the input of predictor should be batched.
# batched query: [ [[user, bot], [user, bot], ..., [user]] ]
query = [chat_query + [[query]]]

generation_args = data
self.predictor.config.max_length = generation_args["max_length"]
if "src_length" in generation_args:
self.predictor.config.src_length = generation_args["src_length"]

if self.predictor.config.src_length + self.predictor.config.max_length > self.total_max_length:
output = {
"error_code": 1,
"error_msg": f"The sum of src_length<{self.predictor.config.src_length}> and "
f"max_length<{self.predictor.config.max_length}> should be smaller than or equal to "
f"the max-total-length<{self.total_max_length}>",
}
yield json.dumps(output, ensure_ascii=False) + "\n"
return

self.predictor.config.top_p = generation_args["top_p"]
self.predictor.config.temperature = generation_args["temperature"]
self.predictor.config.top_k = generation_args["top_k"]
Expand Down Expand Up @@ -158,13 +187,13 @@ def streaming(data):
# refer to: https://github.com/pallets/flask/blob/main/src/flask/app.py#L605
app.run(host="0.0.0.0", port=self.port, threaded=False)

def start_ui_service(self, args):
def start_ui_service(self, args, predictor_args):
# do not support start ui service in one command
from multiprocessing import Process

from gradio_ui import main

p = Process(target=main, args=(args,))
p = Process(target=main, args=(args, predictor_args))
p.daemon = True
p.start()

Expand All @@ -184,6 +213,6 @@ def start_ui_service(self, args):
server = PredictorServer(server_args, predictor)

if server.predictor.tensor_parallel_rank == 0:
server.start_ui_service(server_args)
server.start_ui_service(server_args, asdict(predictor.config))

server.start_flask_server()
89 changes: 72 additions & 17 deletions llm/gradio_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,29 @@ def setup_args():
return args


def launch(args):
def create_src_slider(value, maximum):
return gr.Slider(
minimum=1,
maximum=maximum,
value=value,
step=1,
label="Max Src Length",
info="最大输入长度。",
)


def create_max_slider(value, maximum):
return gr.Slider(
minimum=1,
maximum=maximum,
value=value,
step=1,
label="Max Decoding Length",
info="生成结果的最大长度。",
)


def launch(args, default_params: dict = {}):
"""Launch characters dialogue demo."""

def rollback(state):
Expand All @@ -42,7 +64,7 @@ def rollback(state):
shown_context = get_shown_context(context)
return utterance, shown_context, context, state

def regen(state, top_k, top_p, temperature, repetition_penalty, max_length):
def regen(state, top_k, top_p, temperature, repetition_penalty, max_length, src_length):
"""Regenerate response."""
context = state.setdefault("context", [])
if len(context) < 2:
Expand Down Expand Up @@ -74,13 +96,13 @@ def begin(utterance, state):
shown_context = get_shown_context(context)
return utterance, shown_context, context, state

def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_length):
def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length):
"""Model inference."""
utterance = utterance.strip().replace("<br>", "\n")
context = state.setdefault("context", [])

if not utterance:
gr.Warning("invalid inputs111")
gr.Warning("invalid inputs")
# gr.Warning("请输入有效问题")
shown_context = get_shown_context(context)
return None, shown_context, context, state
Expand All @@ -93,11 +115,17 @@ def infer(utterance, state, top_k, top_p, temperature, repetition_penalty, max_l
"temperature": temperature,
"repetition_penalty": repetition_penalty,
"max_length": max_length,
"src_length": src_length,
"min_length": 1,
}
res = requests.post(f"http://0.0.0.0:{args.base_port}/api/chat", json=data, stream=True)
for line in res.iter_lines():
result = json.loads(line)
if result["error_code"] != 0:
gr.Warning(result["error_msg"])
shown_context = get_shown_context(context)
return None, shown_context, context, state

bot_response = result["result"]["response"]

# replace \n with br: https://github.com/gradio-app/gradio/issues/4344
Expand Down Expand Up @@ -156,30 +184,57 @@ def get_shown_context(context):
with gr.Row():
with gr.Column(scale=1):
top_k = gr.Slider(
minimum=1, maximum=100, value=50, step=1, label="Top-k", info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。"
minimum=0,
maximum=default_params.get("top_k", 20),
value=0,
step=1,
label="Top-k",
info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。",
)
top_p = gr.Slider(
minimum=0, maximum=1, value=0.7, step=0.05, label="Top-p", info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。"
minimum=0,
maximum=1,
value=default_params.get("top_p", 0.7),
step=0.05,
label="Top-p",
info="该参数越大,模型生成结果更加随机,反之生成结果更加确定。",
)
temperature = gr.Slider(
minimum=0.05,
maximum=1.5,
value=0.95,
value=default_params.get("temperature", 0.95),
step=0.05,
label="Temperature",
info="该参数越小,模型生成结果更加随机,反之生成结果更加确定。",
)
repetition_penalty = gr.Slider(
minimum=0.1,
maximum=10,
value=1.0,
value=default_params.get("repetition_penalty", 1.2),
step=0.05,
label="Repetition Penalty",
info="该参数越大,生成结果重复的概率越低。设置 1 则不开启。",
)
max_length = gr.Slider(
minimum=1, maximum=1024, value=50, step=1, label="Max Length", info="生成结果的最大长度。"
)
default_src_length = default_params["src_length"]
total_length = default_params["src_length"] + default_params["max_length"]
src_length = create_src_slider(default_src_length, total_length)
max_length = create_max_slider(50, total_length)

def src_length_change_event(src_length_value, max_length_value):
return create_max_slider(
min(total_length - src_length_value, max_length_value),
total_length - src_length_value,
)

def max_length_change_event(src_length_value, max_length_value):
return create_src_slider(
min(total_length - max_length_value, src_length_value),
total_length - max_length_value,
)

src_length.change(src_length_change_event, inputs=[src_length, max_length], outputs=max_length)
max_length.change(max_length_change_event, inputs=[src_length, max_length], outputs=src_length)

with gr.Column(scale=4):
state = gr.State({})
context_chatbot = gr.Chatbot(label="Context")
Expand All @@ -200,7 +255,7 @@ def get_shown_context(context):
api_name="chat",
).then(
infer,
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
)

Expand All @@ -219,13 +274,13 @@ def get_shown_context(context):
)
regen_btn.click(
regen,
inputs=[state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
queue=False,
api_name="chat",
).then(
infer,
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
)

Expand All @@ -237,15 +292,15 @@ def get_shown_context(context):
api_name="chat",
).then(
infer,
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length],
inputs=[utt_text, state, top_k, top_p, temperature, repetition_penalty, max_length, src_length],
outputs=[utt_text, context_chatbot, raw_context_json, state],
)

block.queue().launch(server_name="0.0.0.0", server_port=args.port, debug=True)


def main(args):
launch(args)
def main(args, default_params: dict = {}):
launch(args, default_params)


if __name__ == "__main__":
Expand Down
42 changes: 40 additions & 2 deletions llm/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@
from utils import (
dybatch_preprocess,
get_alibi_slopes,
get_default_max_decoding_length,
get_default_max_encoding_length,
get_infer_model_path,
get_model_max_position_embeddings,
get_prefix_tuning_params,
init_chat_template,
load_real_time_tokens,
Expand All @@ -56,8 +59,8 @@
class PredictorArgument:
model_name_or_path: str = field(default=None, metadata={"help": "The directory of model."})
model_prefix: str = field(default="model", metadata={"help": "the prefix name of static model"})
src_length: int = field(default=1024, metadata={"help": "The max length of source text."})
max_length: int = field(default=2048, metadata={"help": "the max length for decoding."})
src_length: int = field(default=None, metadata={"help": "The max length of source text."})
max_length: int = field(default=None, metadata={"help": "the max length for decoding."})
top_k: int = field(default=0, metadata={"help": "top_k parameter for generation"})
top_p: float = field(default=0.7, metadata={"help": "top_p parameter for generation"})
temperature: float = field(default=0.95, metadata={"help": "top_p parameter for generation"})
Expand Down Expand Up @@ -732,6 +735,40 @@ def create_predictor(
if isinstance(tokenizer, LlamaTokenizer) and not tokenizer.pad_token:
tokenizer.pad_token = tokenizer.unk_token

config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)

max_position_embeddings = get_model_max_position_embeddings(config)
if max_position_embeddings is None:
max_position_embeddings = 2048
logger.warning("Can not retrieval `max_position_embeddings` from config.json, use default value 2048")

if predictor_args.src_length is None:
if predictor_args.max_length is None:
predictor_args.src_length = get_default_max_encoding_length(config)
predictor_args.max_length = get_default_max_decoding_length(config)
else:
predictor_args.src_length = max_position_embeddings - predictor_args.max_length
if predictor_args.src_length <= 0:
raise ValueError(
f"--max_length<{predictor_args.max_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor_args.max_length is None:
predictor_args.max_length = max_position_embeddings - predictor_args.src_length
if predictor_args.max_length <= 0:
raise ValueError(
f"--src_length<{predictor_args.src_length}> param should be smaller "
f"than max_position_embeddings<{max_position_embeddings}>"
)
else:
if predictor_args.src_length + predictor_args.max_length > max_position_embeddings:
raise ValueError(
f"The sum of src_length<{predictor_args.src_length}> and "
f"max_length<{predictor_args.max_length}> should be smaller than or equal to "
f"the maximum position embedding size<{max_position_embeddings}>"
)

# update config parameter for inference predictor
if predictor_args.decode_strategy == "greedy_search":
predictor_args.top_p = 0.0
Expand Down Expand Up @@ -925,6 +962,7 @@ def create_predictor(
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
else:
raise ValueError("the `mode` should be one of [dynamic, static]")

return predictor


Expand Down
Loading

0 comments on commit cfdecf6

Please sign in to comment.