Description
Original title: [Bug]: vllm0.8.5 + vllm0.8.5rc1 + Qwen models run failed with VLLM_USE_V1=1 and pp > 1
Your current environment
NPU: 910B4
python: 3.10.16
CANN: 8.1RC1
vllm: 0.8.5
vllm-ascend: 0.8.5rc1
🐛 Describe the bug
The parameter I ran to initialize vllm:
llm = LLM(model=model_path,dtype='float16',
block_size=128,
swap_space=16,
cpu_offload_gb=10,
pipeline_parallel_size=4,
tensor_parallel_size=1,
enforce_eager=True,
max_model_len=16384,
distributed_executor_backend="ray",
)
the error output during decode:
Traceback (most recent call last):
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
self.run()
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/multiprocessing/process.py", line 108, in run
self._target(*self._args, **self._kwargs)
File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 400, in run_engine_core
raise e
File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 389, in run_engine_core
engine_core.run_busy_loop()
File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 413, in run_busy_loop
self._process_engine_step()
File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 438, in _process_engine_step
outputs = self.step_fn()
File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 249, in step_with_batch_queue
model_output = future.result()
File "/home/git_repo/vllm/vllm/v1/executor/ray_distributed_executor.py", line 24, in result
return self.ref.get()
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 150, in get
return _process_return_vals(return_vals, True)
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 27, in _process_return_vals
raise val.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::RayWorkerWrapper.__ray_call__() (pid=3953894, ip=51.36.133.73)
File "/home/git_repo/vllm/vllm/executor/ray_utils.py", line 137, in execute_model_ray
output = self.worker.model_runner.execute_model(
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/vllm_ascend/worker/model_runner_v1.py", line 642, in execute_model
logits = self.model.compute_logits(hidden_states, None)
File "/home/git_repo/vllm/vllm/model_executor/models/qwen2.py", line 475, in compute_logits
logits = self.logits_processor(self.lm_head, hidden_states,
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/home/git_repo/vllm/vllm/model_executor/layers/logits_processor.py", line 70, in forward
logits = self._get_logits(hidden_states, lm_head, embedding_bias)
File "/home/git_repo/vllm/vllm/model_executor/layers/logits_processor.py", line 108, in _get_logits
logits = lm_head.quant_method.apply(lm_head,
File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
raise AttributeError(
AttributeError: 'PPMissingLayer' object has no attribute 'quant_method'
The execution can be processed under GPU environment, so it's supposed to be a VLLM Ascend bug.
In Qwen2.py the model's lm_head is initialized as follows:
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
lm_head will be assigned as PPMissingLayer if current worker is not last rank in pp group, but when it comes to model_runner_v1.NPUModelRunner.execute_model, logits will be always computed through logits = self.model.compute_logits(hidden_states[sample_indices], None)
. In compute_logits, model.lm_head is passed as param and model.lm_head.quant_method.apply will be called eventually:
def _get_logits(
self,
hidden_states: torch.Tensor,
lm_head: VocabParallelEmbedding,
embedding_bias: Optional[torch.Tensor],
) -> Optional[torch.Tensor]:
# Get the logits for the next tokens.
logits = lm_head.quant_method.apply(lm_head,
hidden_states,
bias=embedding_bias)
PPMissingLayer, however, does not hold any quant_method, which caused the error. Either NPUModelRunner.execute_model should return early when current worker is not last pp worker, or Qwen2.py should be modified (some other models like llama also have this problem).
This bug might be reproducible on main as latest code also lack early return in func execute_model.