-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for chatglm2 #649
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your contribution! Can you resolve the issues and make sure the formatting check passes by running format.sh
? Thanks again!
if self.multi_query_attention: | ||
key_layer = key_layer.unsqueeze(-2) | ||
key_layer = key_layer.expand( | ||
-1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 | ||
) | ||
key_layer = key_layer.contiguous().view( | ||
key_layer.size()[:1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) | ||
) | ||
value_layer = value_layer.unsqueeze(-2) | ||
value_layer = value_layer.expand( | ||
-1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1 | ||
) | ||
value_layer = value_layer.contiguous().view( | ||
value_layer.size()[:1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
vLLM has native support for multi-query attention and this expand
operation can be avoided. This should greatly improve the speed and memory utilization of ChatGLM models. You can refer to GPTBigCode's implementation below:
vllm/vllm/model_executor/models/gpt_bigcode.py
Lines 59 to 89 in a57d13c
self.multi_query = config.multi_query | |
if self.multi_query: | |
self.num_kv_heads = 1 | |
self.kv_dim = self.head_dim | |
self.c_attn_q = ColumnParallelLinear(self.hidden_size, | |
self.hidden_size, | |
bias=True, | |
gather_output=False, | |
perform_initialization=False) | |
self.c_attn_kv = nn.Linear(self.hidden_size, | |
2 * self.kv_dim, | |
bias=True) | |
else: | |
self.num_kv_heads = self.num_heads | |
self.kv_dim = self.num_kv_heads * self.head_dim | |
self.c_attn = ColumnParallelLinear(self.hidden_size, | |
self.hidden_size + | |
2 * self.kv_dim, | |
bias=True, | |
gather_output=False, | |
perform_initialization=False) | |
self.c_proj = RowParallelLinear(self.hidden_size, | |
self.hidden_size, | |
bias=True, | |
input_is_parallel=True, | |
perform_initialization=False) | |
self.attn = PagedAttention(self.num_heads, | |
self.head_dim, | |
scale=self.scale, | |
num_kv_heads=self.num_kv_heads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, can I just replace expand
operation with repeat_interleave
? I found if I passed the num_kv_heads
(which is 2 in chatglm2) params to PagedAttention
, the output would be confusing. I thought it was caused by the self.head_mapping
. If you have some other advice, please tell me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @canghongjian! Sorry for the delayed response. The current implementation uses multi-head attention kernel instead of the grouped-query attention in vLLM. Using the multi-head attention kernel requires a redundant copy of key and value tensors. In other words, the expand
or the repeat_interleave
operators here are redundant and waste memory.
I believe setting num_kv_heads=2
is correct, but it seems like the default head_mapping
is incorrect. I took a look at the chatGLM's source code on huggingface, it seems like instead of letting the head_mapping
to be:
[0, 0, ..., 0, 0, 1, 1, ..., 1, 1]
It should be:
[0, 1, ..., 0, 1, 0, 1, ..., 0, 1]
.
Can you try to modify the torch.repeat_interleave
to torch.repeat
in attention.py
(below) and see whether this can make the model correct?
vllm/vllm/model_executor/layers/attention.py
Lines 69 to 71 in 65fc1c3
self.head_mapping = torch.repeat_interleave( | |
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"), | |
self.num_queries_per_kv) |
vllm/vllm/model_executor/layers/attention.py
Lines 105 to 108 in 65fc1c3
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1) | |
value = torch.repeat_interleave(value, | |
self.num_queries_per_kv, | |
dim=1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have tried @zhuohan123 . The main modification is:
# in chatglm.py
......
self.atten = PagedAttention(config.num_attention_heads,
self.hidden_size_per_attention_head,
self.norm_factor,
num_kv_heads=(self.num_multi_query_groups_per_partition
if self.multi_query_attention else config.num_attention_heads))
......
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos).reshape(
-1, (self.num_attention_heads_per_partition *
self.hidden_size_per_attention_head))
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos).reshape(
-1, (self.num_multi_query_groups_per_partition *
self.hidden_size_per_attention_head))
context_layer = self.atten(query_layer, key_layer, value_layer,
k_cache, v_cache, input_metadata,
cache_event)
# in attention.py
self.head_mapping = torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda").repeat(self.num_queries_per_kv)
......
if self.num_kv_heads != self.num_heads:
# Project the key and value tensors to the desired number of heads.
# key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
# value = torch.repeat_interleave(value,
# self.num_queries_per_kv,
# dim=1)
key = key.repeat([1,self.num_queries_per_kv,1])
value = value.repeat([1,self.num_queries_per_kv,1])
However, the output is still confusing. Did I make mistakes?
In addition, I don't really understand your words "Using the multi-head attention kernel requires a redundant copy of key and value tensors." In my original code, I just copy the key and value tensors before PagedAttention
call, and in attention.py
it will not step into the key = torch.repeat_interleave
part because the self.num_kv_heads
is equal to self.num_heads
when I don't pass the num_kv_heads
param. Therefore, the copy operation is only performed once in both settings.
Looking forward to your reply.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I believe my main point is that our PagedAttention
class supports the case where the number of query heads is not the same as the number of key/value heads. In your code, this property is not utilized. In other words, "the copy of the key and value tensors before PagedAttention" is redundant and suppose to be avoided.
I believe the confusing results are caused by the incorrect mapping between query heads and key/value heads. When the number of query heads is different from key/value heads, we use the head_mapping
array to record which key/value head each query head corresponds to.
Can you take a more detailed look into this mapping? I believe once we set the head_mapping
correct, we should be able to get correct results. Let me know if there are any difficulties.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some attempts, I still failed. Could you please try some debugging?
Hi, TP好像无法正常使用 2023-08-09 11:13:24,653 INFO worker.py:1636 -- Started a local Ray instance. |
Thanks to @canghongjian contribution, @exceedzhang I have just done this work that integration the chatglm2 support in fastchat vllm_worker, the problem I found is chatglm2's tokenizer.eos_token is '</s>', which decode as ''. |
I think this maybe becuase transformers version problem. |
@baildagq 感谢解答,我已经找到解决方案了! |
vllm/transformers_utils/config.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your contribution! I have already run and verified this code locally. One code suggestion is to follow the original code structure design. In my opinion, origin design is add new config class for models that need special handle. Here we need rename attribute num_layers to num_hidden_layers, this better add an entry in _CONFIG_REGISTRY k-v as 'chatglm': ChatGLMConfig. Inner ChatGLMConfig, add attribute_map
for attribute rename. The attribute map will be processed in transformers __setattr__
https://github.com/huggingface/transformers/blob/d0c1aebea467af499331234e7b285a6bf91ea073/src/transformers/configuration_utils.py#L253
A work ChatGLMConfig follows:
from transformers import PretrainedConfig
class ChatGLMConfig(PretrainedConfig):
attribute_map = {
"num_hidden_layers": "num_layers",
}
def __init__(
self,
num_layers=28,
padded_vocab_size=65024,
hidden_size=4096,
ffn_hidden_size=13696,
kv_channels=128,
num_attention_heads=32,
seq_length=2048,
hidden_dropout=0.0,
attention_dropout=0.0,
layernorm_epsilon=1e-5,
rmsnorm=True,
apply_residual_connection_post_layernorm=False,
post_layer_norm=True,
add_bias_linear=False,
add_qkv_bias=False,
interleaved_qkv=False,
bias_dropout_fusion=True,
multi_query_attention=False,
multi_query_group_num=1,
apply_query_key_layer_scaling=True,
attention_softmax_in_fp32=True,
fp32_residual_connection=False,
quantization_bit=0,
pre_seq_len=None,
prefix_projection=False,
**kwargs
):
self.num_layers = num_layers
self.vocab_size = padded_vocab_size
self.padded_vocab_size = padded_vocab_size
self.hidden_size = hidden_size
self.ffn_hidden_size = ffn_hidden_size
self.kv_channels = kv_channels
self.num_attention_heads = num_attention_heads
self.seq_length = seq_length
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.layernorm_epsilon = layernorm_epsilon
self.rmsnorm = rmsnorm
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.post_layer_norm = post_layer_norm
self.add_bias_linear = add_bias_linear
self.add_qkv_bias = add_qkv_bias
self.bias_dropout_fusion = bias_dropout_fusion
self.multi_query_attention = multi_query_attention
self.multi_query_group_num = multi_query_group_num
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_softmax_in_fp32 = attention_softmax_in_fp32
self.fp32_residual_connection = fp32_residual_connection
self.quantization_bit = quantization_bit
self.pre_seq_len = pre_seq_len
self.prefix_projection = prefix_projection
super().__init__(**kwargs)
__CONFIG_REGISTRY:
_CONFIG_REGISTRY = {
"mpt": MPTConfig,
"baichuan": BaiChuanConfig,
"RefinedWeb": RWConfig, # For tiiuae/falcon-40b(-instruct)
"RefinedWebModel": RWConfig, # For tiiuae/falcon-7b(-instruct)
"chatglm": ChatGLMConfig
}
And you need import the ChatGLMConfig class under vllm/transformers_utils/configs/__init__.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestions, I will update it and push a commit later. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, I still got the issue 'ChatGLMConfig' object has no attribute 'num_hidden_layers' with chatglm2-6b, how could fix this , thank you.
python : 3.8.10
transformers: 4.33.2
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi, I still got the issue 'ChatGLMConfig' object has no attribute 'num_hidden_layers' with chatglm2-6b, how could fix this , thank you.
python : 3.8.10 transformers: 4.33.2
Oh, maybe you should check whether your code is modified. You can get the code by git clone https://github.com/canghongjian/vllm.git
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you and greeting. However I met another problem: "The model's config.json must contain one of the following keys to determine the original maximum length of the model: ['max_position_embeddings', 'n_positions', 'max_seq_len', 'max_sequence_length', 'max_seq_length', 'seq_len']" after reinstalling by git cloning project. Seems something is not compatible.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hi thanks to this. I successfully installed by pull the latest. And confront this problem while running the demo code above with original chatglm2-6b:
File ~/autodl-tmp/vllm/vllm/model_executor/models/chatglm.py:478, in ChatGLMModel.init(self, config)
475 super().init()
477 self.config = config
--> 478 self.embedding = VocabParallelEmbedding(
479 config.vocab_size,
480 config.hidden_size,
481 perform_initialization=False,
482 params_dtype=config.torch_dtype)
484 self.num_layers = config.num_layers
485 self.multi_query_group_num = config.multi_query_group_num
TypeError: init() got an unexpected keyword argument 'perform_initialization'
is this a BUG ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, thanks to report. pls pull the latest again
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
new question occured:
"File ~/autodl-tmp/vllm/vllm/model_executor/models/chatglm.py:546, in ChatGLMModel.load_weights(self, model_name_or_path, cache_dir, load_format, revision)
543 if name.startswith("embedding"):
544 name = name.replace(".word_embeddings", "")
--> 546 param = state_dict[name]
547 load_tensor_parallel_weights(param, loaded_weight, name,
548 self._column_parallel_weights,
549 self._row_parallel_weights,
550 tensor_model_parallel_rank)
KeyError: 'lm_head.weight' "
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just tested the official original model. Is the model you are using different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much, I just reload env, and successfully run the demo script on original model.
This question hapnens when using fintuned model with lora.
感谢分享,
|
vllm preallocates 90% GPU memory by default, and you can change this ratio by modifying |
Hi, I run on v100 * 4, tensor_parallel_size =4, load_weight fail. BTW: only 1 V100 is working fine. Namespace(backend='vllm', dataset='/mnt/chw/ShareGPT_V3_unfiltered_cleaned_split.json', hf_max_batch_size=None, model='/mnt/chw/chatglm2/chatglm2-6b', n=1, num_prompts=1, seed=0, tensor_parallel_size=4, tokenizer='/mnt/chw/chatglm2/chatglm2-6b', trust_remote_code=True, use_beam_search=False) ray.exceptions.RayTaskError(AssertionError): ray::RayWorker.execute_method() (pid=2394870, ip=10.1.1.242, actor_id=aee050fd938cd8f3d6997a6401000000, repr=<vllm.engine.ray_utils.RayWorker object at 0x7efe5659a9a0>) AssertionError: encoder.layers.22.self_attention.query_key_value.bias shape mismatch between model and checkpoint: torch.Size([1152]) != torch.Size([4608]) |
Thanks for pointing out this issue @chenhaiwu . I have reproduced this problem and can confirm that the existing chatglm code does not support tensor parallelism. Fixing this issue might take a few days, while I believe the improvements may be minimal. This is because chatglm2 sets the default MQA key-value head values to 2, allowing only a tensor_parallel_size of 2 or 1. Considering this I recommend that you use it in |
你好,如果我想知道chatglm.py具体为何这样修改,需要了解哪些知识呢,目前对这一块的内容不太熟悉,求指点~ |
I merge this pr, after reinstall vllm. single A10 (24GB) can not load chatglm6b. any bugs? OOM, crash info: File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/entrypoints/llm.py", line 66, in init |
Hi, thank for your work, i have try this pr, but i can only get ~50% usage on a single A100. But Qwen model can use about 95%. Do you have any suggestion? |
What does '~50% usage' mean? Is it the GPU memory usage? Does the output seem correct? Please provide more detailed information. |
Model output seems correct |
我用vllm 0.1.17版本的代码,按照提交的文件修改后,使用vllm提供的测试脚本,运行起来,看起来没问题。模型是chatglm2-6b-32k。
用@canghongjian提供的脚本,也可以正常运行
|
good job~ |
vllm=0.2.0 Traceback (most recent call last): 这是我的加载代码: |
请问transformer的版本是多少呢? |
ok我跑起来了。 按照add support修改,之后修改vllm/model_executor/models/chatglm.py中的class ChatGLMModel里的def load_weights
改为:
之后用一下代码加载: 应该就行了 |
你是改了包之后pip install .安装了么,我一直报“The NVIDIA driver on your system is too old”的错。@128Ghe980 |
A10,单卡,api server启动,demo里面的数据生成,只有27tokens/s左右的速度,哪里有问题? |
不是,我是直接pip install vllm==0.2.0,之后再去更改的 |
Thank you for the contribution, unfortunately this PR seems to became stale and ChatGLM3 also came out. Feel free to coordinate the contribution here if you have bandwidth! |
hi , I still encounter the problem of attributeError: 'ChatGLMConfig' object has no attribute 'num_hidden_layers'. I tried to update the vllm version and still have problems. my vllm==0.2.1.post1,transformers = =4.27.0, |
@Midnight-719 I would recommend to take a look at the following working PRs: |
Again, thank you for your contribution! We merged #1261 and the current main branch supports ChatGLM now. Let us know if the main branch does not look good and feel free to propose any changes! |
0.1.17对GPU的要求是什么? |
Hi repository maintainers! Thanks for your excited work, and I have implemented chatglm2 support for vllm. Here is something you should notice:
chatglm.py
and use the standardPagedAttention
.I provide a quick evaluation script as follows:
You should see the results:
They are the same as the original version.
Additionally, I have tested the speed between
chatglm2_vllm
andchatglm2_original
on A10. For ShareGPT dataset,chatglm2_original
achieves 119.3 tokens per second andchatglm2_vllm
achieves 1015.1 tokens per second. For Chinese dataset(abstractive summarization),chatglm2_original
achieves 737.3 tokens per second andchatglm2_vllm
achieves 3317.6 tokens per second. We can see vllm speeds up 8.5x for English and 4.5x for Chinese. The 'tokens' means the sum of 'prompt tokens' and 'output length'.I also tested the api server speed for Chinese.
chatglm2_original
achieves 28 generated tokens per second andchatglm2_vllm
achieves 487 generated tokens per second, which means 17.4x speeds up.Hi @zhuohan123 @WoosukKwon, could you please have a quick check? Let me know if there is any problem.