From 281cc4b3cd2f6c84c2cd8272ef83d97edd1c323a Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 18 Nov 2024 13:04:14 -0500 Subject: [PATCH] [Model][Bugfix] Support TP for PixtralHF ViT (#10405) Signed-off-by: mgoin --- vllm/model_executor/models/pixtral.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d44a538d56b8c..f7f46770057e2 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -17,6 +17,7 @@ from vllm.attention import AttentionMetadata from vllm.config import ModelConfig, VllmConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) from vllm.model_executor.layers.activation import get_act_and_mul_fn @@ -843,17 +844,20 @@ def __init__( self.config = config assert not config.hidden_size % config.num_attention_heads - self.n_heads = config.num_attention_heads + self.total_num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.n_heads = divide(config.num_attention_heads, tp_size) self.head_dim = config.hidden_size // config.num_attention_heads self.qkv_proj = QKVParallelLinear( hidden_size=config.hidden_size, head_size=self.head_dim, - total_num_heads=self.n_heads, + total_num_heads=self.total_num_heads, bias=False, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) + assert self.total_num_heads * self.head_dim == config.hidden_size self.o_proj = RowParallelLinear( input_size=config.hidden_size, output_size=config.hidden_size,