Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for chatglm2 #649

Closed
wants to merge 5 commits into from
Closed

Conversation

canghongjian
Copy link

@canghongjian canghongjian commented Aug 2, 2023

Hi repository maintainers! Thanks for your excited work, and I have implemented chatglm2 support for vllm. Here is something you should notice:

  • I register the class as ''ChatGLMModel".
  • I implement the rope and multi-query attention in chatglm.py and use the standard PagedAttention.

I provide a quick evaluation script as follows:

from vllm import LLM, SamplingParams
  
def build_prompt(prompt):
    return "[Round 1]\n\n问:{}\n\n答:".format(prompt)

content = "你好"
prompts = [
    content,
    "晚上睡不着应该怎么办"
]
prompts = [build_prompt(item) for item in prompts]
# edit this model_url
model_url = 'model/chatglm2-6b'

sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=8192)
llm = LLM(model=model_url, gpu_memory_utilization=0.98, dtype="float16", trust_remote_code=True)
outputs = llm.generate(prompts, sampling_params)

# Print the outputs.
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

You should see the results:

Prompt: '[Round 1]\n\n问:你好\n\n答:', Generated text: '你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。'
Prompt: '[Round 1]\n\n问:晚上睡不着应该怎么办\n\n答:', Generated text: '以下是一些有助于晚上睡觉的技巧:\n\n1. 建立一个固定的睡眠时间表:每天在相同的时间上床和起床可以帮助身体形成固定的生物钟。\n\n2. 创建一个舒适的睡眠环境:保持房间安静、黑暗、凉爽和舒适可以帮助睡眠。\n\n3. 避免使用电子设备:在睡觉前一小时避免使用电子设备,如手机、电脑、平板等,因为它们可能会干扰睡眠。\n\n4. 放松身心:在睡觉前进行一些放松身心的活动,如泡个热水澡、听轻柔的音乐或进行冥想,有助于缓解压力和焦虑。\n\n5. 限制咖啡因和酒精:避免在睡觉前摄入咖啡因和酒精,因为它们可能会影响睡眠。\n\n6. 远离刺激:避免在睡觉前进行刺激性的活动,如激烈的运动或紧张的任务。\n\n7. 远离躺在床上翻来覆去:如果躺在床上超过20分钟还不能入睡,不要躺在床上翻来覆去,而是起床去做一些平静的活动,如阅读或听轻柔的音乐,直到感到困倦再返回床上。\n\n8. 考虑使用睡眠辅助工具:如果以上方法不能解决问题,可以考虑使用睡眠辅助工具,如睡眠面具、耳塞或鼻塞,但请在使用前咨询医生或专业人士的意见。'

They are the same as the original version.
Additionally, I have tested the speed betweenchatglm2_vllm and chatglm2_original on A10. For ShareGPT dataset, chatglm2_original achieves 119.3 tokens per second and chatglm2_vllm achieves 1015.1 tokens per second. For Chinese dataset(abstractive summarization), chatglm2_original achieves 737.3 tokens per second and chatglm2_vllm achieves 3317.6 tokens per second. We can see vllm speeds up 8.5x for English and 4.5x for Chinese. The 'tokens' means the sum of 'prompt tokens' and 'output length'.

I also tested the api server speed for Chinese. chatglm2_original achieves 28 generated tokens per second and chatglm2_vllm achieves 487 generated tokens per second, which means 17.4x speeds up.

Hi @zhuohan123 @WoosukKwon, could you please have a quick check? Let me know if there is any problem.

@exceedzhang
Copy link

整合到FastChat项目中,发现没有输出
image

image

@canghongjian
Copy link
Author

整合到FastChat项目中,发现没有输出 image

image

Hi, here is my result, where the sampling params are taken from your picture:
image

Maybe you should check this script in vllm first, then check how you incorporate vllm to FastChat.

Copy link
Member

@zhuohan123 zhuohan123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution! Can you resolve the issues and make sure the formatting check passes by running format.sh? Thanks again!

vllm/model_executor/models/chatglm.py Outdated Show resolved Hide resolved
Comment on lines 306 to 320
if self.multi_query_attention:
key_layer = key_layer.unsqueeze(-2)
key_layer = key_layer.expand(
-1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
)
key_layer = key_layer.contiguous().view(
key_layer.size()[:1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
value_layer = value_layer.unsqueeze(-2)
value_layer = value_layer.expand(
-1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1
)
value_layer = value_layer.contiguous().view(
value_layer.size()[:1] + (self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM has native support for multi-query attention and this expand operation can be avoided. This should greatly improve the speed and memory utilization of ChatGLM models. You can refer to GPTBigCode's implementation below:

self.multi_query = config.multi_query
if self.multi_query:
self.num_kv_heads = 1
self.kv_dim = self.head_dim
self.c_attn_q = ColumnParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_attn_kv = nn.Linear(self.hidden_size,
2 * self.kv_dim,
bias=True)
else:
self.num_kv_heads = self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.c_attn = ColumnParallelLinear(self.hidden_size,
self.hidden_size +
2 * self.kv_dim,
bias=True,
gather_output=False,
perform_initialization=False)
self.c_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=True,
input_is_parallel=True,
perform_initialization=False)
self.attn = PagedAttention(self.num_heads,
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, can I just replace expand operation with repeat_interleave ? I found if I passed the num_kv_heads (which is 2 in chatglm2) params to PagedAttention, the output would be confusing. I thought it was caused by the self.head_mapping. If you have some other advice, please tell me.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @canghongjian! Sorry for the delayed response. The current implementation uses multi-head attention kernel instead of the grouped-query attention in vLLM. Using the multi-head attention kernel requires a redundant copy of key and value tensors. In other words, the expand or the repeat_interleave operators here are redundant and waste memory.

I believe setting num_kv_heads=2 is correct, but it seems like the default head_mapping is incorrect. I took a look at the chatGLM's source code on huggingface, it seems like instead of letting the head_mapping to be:

[0, 0, ..., 0, 0, 1, 1, ..., 1, 1]

It should be:

[0, 1, ..., 0, 1, 0, 1, ..., 0, 1].

Can you try to modify the torch.repeat_interleave to torch.repeat in attention.py (below) and see whether this can make the model correct?

self.head_mapping = torch.repeat_interleave(
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
self.num_queries_per_kv)

key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=1)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have tried @zhuohan123 . The main modification is:

# in chatglm.py
......

self.atten = PagedAttention(config.num_attention_heads,
                            self.hidden_size_per_attention_head,
                            self.norm_factor,
                            num_kv_heads=(self.num_multi_query_groups_per_partition
    if self.multi_query_attention else config.num_attention_heads))

......

query_layer = apply_rotary_pos_emb(query_layer, rotary_pos).reshape(
    -1, (self.num_attention_heads_per_partition *
    self.hidden_size_per_attention_head))
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos).reshape(
    -1, (self.num_multi_query_groups_per_partition *
    self.hidden_size_per_attention_head))

context_layer = self.atten(query_layer, key_layer, value_layer,
                                   k_cache, v_cache, input_metadata,
                                   cache_event)
# in attention.py
self.head_mapping = torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda").repeat(self.num_queries_per_kv)

......

if self.num_kv_heads != self.num_heads:
    # Project the key and value tensors to the desired number of heads.
    # key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=1)
    # value = torch.repeat_interleave(value,
    #                                 self.num_queries_per_kv,
    #                                 dim=1)
    key = key.repeat([1,self.num_queries_per_kv,1])
    value = value.repeat([1,self.num_queries_per_kv,1])

However, the output is still confusing. Did I make mistakes?

In addition, I don't really understand your words "Using the multi-head attention kernel requires a redundant copy of key and value tensors." In my original code, I just copy the key and value tensors before PagedAttention call, and in attention.py it will not step into the key = torch.repeat_interleave part because the self.num_kv_heads is equal to self.num_heads when I don't pass the num_kv_heads param. Therefore, the copy operation is only performed once in both settings.

Looking forward to your reply.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I believe my main point is that our PagedAttention class supports the case where the number of query heads is not the same as the number of key/value heads. In your code, this property is not utilized. In other words, "the copy of the key and value tensors before PagedAttention" is redundant and suppose to be avoided.

I believe the confusing results are caused by the incorrect mapping between query heads and key/value heads. When the number of query heads is different from key/value heads, we use the head_mapping array to record which key/value head each query head corresponds to.

Can you take a more detailed look into this mapping? I believe once we set the head_mapping correct, we should be able to get correct results. Let me know if there are any difficulties.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some attempts, I still failed. Could you please try some debugging?

@Kevinddddddd
Copy link

Hi, TP好像无法正常使用

2023-08-09 11:13:24,653 INFO worker.py:1636 -- Started a local Ray instance.
INFO 08-09 11:13:25 llm_engine.py:70] Initializing an LLM engine with config: model='/new_mount_point/llms_checkpoint/chatglm2-6b', tokenizer='/new_mount_point/llms_checkpoint/chatglm2-6b', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=2, seed=0)
WARNING 08-09 11:13:25 tokenizer.py:63] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
Traceback (most recent call last):
File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/new_mount_point/dhr/vllm/vllm/entrypoints/api_server.py", line 78, in
engine = AsyncLLMEngine.from_engine_args(engine_args)
File "/new_mount_point/dhr/vllm/vllm/engine/async_llm_engine.py", line 232, in from_engine_args
engine = cls(engine_args.worker_use_ray,
File "/new_mount_point/dhr/vllm/vllm/engine/async_llm_engine.py", line 55, in init
self.engine = engine_class(*args, **kwargs)
File "/new_mount_point/dhr/vllm/vllm/engine/llm_engine.py", line 99, in init
self._init_workers_ray(placement_group)
File "/new_mount_point/dhr/vllm/vllm/engine/llm_engine.py", line 161, in _init_workers_ray
self._run_workers("init_worker",
File "/new_mount_point/dhr/vllm/vllm/engine/llm_engine.py", line 474, in _run_workers
all_outputs = ray.get(all_outputs)
File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 18, in auto_init_wrapper
return fn(*args, **kwargs)
File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
return func(*args, **kwargs)
File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/worker.py", line 2540, in get
raise value.as_instanceof_cause()
ray.exceptions.RayTaskError: ray::RayWorker.execute_method() (pid=1715621, ip=172.18.0.44, actor_id=6b6fc7add92690564a2df33901000000, repr=<vllm.engine.ray_utils.RayWorker object at 0x7eea75d519c0>)
At least one of the input arguments for this task could not be computed:
ray.exceptions.RaySystemError: System error: No module named 'transformers_modules'
traceback: Traceback (most recent call last):
ModuleNotFoundError: No module named 'transformers_modules'
(RayWorker pid=1715621) 2023-08-09 11:13:31,087 ERROR serialization.py:387 -- No module named 'transformers_modules'
(RayWorker pid=1715621) Traceback (most recent call last):
(RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 385, in deserialize_objects
(RayWorker pid=1715621) obj = self._deserialize_object(data, metadata, object_ref)
(RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 268, in _deserialize_object
(RayWorker pid=1715621) return self._deserialize_msgpack_data(data, metadata_fields)
(RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 223, in _deserialize_msgpack_data
(RayWorker pid=1715621) python_objects = self._deserialize_pickle5_data(pickle5_data)
(RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 213, in _deserialize_pickle5_data
(RayWorker pid=1715621) obj = pickle.loads(in_band)
(RayWorker pid=1715621) ModuleNotFoundError: No module named 'transformers_modules'
(RayWorker pid=1715620) 2023-08-09 11:13:31,103 ERROR serialization.py:387 -- No module named 'transformers_modules'
(RayWorker pid=1715620) Traceback (most recent call last):
(RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 385, in deserialize_objects
(RayWorker pid=1715620) obj = self._deserialize_object(data, metadata, object_ref)
(RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 268, in _deserialize_object
(RayWorker pid=1715620) return self._deserialize_msgpack_data(data, metadata_fields)
(RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 223, in _deserialize_msgpack_data
(RayWorker pid=1715620) python_objects = self._deserialize_pickle5_data(pickle5_data)
(RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 213, in _deserialize_pickle5_data
(RayWorker pid=1715620) obj = pickle.loads(in_band)
(RayWorker pid=1715620) ModuleNotFoundError: No module named 'transformers_modules'

@baildagq
Copy link

baildagq commented Aug 9, 2023

整合到FastChat项目中,发现没有输出 image
image

Hi, here is my result, where the sampling params are taken from your picture: image

Maybe you should check this script in vllm first, then check how you incorporate vllm to FastChat.

Thanks to @canghongjian contribution, @exceedzhang I have just done this work that integration the chatglm2 support in fastchat vllm_worker, the problem I found is chatglm2's tokenizer.eos_token is '</s>', which decode as ''.
So in the generation process of vllm, the exit condition is triggered after the first character is generated because any generated text will end with ''.
In my opinion, this bug should be solved in fastchat such as add twice validation for tokenizer.eos_token_id before add into SamplingParams Stop list.

https://github.com/lm-sys/FastChat/blob/3dc91c522e1ed82b6f24cb9866d8d9c06ff28d7b/fastchat/serve/vllm_worker.py#L71

@baildagq
Copy link

baildagq commented Aug 9, 2023

Hi, TP好像无法正常使用

2023-08-09 11:13:24,653 INFO worker.py:1636 -- Started a local Ray instance. INFO 08-09 11:13:25 llm_engine.py:70] Initializing an LLM engine with config: model='/new_mount_point/llms_checkpoint/chatglm2-6b', tokenizer='/new_mount_point/llms_checkpoint/chatglm2-6b', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, use_dummy_weights=False, download_dir=None, use_np_weights=False, tensor_parallel_size=2, seed=0) WARNING 08-09 11:13:25 tokenizer.py:63] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead. Traceback (most recent call last): File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/runpy.py", line 196, in _run_module_as_main return _run_code(code, main_globals, None, File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/runpy.py", line 86, in _run_code exec(code, run_globals) File "/new_mount_point/dhr/vllm/vllm/entrypoints/api_server.py", line 78, in engine = AsyncLLMEngine.from_engine_args(engine_args) File "/new_mount_point/dhr/vllm/vllm/engine/async_llm_engine.py", line 232, in from_engine_args engine = cls(engine_args.worker_use_ray, File "/new_mount_point/dhr/vllm/vllm/engine/async_llm_engine.py", line 55, in init self.engine = engine_class(*args, **kwargs) File "/new_mount_point/dhr/vllm/vllm/engine/llm_engine.py", line 99, in init self._init_workers_ray(placement_group) File "/new_mount_point/dhr/vllm/vllm/engine/llm_engine.py", line 161, in _init_workers_ray self._run_workers("init_worker", File "/new_mount_point/dhr/vllm/vllm/engine/llm_engine.py", line 474, in _run_workers all_outputs = ray.get(all_outputs) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 18, in auto_init_wrapper return fn(*args, **kwargs) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 103, in wrapper return func(*args, **kwargs) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/worker.py", line 2540, in get raise value.as_instanceof_cause() ray.exceptions.RayTaskError: ray::RayWorker.execute_method() (pid=1715621, ip=172.18.0.44, actor_id=6b6fc7add92690564a2df33901000000, repr=<vllm.engine.ray_utils.RayWorker object at 0x7eea75d519c0>) At least one of the input arguments for this task could not be computed: ray.exceptions.RaySystemError: System error: No module named 'transformers_modules' traceback: Traceback (most recent call last): ModuleNotFoundError: No module named 'transformers_modules' (RayWorker pid=1715621) 2023-08-09 11:13:31,087 ERROR serialization.py:387 -- No module named 'transformers_modules' (RayWorker pid=1715621) Traceback (most recent call last): (RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 385, in deserialize_objects (RayWorker pid=1715621) obj = self._deserialize_object(data, metadata, object_ref) (RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 268, in _deserialize_object (RayWorker pid=1715621) return self._deserialize_msgpack_data(data, metadata_fields) (RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 223, in _deserialize_msgpack_data (RayWorker pid=1715621) python_objects = self._deserialize_pickle5_data(pickle5_data) (RayWorker pid=1715621) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 213, in _deserialize_pickle5_data (RayWorker pid=1715621) obj = pickle.loads(in_band) (RayWorker pid=1715621) ModuleNotFoundError: No module named 'transformers_modules' (RayWorker pid=1715620) 2023-08-09 11:13:31,103 ERROR serialization.py:387 -- No module named 'transformers_modules' (RayWorker pid=1715620) Traceback (most recent call last): (RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 385, in deserialize_objects (RayWorker pid=1715620) obj = self._deserialize_object(data, metadata, object_ref) (RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 268, in _deserialize_object (RayWorker pid=1715620) return self._deserialize_msgpack_data(data, metadata_fields) (RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 223, in _deserialize_msgpack_data (RayWorker pid=1715620) python_objects = self._deserialize_pickle5_data(pickle5_data) (RayWorker pid=1715620) File "/new_mount_point/anaconda3/envs/deploy/lib/python3.10/site-packages/ray/_private/serialization.py", line 213, in _deserialize_pickle5_data (RayWorker pid=1715620) obj = pickle.loads(in_band) (RayWorker pid=1715620) ModuleNotFoundError: No module named 'transformers_modules'

I think this maybe becuase transformers version problem.

@exceedzhang
Copy link

@baildagq 感谢解答,我已经找到解决方案了!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution! I have already run and verified this code locally. One code suggestion is to follow the original code structure design. In my opinion, origin design is add new config class for models that need special handle. Here we need rename attribute num_layers to num_hidden_layers, this better add an entry in _CONFIG_REGISTRY k-v as 'chatglm': ChatGLMConfig. Inner ChatGLMConfig, add attribute_map for attribute rename. The attribute map will be processed in transformers __setattr__
https://github.com/huggingface/transformers/blob/d0c1aebea467af499331234e7b285a6bf91ea073/src/transformers/configuration_utils.py#L253

A work ChatGLMConfig follows:

from transformers import PretrainedConfig


class ChatGLMConfig(PretrainedConfig):
    attribute_map = {
        "num_hidden_layers": "num_layers",
    }

    def __init__(
        self,
        num_layers=28,
        padded_vocab_size=65024,
        hidden_size=4096,
        ffn_hidden_size=13696,
        kv_channels=128,
        num_attention_heads=32,
        seq_length=2048,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        layernorm_epsilon=1e-5,
        rmsnorm=True,
        apply_residual_connection_post_layernorm=False,
        post_layer_norm=True,
        add_bias_linear=False,
        add_qkv_bias=False,
        interleaved_qkv=False,
        bias_dropout_fusion=True,
        multi_query_attention=False,
        multi_query_group_num=1,
        apply_query_key_layer_scaling=True,
        attention_softmax_in_fp32=True,
        fp32_residual_connection=False,
        quantization_bit=0,
        pre_seq_len=None,
        prefix_projection=False,
        **kwargs
    ):
        self.num_layers = num_layers
        self.vocab_size = padded_vocab_size
        self.padded_vocab_size = padded_vocab_size
        self.hidden_size = hidden_size
        self.ffn_hidden_size = ffn_hidden_size
        self.kv_channels = kv_channels
        self.num_attention_heads = num_attention_heads
        self.seq_length = seq_length
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout
        self.layernorm_epsilon = layernorm_epsilon
        self.rmsnorm = rmsnorm
        self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
        self.post_layer_norm = post_layer_norm
        self.add_bias_linear = add_bias_linear
        self.add_qkv_bias = add_qkv_bias
        self.bias_dropout_fusion = bias_dropout_fusion
        self.multi_query_attention = multi_query_attention
        self.multi_query_group_num = multi_query_group_num
        self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
        self.attention_softmax_in_fp32 = attention_softmax_in_fp32
        self.fp32_residual_connection = fp32_residual_connection
        self.quantization_bit = quantization_bit
        self.pre_seq_len = pre_seq_len
        self.prefix_projection = prefix_projection
        super().__init__(**kwargs)

__CONFIG_REGISTRY:

_CONFIG_REGISTRY = {
    "mpt": MPTConfig,
    "baichuan": BaiChuanConfig,
    "RefinedWeb": RWConfig,  # For tiiuae/falcon-40b(-instruct)
    "RefinedWebModel": RWConfig,  # For tiiuae/falcon-7b(-instruct)
    "chatglm": ChatGLMConfig
}

And you need import the ChatGLMConfig class under vllm/transformers_utils/configs/__init__.py

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestions, I will update it and push a commit later. Thank you!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, I still got the issue 'ChatGLMConfig' object has no attribute 'num_hidden_layers' with chatglm2-6b, how could fix this , thank you.

python : 3.8.10
transformers: 4.33.2

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, I still got the issue 'ChatGLMConfig' object has no attribute 'num_hidden_layers' with chatglm2-6b, how could fix this , thank you.

python : 3.8.10 transformers: 4.33.2

Oh, maybe you should check whether your code is modified. You can get the code by git clone https://github.com/canghongjian/vllm.git.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you and greeting. However I met another problem: "The model's config.json must contain one of the following keys to determine the original maximum length of the model: ['max_position_embeddings', 'n_positions', 'max_seq_len', 'max_sequence_length', 'max_seq_length', 'seq_len']" after reinstalling by git cloning project. Seems something is not compatible.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi thanks to this. I successfully installed by pull the latest. And confront this problem while running the demo code above with original chatglm2-6b:

File ~/autodl-tmp/vllm/vllm/model_executor/models/chatglm.py:478, in ChatGLMModel.init(self, config)
475 super().init()
477 self.config = config
--> 478 self.embedding = VocabParallelEmbedding(
479 config.vocab_size,
480 config.hidden_size,
481 perform_initialization=False,
482 params_dtype=config.torch_dtype)
484 self.num_layers = config.num_layers
485 self.multi_query_group_num = config.multi_query_group_num

TypeError: init() got an unexpected keyword argument 'perform_initialization'

is this a BUG ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, thanks to report. pls pull the latest again

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new question occured:
"File ~/autodl-tmp/vllm/vllm/model_executor/models/chatglm.py:546, in ChatGLMModel.load_weights(self, model_name_or_path, cache_dir, load_format, revision)
543 if name.startswith("embedding"):
544 name = name.replace(".word_embeddings", "")
--> 546 param = state_dict[name]
547 load_tensor_parallel_weights(param, loaded_weight, name,
548 self._column_parallel_weights,
549 self._row_parallel_weights,
550 tensor_model_parallel_rank)

KeyError: 'lm_head.weight' "

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tested the official original model. Is the model you are using different?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you so much, I just reload env, and successfully run the demo script on original model.

This question hapnens when using fintuned model with lora.

@cdj0311
Copy link

cdj0311 commented Aug 10, 2023

感谢分享,
我使用该代码在A100上测试,速度跟使用huggingface一样,而且显存占用达到70G,1秒只能生成30-40个token,以下是代码,请问哪里有问题:

import time
from vllm import LLM, SamplingParams

model_path = "THUDM/chatglm2-6b"

prompts = [
    "def quick_sort(array):\n",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=500)

llm = LLM(model=model_path, trust_remote_code=True)

for _ in range(5):
    t1 = time.time()
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0]
        print(generated_text.text)
    print(len(generated_text.token_ids), time.time() - t1)

@canghongjian
Copy link
Author

感谢分享, 我使用该代码在A100上测试,速度跟使用huggingface一样,而且显存占用达到70G,1秒只能生成30-40个token,以下是代码,请问哪里有问题:

import time
from vllm import LLM, SamplingParams

model_path = "THUDM/chatglm2-6b"

prompts = [
    "def quick_sort(array):\n",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=500)

llm = LLM(model=model_path, trust_remote_code=True)

for _ in range(5):
    t1 = time.time()
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0]
        print(generated_text.text)
    print(len(generated_text.token_ids), time.time() - t1)

vllm preallocates 90% GPU memory by default, and you can change this ratio by modifying gpu_memory_utilization. I think you do not make full use of vllm by duplicated single call of generate. You can refer to the official benchmark codes to experience the real speed.

@chenhaiwu
Copy link

Hi, I run on v100 * 4, tensor_parallel_size =4, load_weight fail. BTW: only 1 V100 is working fine.

Namespace(backend='vllm', dataset='/mnt/chw/ShareGPT_V3_unfiltered_cleaned_split.json', hf_max_batch_size=None, model='/mnt/chw/chatglm2/chatglm2-6b', n=1, num_prompts=1, seed=0, tensor_parallel_size=4, tokenizer='/mnt/chw/chatglm2/chatglm2-6b', trust_remote_code=True, use_beam_search=False)

ray.exceptions.RayTaskError(AssertionError): ray::RayWorker.execute_method() (pid=2394870, ip=10.1.1.242, actor_id=aee050fd938cd8f3d6997a6401000000, repr=<vllm.engine.ray_utils.RayWorker object at 0x7efe5659a9a0>)
File "/mnt/chw/vllm/vllm/engine/ray_utils.py", line 25, in execute_method
return executor(*args, **kwargs)
File "/mnt/chw/vllm/vllm/worker/worker.py", line 67, in init_model
self.model = get_model(self.model_config)
File "/mnt/chw/vllm/vllm/model_executor/model_loader.py", line 57, in get_model
model.load_weights(model_config.model, model_config.download_dir,
File "/mnt/chw/vllm/vllm/model_executor/models/chatglm.py", line 552, in load_weights
load_tensor_parallel_weights(param, loaded_weight, name,
File "/mnt/chw/vllm/vllm/model_executor/weight_utils.py", line 105, in load_tensor_parallel_weights
assert param.shape == loaded_weight.shape, (

AssertionError: encoder.layers.22.self_attention.query_key_value.bias shape mismatch between model and checkpoint: torch.Size([1152]) != torch.Size([4608])

@canghongjian
Copy link
Author

canghongjian commented Aug 22, 2023

Hi, I run on v100 * 4, tensor_parallel_size =4, load_weight fail. BTW: only 1 V100 is working fine.

Namespace(backend='vllm', dataset='/mnt/chw/ShareGPT_V3_unfiltered_cleaned_split.json', hf_max_batch_size=None, model='/mnt/chw/chatglm2/chatglm2-6b', n=1, num_prompts=1, seed=0, tensor_parallel_size=4, tokenizer='/mnt/chw/chatglm2/chatglm2-6b', trust_remote_code=True, use_beam_search=False)

ray.exceptions.RayTaskError(AssertionError): ray::RayWorker.execute_method() (pid=2394870, ip=10.1.1.242, actor_id=aee050fd938cd8f3d6997a6401000000, repr=<vllm.engine.ray_utils.RayWorker object at 0x7efe5659a9a0>) File "/mnt/chw/vllm/vllm/engine/ray_utils.py", line 25, in execute_method return executor(*args, **kwargs) File "/mnt/chw/vllm/vllm/worker/worker.py", line 67, in init_model self.model = get_model(self.model_config) File "/mnt/chw/vllm/vllm/model_executor/model_loader.py", line 57, in get_model model.load_weights(model_config.model, model_config.download_dir, File "/mnt/chw/vllm/vllm/model_executor/models/chatglm.py", line 552, in load_weights load_tensor_parallel_weights(param, loaded_weight, name, File "/mnt/chw/vllm/vllm/model_executor/weight_utils.py", line 105, in load_tensor_parallel_weights assert param.shape == loaded_weight.shape, (

AssertionError: encoder.layers.22.self_attention.query_key_value.bias shape mismatch between model and checkpoint: torch.Size([1152]) != torch.Size([4608])

Thanks for pointing out this issue @chenhaiwu . I have reproduced this problem and can confirm that the existing chatglm code does not support tensor parallelism. Fixing this issue might take a few days, while I believe the improvements may be minimal. This is because chatglm2 sets the default MQA key-value head values to 2, allowing only a tensor_parallel_size of 2 or 1. Considering this I recommend that you use it in tensor_parallel_size=1 setting.

@UncleFB
Copy link

UncleFB commented Aug 28, 2023

你好,如果我想知道chatglm.py具体为何这样修改,需要了解哪些知识呢,目前对这一块的内容不太熟悉,求指点~

@white-wolf-tech
Copy link

white-wolf-tech commented Aug 31, 2023

I merge this pr, after reinstall vllm. single A10 (24GB) can not load chatglm6b. any bugs? OOM, crash info:

File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/entrypoints/llm.py", line 66, in init
self.llm_engine = LLMEngine.from_engine_args(engine_args)
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/engine/llm_engine.py", line 220, in from_engine_args
engine = cls(*engine_configs,
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/engine/llm_engine.py", line 104, in init
self._init_cache()
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/engine/llm_engine.py", line 208, in _init_cache
self._run_workers("init_cache_engine", cache_config=self.cache_config)
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/engine/llm_engine.py", line 470, in _run_workers
output = executor(*args, **kwargs)
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/worker/worker.py", line 139, in init_cache_engine
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/worker/cache_engine.py", line 44, in init
self.gpu_cache = self.allocate_gpu_cache()
File "/data/.conda/envs/glm200/lib/python3.9/site-packages/vllm-0.1.3-py3.9-linux-x86_64.egg/vllm/worker/cache_engine.py", line 75, in allocate_gpu_cache
key_blocks = torch.empty(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 176.00 MiB (GPU 0; 22.20 GiB total capacity; 20.95 GiB already allocated; 96.12 MiB free; 20.96 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

@zhuohan123 zhuohan123 added the new model Requests to new models label Sep 12, 2023
@liuyanyi
Copy link
Contributor

Hi, thank for your work, i have try this pr, but i can only get ~50% usage on a single A100. But Qwen model can use about 95%. Do you have any suggestion?

@canghongjian
Copy link
Author

Hi, thank for your work, i have try this pr, but i can only get ~50% usage on a single A100. But Qwen model can use about 95%. Do you have any suggestion?

What does '~50% usage' mean? Is it the GPU memory usage? Does the output seem correct? Please provide more detailed information.

@liuyanyi
Copy link
Contributor

Hi, thank for your work, i have try this pr, but i can only get ~50% usage on a single A100. But Qwen model can use about 95%. Do you have any suggestion?

What does '~50% usage' mean? Is it the GPU memory usage? Does the output seem correct? Please provide more detailed information.

Model output seems correct
~50 means GPU utilization is around 50%.

@liuyanyi
Copy link
Contributor

I use benchmarks/benchmark_throughput.py to test the performance between Qwen-7B and ChatGLM2-6B.

Metric Qwen-vllm ChatGLM2-vllm
tokens/s 3008.20 1997.71
requests/s 7.45 4.22

During testing ChatGLM, the GPU utilization shows as below

image

@danny-zhu
Copy link

danny-zhu commented Sep 22, 2023

我用vllm 0.1.17版本的代码,按照提交的文件修改后,使用vllm提供的测试脚本,运行起来,看起来没问题。模型是chatglm2-6b-32k。

You are using a model of type chatglm to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO 09-22 15:08:08 llm_engine.py:72] Initializing an LLM engine with config: model='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-22 15:08:08 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 09-22 15:08:15 llm_engine.py:199] # GPU blocks: 1269, # CPU blocks: 585
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.73it/s]
[RequestOutput(request_id=0, prompt='Hello, my name is', prompt_token_ids=[64790, 64792, 13755, 30932, 552, 1462, 323], outputs=[CompletionOutput(index=0, text='Sarah and I am a 20-year-old student studying a B', token_ids=[8274, 293, 307, 674, 260, 30910, 30943, 30940, 30941, 2475, 30941, 717, 3052, 9436, 260, 347], cumulative_logprob=-19.13202815562545, logprobs={}, finish_reason=length)], finished=True)]

@canghongjian提供的脚本,也可以正常运行


You are using a model of type chatglm to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO 09-22 15:27:50 llm_engine.py:72] Initializing an LLM engine with config: model='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-22 15:27:50 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 09-22 15:27:57 llm_engine.py:199] # GPU blocks: 1546, # CPU blocks: 585
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.52s/it]
Prompt: '[Round 1]\n\n问:你好\n\n答:', Generated text: '你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。'
Prompt: '[Round 1]\n\n问:晚上睡不着应该怎么办\n\n答:', Generated text: '晚上睡不着觉可能会让人感到很困扰,可以尝试以下方法帮助入睡:\n\n1. 放松身体:试着放松身体,深呼吸或做一些轻松的伸展运动,以缓解身体的紧张。\n\n2. 创建一个舒适的睡眠环境:确保卧室安静、黑暗、凉爽、舒适,床垫、枕头舒适,保证床铺的干净和整洁。\n\n3. 避免使用电子产品:避免使用电子产品,如手机、电视、电脑等,这些产品的蓝光会干扰睡眠。\n\n4. 建立一个固定的睡眠时间:尽量在同一时间入睡和起床,帮助身体建立一个固定的睡眠节律。\n\n5. 避免饮用刺激性饮料:避免饮用含咖啡因的刺激性饮料,如咖啡、茶、可乐等。\n\n6. 尝试一些放松技巧:可以尝试一些深呼吸、渐进性肌肉松弛、瑜伽或冥想等放松技巧。\n\n如果以上方法不能解决您的问题,可以尝试寻求专业的帮助,如咨询医生或睡眠专家。'

@yanxiyue
Copy link
Contributor

good job~

@128Ghe980
Copy link

128Ghe980 commented Oct 8, 2023

vllm=0.2.0
transformer=4.33.2
按照add support修改后报错:

Traceback (most recent call last):
File "chatglm-vllm.py", line 74, in
engine, generation_config, tokenizer = init()
File "chatglm-vllm.py", line 67, in init
engine = AsyncLLMEngine.from_engine_args(engine_args)
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/engine/async_llm_engine.py", line 486, in from_engine_args
engine = cls(engine_args.worker_use_ray,
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/engine/async_llm_engine.py", line 270, in init
self.engine = self._init_engine(*args, **kwargs)
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/engine/async_llm_engine.py", line 306, in _init_engine
return engine_class(*args, **kwargs)
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 108, in init
self._init_workers(distributed_init_method)
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 140, in _init_workers
self._run_workers(
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/engine/llm_engine.py", line 692, in _run_workers
output = executor(*args, **kwargs)
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/worker/worker.py", line 68, in init_model
self.model = get_model(self.model_config)
File "/home/deployer/venv38_vllm/lib/python3.8/site-packages/vllm/model_executor/model_loader.py", line 102, in get_model
model.load_weights(model_config.model, model_config.download_dir,
TypeError: load_weights() takes from 2 to 4 positional arguments but 5 were given

这是我的加载代码:
engine_args = AsyncEngineArgs(
model=model_path,
trust_remote_code=True,
dtype="float16",
seed=42,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
genration_config = GenerationConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path,
use_fast=False,
trust_remote_code=True)
是因为vllm版本太高了吗

@128Ghe980
Copy link

我用vllm 0.1.17版本的代码,按照提交的文件修改后,使用vllm提供的测试脚本,运行起来,看起来没问题。模型是chatglm2-6b-32k。

You are using a model of type chatglm to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO 09-22 15:08:08 llm_engine.py:72] Initializing an LLM engine with config: model='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-22 15:08:08 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 09-22 15:08:15 llm_engine.py:199] # GPU blocks: 1269, # CPU blocks: 585
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.73it/s]
[RequestOutput(request_id=0, prompt='Hello, my name is', prompt_token_ids=[64790, 64792, 13755, 30932, 552, 1462, 323], outputs=[CompletionOutput(index=0, text='Sarah and I am a 20-year-old student studying a B', token_ids=[8274, 293, 307, 674, 260, 30910, 30943, 30940, 30941, 2475, 30941, 717, 3052, 9436, 260, 347], cumulative_logprob=-19.13202815562545, logprobs={}, finish_reason=length)], finished=True)]

@canghongjian提供的脚本,也可以正常运行


You are using a model of type chatglm to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO 09-22 15:27:50 llm_engine.py:72] Initializing an LLM engine with config: model='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-22 15:27:50 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 09-22 15:27:57 llm_engine.py:199] # GPU blocks: 1546, # CPU blocks: 585
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.52s/it]
Prompt: '[Round 1]\n\n问:你好\n\n答:', Generated text: '你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。'
Prompt: '[Round 1]\n\n问:晚上睡不着应该怎么办\n\n答:', Generated text: '晚上睡不着觉可能会让人感到很困扰,可以尝试以下方法帮助入睡:\n\n1. 放松身体:试着放松身体,深呼吸或做一些轻松的伸展运动,以缓解身体的紧张。\n\n2. 创建一个舒适的睡眠环境:确保卧室安静、黑暗、凉爽、舒适,床垫、枕头舒适,保证床铺的干净和整洁。\n\n3. 避免使用电子产品:避免使用电子产品,如手机、电视、电脑等,这些产品的蓝光会干扰睡眠。\n\n4. 建立一个固定的睡眠时间:尽量在同一时间入睡和起床,帮助身体建立一个固定的睡眠节律。\n\n5. 避免饮用刺激性饮料:避免饮用含咖啡因的刺激性饮料,如咖啡、茶、可乐等。\n\n6. 尝试一些放松技巧:可以尝试一些深呼吸、渐进性肌肉松弛、瑜伽或冥想等放松技巧。\n\n如果以上方法不能解决您的问题,可以尝试寻求专业的帮助,如咨询医生或睡眠专家。'

请问transformer的版本是多少呢?

@128Ghe980
Copy link

ok我跑起来了。
版本为:
vllm=0.2.0
transformer=4.33.2

按照add support修改,之后修改vllm/model_executor/models/chatglm.py中的class ChatGLMModel里的def load_weights
从:
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
use_np_cache: bool = False):

    tensor_model_parallel_rank = get_tensor_model_parallel_rank()
    state_dict = self.state_dict()

    for name, loaded_weight in hf_model_weights_iterator(
            model_name_or_path, cache_dir, use_np_cache):
        if "rotary_pos_emb.inv_freq" in name:
            continue

改为:
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):

    tensor_model_parallel_rank = get_tensor_model_parallel_rank()
    state_dict = self.state_dict()

    for name, loaded_weight in hf_model_weights_iterator(
            model_name_or_path, cache_dir, load_format, revision):
        if "rotary_pos_emb.inv_freq" in name:
            continue

之后用一下代码加载:
engine_args = AsyncEngineArgs(
model=model_path,
trust_remote_code=True,
dtype="float16",
seed=42,
)
engine = AsyncLLMEngine.from_engine_args(engine_args)
# genration_config = GenerationConfig.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path,
use_fast=False,
trust_remote_code=True)

应该就行了

@xuxingya
Copy link

你是改了包之后pip install .安装了么,我一直报“The NVIDIA driver on your system is too old”的错。@128Ghe980

@Alwin4Zhang
Copy link

Alwin4Zhang commented Oct 11, 2023

感谢分享, 我使用该代码在A100上测试,速度跟使用huggingface一样,而且显存占用达到70G,1秒只能生成30-40个token,以下是代码,请问哪里有问题:

import time
from vllm import LLM, SamplingParams

model_path = "THUDM/chatglm2-6b"

prompts = [
    "def quick_sort(array):\n",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=500)

llm = LLM(model=model_path, trust_remote_code=True)

for _ in range(5):
    t1 = time.time()
    outputs = llm.generate(prompts, sampling_params)
    for output in outputs:
        prompt = output.prompt
        generated_text = output.outputs[0]
        print(generated_text.text)
    print(len(generated_text.token_ids), time.time() - t1)

vllm preallocates 90% GPU memory by default, and you can change this ratio by modifying gpu_memory_utilization. I think you do not make full use of vllm by duplicated single call of generate. You can refer to the official benchmark codes to experience the real speed.

A10,单卡,api server启动,demo里面的数据生成,只有27tokens/s左右的速度,哪里有问题?

@128Ghe980
Copy link

你是改了包之后pip install .安装了么,我一直报“The NVIDIA driver on your system is too old”的错。@128Ghe980

不是,我是直接pip install vllm==0.2.0,之后再去更改的

@simon-mo
Copy link
Collaborator

simon-mo commented Nov 2, 2023

Thank you for the contribution, unfortunately this PR seems to became stale and ChatGLM3 also came out. Feel free to coordinate the contribution here if you have bandwidth!

#1552

@Midnight-719
Copy link

hi , I still encounter the problem of attributeError: 'ChatGLMConfig' object has no attribute 'num_hidden_layers'. I tried to update the vllm version and still have problems. my vllm==0.2.1.post1,transformers = =4.27.0,

@simon-mo
Copy link
Collaborator

simon-mo commented Nov 6, 2023

@Midnight-719 I would recommend to take a look at the following working PRs:

#1558
#1261

@zhuohan123
Copy link
Member

Again, thank you for your contribution! We merged #1261 and the current main branch supports ChatGLM now. Let us know if the main branch does not look good and feel free to propose any changes!

@zhuohan123 zhuohan123 closed this Nov 7, 2023
@PeterXiaTian
Copy link

我用vllm 0.1.17版本的代码,按照提交的文件修改后,使用vllm提供的测试脚本,运行起来,看起来没问题。模型是chatglm2-6b-32k。

You are using a model of type chatglm to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO 09-22 15:08:08 llm_engine.py:72] Initializing an LLM engine with config: model='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-22 15:08:08 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 09-22 15:08:15 llm_engine.py:199] # GPU blocks: 1269, # CPU blocks: 585
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  2.73it/s]
[RequestOutput(request_id=0, prompt='Hello, my name is', prompt_token_ids=[64790, 64792, 13755, 30932, 552, 1462, 323], outputs=[CompletionOutput(index=0, text='Sarah and I am a 20-year-old student studying a B', token_ids=[8274, 293, 307, 674, 260, 30910, 30943, 30940, 30941, 2475, 30941, 717, 3052, 9436, 260, 347], cumulative_logprob=-19.13202815562545, logprobs={}, finish_reason=length)], finished=True)]

@canghongjian提供的脚本,也可以正常运行


You are using a model of type chatglm to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
INFO 09-22 15:27:50 llm_engine.py:72] Initializing an LLM engine with config: model='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer='/home/top/LLM/models/THUDM/chatglm2-6b-32k', tokenizer_mode=auto, trust_remote_code=True, dtype=torch.float16, download_dir=None, load_format=auto, tensor_parallel_size=1, seed=0)
WARNING 09-22 15:27:50 tokenizer.py:64] Using a slow tokenizer. This might cause a significant slowdown. Consider using a fast tokenizer instead.
INFO 09-22 15:27:57 llm_engine.py:199] # GPU blocks: 1546, # CPU blocks: 585
Processed prompts: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:05<00:00,  2.52s/it]
Prompt: '[Round 1]\n\n问:你好\n\n答:', Generated text: '你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。'
Prompt: '[Round 1]\n\n问:晚上睡不着应该怎么办\n\n答:', Generated text: '晚上睡不着觉可能会让人感到很困扰,可以尝试以下方法帮助入睡:\n\n1. 放松身体:试着放松身体,深呼吸或做一些轻松的伸展运动,以缓解身体的紧张。\n\n2. 创建一个舒适的睡眠环境:确保卧室安静、黑暗、凉爽、舒适,床垫、枕头舒适,保证床铺的干净和整洁。\n\n3. 避免使用电子产品:避免使用电子产品,如手机、电视、电脑等,这些产品的蓝光会干扰睡眠。\n\n4. 建立一个固定的睡眠时间:尽量在同一时间入睡和起床,帮助身体建立一个固定的睡眠节律。\n\n5. 避免饮用刺激性饮料:避免饮用含咖啡因的刺激性饮料,如咖啡、茶、可乐等。\n\n6. 尝试一些放松技巧:可以尝试一些深呼吸、渐进性肌肉松弛、瑜伽或冥想等放松技巧。\n\n如果以上方法不能解决您的问题,可以尝试寻求专业的帮助,如咨询医生或睡眠专家。'

0.1.17对GPU的要求是什么?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
new model Requests to new models
Projects
None yet
Development

Successfully merging this pull request may close these issues.