From 4fca644aaaa32563f10843ce887079d5d51cba68 Mon Sep 17 00:00:00 2001 From: Yanyi Liu Date: Mon, 23 Sep 2024 21:46:59 +0800 Subject: [PATCH] [Model] Support pp for qwen2-vl (#8696) --- tests/distributed/test_pipeline_parallel.py | 8 ++++++ vllm/config.py | 1 + vllm/model_executor/models/qwen2.py | 22 +++++++++++----- vllm/model_executor/models/qwen2_vl.py | 29 ++++++++++++++++----- 4 files changed, 46 insertions(+), 14 deletions(-) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 02288dc9dac90..280a8abdd13a7 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -8,6 +8,8 @@ import os import pytest +from packaging import version +from transformers import __version__ as transformers_version from vllm.logger import init_logger @@ -37,6 +39,7 @@ (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), + (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp") ], ) @fork_new_process_for_each_test @@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") + # Skip tests that require transformers>=4.45.0 + if "Qwen2-VL" in MODEL_NAME and version.parse( + transformers_version) < version.parse("4.45.0.dev0"): + pytest.skip("This test requires transformers>=4.45.0") + pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", diff --git a/vllm/config.py b/vllm/config.py index fae2d44f174bd..960a8d3928584 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -51,6 +51,7 @@ "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", + "Qwen2VLForConditionalGeneration", ] diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a64e08c422bc3..5e6737ad7fa47 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers class Qwen2MLP(nn.Module): @@ -235,11 +235,16 @@ def __init__( self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - ) + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen2DecoderLayer(config=config, @@ -248,7 +253,10 @@ def __init__( prefix=f"{prefix}.layers", ) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 1011c9256793e..9f72210c60bf9 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -45,7 +45,7 @@ from vllm.attention.selector import (_Backend, backend_name_to_enum, get_global_forced_attn_backend) from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import parallel_state +from vllm.distributed import get_pp_group, parallel_state from vllm.distributed import utils as dist_utils from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger @@ -68,6 +68,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import get_processor +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory) + logger = init_logger(__name__) # === Vision Inputs === # @@ -856,15 +859,21 @@ def __init__(self, self.model = Qwen2Model(config, cache_config, quant_config) - if config.tie_word_embeddings: - self.lm_head = self.model.embed_tokens + if get_pp_group().is_last_rank: + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) else: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) + self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def _validate_and_reshape_mm_tensor(self, mm_input: Union[torch.Tensor, @@ -979,7 +988,8 @@ def forward( image_input = self._parse_and_validate_image_input(**kwargs) video_input = self._parse_and_validate_video_input(**kwargs) - if image_input is None and video_input is None: + if (image_input is None + and video_input is None) or not get_pp_group().is_first_rank: inputs_embeds = None else: if getattr(self.config, "rope_scaling", {}).get("type", @@ -1015,6 +1025,7 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -1055,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -1081,6 +1094,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] except KeyError: print(params_dict.keys())