Skip to content

Commit

Permalink
Model PP support
Browse files Browse the repository at this point in the history
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
  • Loading branch information
andoorve committed Aug 5, 2024
1 parent 8069495 commit f327221
Show file tree
Hide file tree
Showing 48 changed files with 1,052 additions and 663 deletions.
49 changes: 34 additions & 15 deletions vllm/model_executor/models/arctic.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""Inference-only Snowflake Arctic model."""
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union

import torch
from torch import nn

from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
Expand All @@ -32,6 +32,8 @@
from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.transformers_utils.configs.arctic import ArcticConfig

from .utils import is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory

logger = init_logger(__name__)


Expand Down Expand Up @@ -364,6 +366,7 @@ def __init__(
config: ArcticConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
Expand All @@ -372,28 +375,35 @@ def __init__(
self.vocab_size,
config.hidden_size,
org_num_embeddings=self.vocab_size)
self.layers = nn.ModuleList([
ArcticDecoderLayer(config,
layer_idx,
cache_config,
quant_config=quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: ArcticDecoderLayer(config,
int(prefix.split(".")[-1]),
cache_config, quant_config),
prefix=f"{prefix}.layers")
self._attn_implementation = config._attn_implementation
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
for i in range(len(self.layers)):
intermediate_tensors: IntermediateTensors,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states = layer(positions, hidden_states, kv_caches[i],
hidden_states = layer(positions, hidden_states, kv_caches[i-self.start_layer],
attn_metadata)
if not get_pp_group().is_last_rank:
return IntermediateTensors({"hidden_states": hidden_states})
hidden_states = self.norm(hidden_states)
return hidden_states

Expand All @@ -420,6 +430,7 @@ def __init__(self,
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors

def forward(
self,
Expand All @@ -428,9 +439,9 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down Expand Up @@ -498,6 +509,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -507,6 +520,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -517,6 +532,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
Expand All @@ -527,6 +544,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
else:
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]

weight_loader = getattr(param, "weight_loader",
Expand Down
50 changes: 35 additions & 15 deletions vllm/model_executor/models/baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
# limitations under the License.
"""Inference-only BaiChuan model compatible with HuggingFace weights."""
import math
from typing import Iterable, List, Optional, Tuple
from typing import Iterable, List, Optional, Tuple, Union

import torch
from torch import nn
Expand All @@ -28,7 +28,8 @@
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_tensor_model_parallel_world_size,
get_pp_group)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
Expand All @@ -46,6 +47,7 @@
from vllm.sequence import IntermediateTensors, SamplerOutput

from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers, make_empty_intermediate_tensors_factory


def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
Expand Down Expand Up @@ -255,7 +257,8 @@ def __init__(self,
config: PretrainedConfig,
position_embedding: str,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None):
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
self.config = config
self.padding_idx = config.pad_token_id
Expand All @@ -265,31 +268,43 @@ def __init__(self,
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
BaiChuanDecoderLayer(config, position_embedding, cache_config,
quant_config)
for _ in range(config.num_hidden_layers)
])
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: BaiChuanDecoderLayer(config, position_embedding, cache_config, quant_config),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(["hidden_states", "residual"], config.hidden_size)

def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
intermediate_tensors: IntermediateTensors,
) -> Union[torch.Tensor. IntermediateTensors]:
if get_pp_group().is_first_rank:
hidden_states = self.embed_tokens(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
hidden_states, residual = layer(
positions,
hidden_states,
kv_caches[i],
kv_caches[i-self.start_layer],
attn_metadata,
residual,
)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual,
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states

Expand Down Expand Up @@ -333,6 +348,7 @@ def __init__(
quant_config=quant_config)
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.make_empty_intermediate_tensors = self.model.make_empty_intermediate_tensors

def forward(
self,
Expand All @@ -341,9 +357,9 @@ def forward(
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
) -> torch.Tensor:
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata)
attn_metadata, intermediate_tensors)
return hidden_states

def compute_logits(self, hidden_states: torch.Tensor,
Expand Down Expand Up @@ -389,6 +405,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -397,6 +415,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
11 changes: 8 additions & 3 deletions vllm/model_executor/models/blip2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union

import torch
import torch.nn as nn
Expand All @@ -21,7 +21,7 @@
from .blip import (BlipVisionModel, dummy_image_for_blip,
get_max_blip_image_tokens)
from .interfaces import SupportsVision
from .utils import merge_vision_embeddings
from .utils import merge_vision_embeddings, is_pp_missing_parameter, make_empty_intermediate_tensors_factory

_KEYS_TO_MODIFY_MAPPING = {
"language_model.lm_head": "lm_head",
Expand Down Expand Up @@ -558,7 +558,7 @@ def forward(
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
**kwargs: object,
) -> SamplerOutput:
) -> Union[SamplerOutput, IntermediateTensors]:
"""Run forward pass for BLIP-2.
One key thing to understand is the `input_ids` already accounts for the
Expand Down Expand Up @@ -607,6 +607,7 @@ def forward(
positions,
kv_caches,
attn_metadata,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds)

return hidden_states
Expand Down Expand Up @@ -656,13 +657,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
shard_id) in stacked_params_mapping:
if weight_name not in name:
continue
if is_pp_missing_parameter(name.replace(weight_name, param_name), self):
continue
param = params_dict[name.replace(weight_name, param_name)]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
Expand Down
Loading

0 comments on commit f327221

Please sign in to comment.