Skip to content

Commit

Permalink
[Model] Support pp for qwen2-vl (vllm-project#8696)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyanyi authored Sep 23, 2024
1 parent 3e83c12 commit a79e522
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 14 deletions.
8 changes: 8 additions & 0 deletions tests/distributed/test_pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import os

import pytest
from packaging import version
from transformers import __version__ as transformers_version

from vllm.logger import init_logger

Expand Down Expand Up @@ -37,6 +39,7 @@
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
],
)
@fork_new_process_for_each_test
Expand All @@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")

# Skip tests that require transformers>=4.45.0
if "Qwen2-VL" in MODEL_NAME and version.parse(
transformers_version) < version.parse("4.45.0.dev0"):
pytest.skip("This test requires transformers>=4.45.0")

pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
Expand Down
1 change: 1 addition & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"Qwen2ForCausalLM",
"Qwen2MoeForCausalLM",
"QWenLMHeadModel",
"Qwen2VLForConditionalGeneration",
]


Expand Down
22 changes: 15 additions & 7 deletions vllm/model_executor/models/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from vllm.sequence import IntermediateTensors

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers


class Qwen2MLP(nn.Module):
Expand Down Expand Up @@ -235,11 +235,16 @@ def __init__(
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
if get_pp_group().is_first_rank or (config.tie_word_embeddings
and get_pp_group().is_last_rank):
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
)
else:
self.embed_tokens = PPMissingLayer()

self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: Qwen2DecoderLayer(config=config,
Expand All @@ -248,7 +253,10 @@ def __init__(
prefix=f"{prefix}.layers",
)

self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
Expand Down
29 changes: 22 additions & 7 deletions vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
from vllm.attention.selector import (_Backend, backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import parallel_state
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
from vllm.logger import init_logger
Expand All @@ -68,6 +68,9 @@
from vllm.sequence import IntermediateTensors, SequenceData
from vllm.transformers_utils.processor import get_processor

from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory)

logger = init_logger(__name__)

# === Vision Inputs === #
Expand Down Expand Up @@ -856,15 +859,21 @@ def __init__(self,

self.model = Qwen2Model(config, cache_config, quant_config)

if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
if get_pp_group().is_last_rank:
if config.tie_word_embeddings:
self.lm_head = self.model.embed_tokens
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
else:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.lm_head = PPMissingLayer()

self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

def _validate_and_reshape_mm_tensor(self,
mm_input: Union[torch.Tensor,
Expand Down Expand Up @@ -979,7 +988,8 @@ def forward(
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)

if image_input is None and video_input is None:
if (image_input is None
and video_input is None) or not get_pp_group().is_first_rank:
inputs_embeds = None
else:
if getattr(self.config, "rope_scaling", {}).get("type",
Expand Down Expand Up @@ -1015,6 +1025,7 @@ def forward(
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
)
return hidden_states
Expand Down Expand Up @@ -1055,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -1081,6 +1094,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
except KeyError:
print(params_dict.keys())
Expand Down

0 comments on commit a79e522

Please sign in to comment.