Skip to content

[Feature]: Support PP with VLLM_USE_V1=1 #1302

Open
@silveryshine

Description

@silveryshine

Original title: [Bug]: vllm0.8.5 + vllm0.8.5rc1 + Qwen models run failed with VLLM_USE_V1=1 and pp > 1

Your current environment

NPU: 910B4
python: 3.10.16
CANN: 8.1RC1
vllm: 0.8.5
vllm-ascend: 0.8.5rc1

🐛 Describe the bug

The parameter I ran to initialize vllm:

llm = LLM(model=model_path,dtype='float16',
              block_size=128,
              swap_space=16,
              cpu_offload_gb=10,
              pipeline_parallel_size=4,
              tensor_parallel_size=1,
              enforce_eager=True,
              max_model_len=16384,
              distributed_executor_backend="ray",
              )

the error output during decode:

Traceback (most recent call last):
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 400, in run_engine_core
    raise e
  File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 389, in run_engine_core
    engine_core.run_busy_loop()
  File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 413, in run_busy_loop
    self._process_engine_step()
  File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 438, in _process_engine_step
    outputs = self.step_fn()
  File "/home/git_repo/vllm/vllm/v1/engine/core.py", line 249, in step_with_batch_queue
    model_output = future.result()
  File "/home/git_repo/vllm/vllm/v1/executor/ray_distributed_executor.py", line 24, in result
    return self.ref.get()
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 150, in get
    return _process_return_vals(return_vals, True)
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/ray/experimental/compiled_dag_ref.py", line 27, in _process_return_vals
    raise val.as_instanceof_cause()
ray.exceptions.RayTaskError(AttributeError): ray::RayWorkerWrapper.__ray_call__() (pid=3953894, ip=51.36.133.73)
  File "/home/git_repo/vllm/vllm/executor/ray_utils.py", line 137, in execute_model_ray
    output = self.worker.model_runner.execute_model(
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/vllm_ascend/worker/model_runner_v1.py", line 642, in execute_model
    logits = self.model.compute_logits(hidden_states, None)
  File "/home/git_repo/vllm/vllm/model_executor/models/qwen2.py", line 475, in compute_logits
    logits = self.logits_processor(self.lm_head, hidden_states,
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/git_repo/vllm/vllm/model_executor/layers/logits_processor.py", line 70, in forward
    logits = self._get_logits(hidden_states, lm_head, embedding_bias)
  File "/home/git_repo/vllm/vllm/model_executor/layers/logits_processor.py", line 108, in _get_logits
    logits = lm_head.quant_method.apply(lm_head,
  File "/home/conda/envs/vllm_0.8.5/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1931, in __getattr__
    raise AttributeError(
AttributeError: 'PPMissingLayer' object has no attribute 'quant_method'

The execution can be processed under GPU environment, so it's supposed to be a VLLM Ascend bug.

In Qwen2.py the model's lm_head is initialized as follows:

if get_pp_group().is_last_rank:
    if config.tie_word_embeddings:
        self.lm_head = self.model.embed_tokens
    else:
        self.lm_head = ParallelLMHead(config.vocab_size,
                                              config.hidden_size,
                                              quant_config=quant_config,
                                              prefix=maybe_prefix(
                                                  prefix, "lm_head"))
else:
    self.lm_head = PPMissingLayer()

lm_head will be assigned as PPMissingLayer if current worker is not last rank in pp group, but when it comes to model_runner_v1.NPUModelRunner.execute_model, logits will be always computed through logits = self.model.compute_logits(hidden_states[sample_indices], None). In compute_logits, model.lm_head is passed as param and model.lm_head.quant_method.apply will be called eventually:

def _get_logits(
        self,
        hidden_states: torch.Tensor,
        lm_head: VocabParallelEmbedding,
        embedding_bias: Optional[torch.Tensor],
    ) -> Optional[torch.Tensor]:
        # Get the logits for the next tokens.
        logits = lm_head.quant_method.apply(lm_head,
                                            hidden_states,
                                            bias=embedding_bias)

PPMissingLayer, however, does not hold any quant_method, which caused the error. Either NPUModelRunner.execute_model should return early when current worker is not last pp worker, or Qwen2.py should be modified (some other models like llama also have this problem).
This bug might be reproducible on main as latest code also lack early return in func execute_model.

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions