Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Model] Adds Phi-3 support (vllm-project#4298)
Browse files Browse the repository at this point in the history
  • Loading branch information
caiom authored and robertgshaw2-neuralmagic committed Apr 26, 2024
1 parent 382fb33 commit e207f23
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 9 deletions.
4 changes: 4 additions & 0 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,10 @@ Alongside each architecture, we include some popular models that use it.
- Phi
- :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
-
* - :code:`Phi3ForCausalLM`
- Phi-3
- :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, etc.
-
* - :code:`QWenLMHeadModel`
- Qwen
- :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
Expand Down
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1096,7 +1096,7 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if rope_scaling is not None and rope_scaling["type"] != "su":
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
Expand Down
136 changes: 133 additions & 3 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,114 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class Phi3SuScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.
Based on the original RotaryEmbedding implementation.
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
long_mscale: float = 1.225,
):
super().__init__()

if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")

self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)

long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)

long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)

def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
return inv_freq

def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)

k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
idx.device)
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)

query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin

return query.flatten(-2), key.flatten(-2)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand All @@ -349,17 +457,26 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
tuple(rope_scaling.items()) if rope_scaling is not None else None)
rope_scaling_args)
if key in _ROPE_DICT:
return _ROPE_DICT[key]

if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type != "su":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
Expand All @@ -383,6 +500,19 @@ def get_rope(
base, is_neox_style,
scaling_factor,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
Expand Down
14 changes: 9 additions & 5 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def __init__(
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
Expand Down Expand Up @@ -378,11 +382,11 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
Expand Down

0 comments on commit e207f23

Please sign in to comment.