-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[New Features] support dynamic src_length #7740
[New Features] support dynamic src_length #7740
Conversation
Thanks for your contribution! |
用以下脚本扫了一遍 paddlenlp 所有的模型列表: import importlib
import inspect
from paddlenlp.transformers import PretrainedModel, AutoConfig
from tqdm import tqdm
def find_max_position_embeddings(config):
names = [
"max_sequence_length",
"max_position_embeddings"
]
for name in names:
max_length = config.get(name, None)
if max_length is not None:
return max_length
return None
def find_pretrained_model_classes():
module = importlib.import_module("paddlenlp.transformers")
model_names = []
model_names_dict = {}
for attr_name in dir(module):
if attr_name.startswith("_"):
continue
obj = getattr(module, attr_name)
if not inspect.isclass(obj):
continue
if not issubclass(obj, PretrainedModel):
continue
model_names.extend(
list(obj.pretrained_init_configuration.keys())
)
model_names.append("meta-llama/Llama-2-7b")
for model_name in tqdm(model_names):
try:
config = AutoConfig.from_pretrained(model_name)
except:
continue
# max_position_embeddings
max_length = find_max_position_embeddings(config)
# print(f"model_name<{model_name}> \t\t -> {max_length}")
if max_length is None:
print("error ->", model_name)
# raise ValueError("sdsdf")
find_pretrained_model_classes() 打印的日志为: 0%| | 0/2429 [00:00<?, ?it/s]error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> ernie-code-base
error -> ernie-code-base-L512
error -> ernie-code-base
error -> ernie-code-base-L512
error -> ernie-code-base
error -> ernie-code-base-L512
error -> ernie-code-base
error -> ernie-code-base-L512
30%|████████████████████████████████████████▊ | 729/2429 [00:00<00:00, 4059.37it/s]
47%|███████████████████████████████████████████████████████████████ | 1135/2429 [00:00<00:00, 2752.97it/s]
63%|████████████████████████████████████████████████████████████████████████████████████▍ | 1519/2429 [00:00<00:00, 2813.16it/s]�[33m[2023-12-28 15:52:40,981] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,984] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,984] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,984] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
74%|█████████████████████████████████████████████████████████████████████████████████████████████████████ | 1806/2429 [00:08<00:05, 123.11it/s]
82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉ | 1982/2429 [00:09<00:03, 124.91it/s]error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2429/2429 [00:09<00:00, 210.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2429/2429 [00:09<00:00, 246.22it/s]
|
从上述的 log 日志中可以看出,通过 |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## develop #7740 +/- ##
===========================================
+ Coverage 57.29% 57.31% +0.01%
===========================================
Files 584 584
Lines 87636 87688 +52
===========================================
+ Hits 50209 50254 +45
- Misses 37427 37434 +7 ☔ View full report in Codecov by Sentry. |
bot_response["utterance"] = bot_response["utterance"][:-5] | ||
text += bot_response["utterance"] | ||
|
||
print("result -> ", text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print放到外面吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样外面每个都需要 print
这个是个示例脚本,只是介绍他们怎么用,这个就没必要吧
query = [chat_query] | ||
|
||
generation_args = data | ||
self.predictor.config.max_length = generation_args["max_length"] | ||
if "src_length" in generation_args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里 src_length 和 max_length 加起来不能超过最大值吧?,如果超过了需要改一下吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
关于这块之前其实是有讨论过的,结论是:分别用 src_length 和 max_length 分别来做输入和输出的最大长度控制,并且不要强制控制。
不强制控制是在于模型可能是存在支持外推的能力,此时用这个就限制了外推的效果了。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那就枚举出支持外推的模型,然后给他不做限制,不能外推的就要限制
# also support history data | ||
elif isinstance(history, str): | ||
history = json.loads(history) | ||
|
||
assert len(history) % 2 == 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果len(history) == 0: 下面的都不用做了,很担心成了 query = [[]]
llm/gradio_ui.py
Outdated
info="生成结果的最大长度。", | ||
) | ||
default_src_length = default_params["src_length"] | ||
src_length = gr.Slider( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个往上面挪一点,先src length再max length
llm/predictor.py
Outdated
if predictor.config.src_length is None: | ||
predictor.config.src_length = get_default_max_encoding_length(predictor.model_config) | ||
|
||
if predictor.config.max_length is None: | ||
predictor.config.max_length = get_default_max_decoding_length(predictor.model_config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是不是要判断一下,当用户指定这两个参数的时候,src len + max len 会超过最大值吗?你默认值不可能超过,但是一旦用户自己指定了,有可能超过
"temperature": 0.95, | ||
"repetition_penalty": 1.3, | ||
"max_length": 100, | ||
"src_length": 100, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
当前flask没有兜底的机制,当src len + max len 无限长的时候,程序直接挂了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里还是没有解决问题,需要给个error message吗?
@@ -68,6 +68,7 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor): | |||
self.args.flask_port + port_interval * predictor.tensor_parallel_rank, | |||
self.args.flask_port + port_interval * (predictor.tensor_parallel_rank + 1), | |||
) | |||
self.total_max_length = predictor.config.src_length + predictor.config.max_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。
我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
上面之前的判断确定predictor.config.src_length + predictor.config.max_length不可能大于max_position,因此这里直接是max_position就可以了。。。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。
我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)
现在的逻辑就是这样,你可以看:create_predictor 方法最下面的初始化过程。
所以,在 flask_server 里面出来之后就肯定是初始化好了的。
151ae6f
to
f4cd00c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* support dynamic src_length * revert max_position_embedding * update doc * update flask_server * update max_length control * update request flask_server * fix max-position-embeddings * update error message * update predictor length init
PR types
New features
PR changes
LLM
Description