Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Features] support dynamic src_length #7740

Merged
merged 9 commits into from
Jan 4, 2024

Conversation

wj-Mcat
Copy link
Contributor

@wj-Mcat wj-Mcat commented Dec 28, 2023

PR types

New features

PR changes

LLM

Description

  • support src-length
  • init max_position_embedding from config.json
  • fix history data structure: str and dict

Copy link

paddle-bot bot commented Dec 28, 2023

Thanks for your contribution!

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Dec 28, 2023

用以下脚本扫了一遍 paddlenlp 所有的模型列表:

import importlib
import inspect
from paddlenlp.transformers import PretrainedModel, AutoConfig
from tqdm import tqdm


def find_max_position_embeddings(config):
    names = [
        "max_sequence_length",
        "max_position_embeddings"
    ]
    for name in names:
        max_length = config.get(name, None)
        if max_length is not None:
            return max_length
    return None 
        

def find_pretrained_model_classes():
    module = importlib.import_module("paddlenlp.transformers")
    model_names = []
    model_names_dict = {}
    for attr_name in dir(module):
        if attr_name.startswith("_"):
            continue
        obj = getattr(module, attr_name)
        if not inspect.isclass(obj):
            continue
        if not issubclass(obj, PretrainedModel):
            continue
        
        model_names.extend(
            list(obj.pretrained_init_configuration.keys())
        )
    model_names.append("meta-llama/Llama-2-7b")
    for model_name in tqdm(model_names):
        try:
            config = AutoConfig.from_pretrained(model_name)     
        except:
            continue
        # max_position_embeddings
        max_length = find_max_position_embeddings(config)
        # print(f"model_name<{model_name}> \t\t -> {max_length}")
        if max_length is None:
            print("error ->", model_name)
            # raise ValueError("sdsdf")

find_pretrained_model_classes()

打印的日志为:

  0%|                                                                                                                                                    | 0/2429 [00:00<?, ?it/s]error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> dalle-mini
error -> dalle-mega-v16
error -> dalle-mega-v26
error -> dalle-mega
error -> ernie-code-base
error -> ernie-code-base-L512
error -> ernie-code-base
error -> ernie-code-base-L512
error -> ernie-code-base
error -> ernie-code-base-L512
error -> ernie-code-base
error -> ernie-code-base-L512

 30%|████████████████████████████████████████▊                                                                                               | 729/2429 [00:00<00:00, 4059.37it/s]
 47%|███████████████████████████████████████████████████████████████                                                                        | 1135/2429 [00:00<00:00, 2752.97it/s]
 63%|████████████████████████████████████████████████████████████████████████████████████▍                                                  | 1519/2429 [00:00<00:00, 2813.16it/s]�[33m[2023-12-28 15:52:40,981] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,982] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,983] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,984] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,984] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m
�[33m[2023-12-28 15:52:40,984] [ WARNING]�[0m - You are using a model of type layoutlmv2 to instantiate a model of type layoutxlm. This is not supported for all configurations of models and can yield errors.�[0m

 74%|█████████████████████████████████████████████████████████████████████████████████████████████████████                                   | 1806/2429 [00:08<00:05, 123.11it/s]
 82%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                         | 1982/2429 [00:09<00:03, 124.91it/s]error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
error -> t5-small
error -> t5-base
error -> t5-large
error -> t5-v1_1-base
error -> t5-v1_1-large
error -> t5-3b
error -> t5-11b
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2429/2429 [00:09<00:00, 210.77it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2429/2429 [00:09<00:00, 246.22it/s]

@wj-Mcat
Copy link
Contributor Author

wj-Mcat commented Dec 28, 2023

从上述的 log 日志中可以看出,通过 max-position-embedding的候选关键字列表可以识别绝大部分的模型的最大长度。

Copy link

codecov bot commented Dec 28, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (dca0575) 57.29% compared to head (f4cd00c) 57.31%.
Report is 17 commits behind head on develop.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7740      +/-   ##
===========================================
+ Coverage    57.29%   57.31%   +0.01%     
===========================================
  Files          584      584              
  Lines        87636    87688      +52     
===========================================
+ Hits         50209    50254      +45     
- Misses       37427    37434       +7     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

bot_response["utterance"] = bot_response["utterance"][:-5]
text += bot_response["utterance"]

print("result -> ", text)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

print放到外面吧

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这样外面每个都需要 print

这个是个示例脚本,只是介绍他们怎么用,这个就没必要吧

query = [chat_query]

generation_args = data
self.predictor.config.max_length = generation_args["max_length"]
if "src_length" in generation_args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 src_length 和 max_length 加起来不能超过最大值吧?,如果超过了需要改一下吧?

Copy link
Contributor Author

@wj-Mcat wj-Mcat Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

关于这块之前其实是有讨论过的,结论是:分别用 src_length 和 max_length 分别来做输入和输出的最大长度控制,并且不要强制控制。

不强制控制是在于模型可能是存在支持外推的能力,此时用这个就限制了外推的效果了。

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

那就枚举出支持外推的模型,然后给他不做限制,不能外推的就要限制

# also support history data
elif isinstance(history, str):
history = json.loads(history)

assert len(history) % 2 == 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果len(history) == 0: 下面的都不用做了,很担心成了 query = [[]]

llm/gradio_ui.py Outdated
info="生成结果的最大长度。",
)
default_src_length = default_params["src_length"]
src_length = gr.Slider(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个往上面挪一点,先src length再max length

llm/predictor.py Outdated
Comment on lines 891 to 895
if predictor.config.src_length is None:
predictor.config.src_length = get_default_max_encoding_length(predictor.model_config)

if predictor.config.max_length is None:
predictor.config.max_length = get_default_max_decoding_length(predictor.model_config)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是不是要判断一下,当用户指定这两个参数的时候,src len + max len 会超过最大值吗?你默认值不可能超过,但是一旦用户自己指定了,有可能超过

"temperature": 0.95,
"repetition_penalty": 1.3,
"max_length": 100,
"src_length": 100,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前flask没有兜底的机制,当src len + max len 无限长的时候,程序直接挂了

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还是没有解决问题,需要给个error message吗?

@@ -68,6 +68,7 @@ def __init__(self, args: ServerArgument, predictor: BasePredictor):
self.args.flask_port + port_interval * predictor.tensor_parallel_rank,
self.args.flask_port + port_interval * (predictor.tensor_parallel_rank + 1),
)
self.total_max_length = predictor.config.src_length + predictor.config.max_length
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。

我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面之前的判断确定predictor.config.src_length + predictor.config.max_length不可能大于max_position,因此这里直接是max_position就可以了。。。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里如果启动predictor的默认值为None的时候就是等于模型的最大值。如果用户指定了src len 和 max len,如128 和 64那么最大值就是192了。

我觉得这里写成 self.total_max_length = max(max_position 如4096 , predictor.config.src_length + predictor.config.max_length)

现在的逻辑就是这样,你可以看:create_predictor 方法最下面的初始化过程。

所以,在 flask_server 里面出来之后就肯定是初始化好了的。

@wj-Mcat wj-Mcat marked this pull request as ready for review January 2, 2024 11:23
Copy link
Member

@JunnYu JunnYu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wj-Mcat wj-Mcat merged commit 04dc625 into PaddlePaddle:develop Jan 4, 2024
10 of 11 checks passed
JunnYu pushed a commit that referenced this pull request Jan 4, 2024
* support dynamic src_length

* revert max_position_embedding

* update doc

* update flask_server

* update max_length control

* update request flask_server

* fix max-position-embeddings

* update error message

* update predictor length init
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants