Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] Adds Phi-3 support #4298

Merged
merged 11 commits into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def _get_and_verify_max_len(
derived_max_model_len = default_max_len

rope_scaling = getattr(hf_config, "rope_scaling", None)
if rope_scaling is not None:
if rope_scaling is not None and rope_scaling["type"] != "su":
assert "factor" in rope_scaling
scaling_factor = rope_scaling["factor"]
if rope_scaling["type"] == "yarn":
Expand Down
136 changes: 133 additions & 3 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,114 @@ def _compute_cos_sin_cache(self) -> torch.Tensor:
return cache


class Phi3SuScaledRotaryEmbedding(nn.Module):
"""Phi3 family of models scaled rotary embedding.

Based on the original RotaryEmbedding implementation.
"""

def __init__(
self,
head_size: int,
rotary_dim: int,
max_position_embeddings: int,
original_max_position_embeddings: int,
base: int,
is_neox_style: bool,
short_factor: List[float],
long_factor: List[float],
short_mscale: float = 1.1,
long_mscale: float = 1.225,
):
super().__init__()

if rotary_dim != head_size:
raise ValueError(
f"`Phi3SuScaledRotaryEmbedding` does not support rotary_dim != \
head_size ({rotary_dim}!={head_size}).")
if is_neox_style is False:
raise ValueError(
"`Phi3SuScaledRotaryEmbedding` only supports neox_style.")

self.head_size = head_size
self.max_position_embeddings = max_position_embeddings
self.original_max_position_embeddings = original_max_position_embeddings
self.base = base
self.short_factor = short_factor
self.long_factor = long_factor
self.short_mscale = short_mscale
self.long_mscale = long_mscale

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = short_cache.to(torch.get_default_dtype())
self.register_buffer("short_cos_sin_cache",
short_cache,
persistent=False)

long_cache = self._compute_cos_sin_cache(max_position_embeddings,
long_factor, long_mscale)
long_cache = long_cache.to(torch.get_default_dtype())
self.register_buffer("long_cos_sin_cache",
long_cache,
persistent=False)

long_short_cache = torch.cat(
[self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0)
self.register_buffer("long_short_cos_sin_cache",
long_short_cache,
persistent=False)

def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor:
rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange(
0, self.head_size, 2, dtype=torch.float) / self.head_size)))
return inv_freq

def _compute_cos_sin_cache(
self,
max_position_embeddings: int,
rescale_factors: List[float],
mscale: float,
) -> torch.Tensor:
inv_freq = self._compute_inv_freq(rescale_factors)
t = torch.arange(max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * mscale
sin = freqs.sin() * mscale
cache = torch.cat((cos, sin), dim=-1)
return cache

def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)

k = self.original_max_position_embeddings
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
Comment on lines +429 to +430
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have question on this line: In a case where there are prompts longer than k and others shorter than k, it will have a uniform offset (choosing the long cos_sin_cache), therefore, the shorter prompts will also refer to the long cos_sin_cache, and the outputs will be different if the short prompts are processed alone. But is this as expected? In shorter-prompt cases the short cos_sin_cache should be used I think.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mindest, that's correct, when a batch contains long and short sequences, it will always use long factor, even for short samples. Currently we don't support such mixed batches.

idx = (torch.add(positions, long_prompt_offset)
if long_prompt_offset is not None else positions)
self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to(
idx.device)
idx = torch.add(idx, offsets) if offsets is not None else idx
cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

cos, sin = cos_sin.chunk(2, dim=-1)
cos = cos.repeat(1, 2).unsqueeze(-2)
sin = sin.repeat(1, 2).unsqueeze(-2)

query = query * cos + _rotate_neox(query) * sin
key = key * cos + _rotate_neox(key) * sin

return query.flatten(-2), key.flatten(-2)


_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {}


Expand All @@ -348,17 +456,26 @@ def get_rope(
is_neox_style: bool = True,
rope_scaling: Optional[Dict[str, Any]] = None,
) -> RotaryEmbedding:
if rope_scaling is not None:
# Transforms every value that is a list into a tuple for caching calls
rope_scaling_tuple = {
k: tuple(v) if isinstance(v, list) else v
for k, v in rope_scaling.items()
}
rope_scaling_args = tuple(rope_scaling_tuple.items())
else:
rope_scaling_args = None
key = (head_size, rotary_dim, max_position, base, is_neox_style,
tuple(rope_scaling.items()) if rope_scaling is not None else None)
rope_scaling_args)
if key in _ROPE_DICT:
return _ROPE_DICT[key]

if rope_scaling is None:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style)
else:
scaling_type = rope_scaling["type"]
scaling_factor = rope_scaling["factor"]
if scaling_type != "su":
scaling_factor = rope_scaling["factor"]
if scaling_type == "linear":
rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim,
max_position, base,
Expand All @@ -382,6 +499,19 @@ def get_rope(
base, is_neox_style,
scaling_factor,
**extra_kwargs)
elif scaling_type == "su":
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position = rope_scaling[
"original_max_position_embeddings"]
extra_kwargs = {
k: v
for k, v in rope_scaling.items()
if k in ("short_mscale", "long_mscale")
}
rotary_emb = Phi3SuScaledRotaryEmbedding(
head_size, rotary_dim, max_position, original_max_position,
base, is_neox_style, short_factor, long_factor, **extra_kwargs)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
_ROPE_DICT[key] = rotary_emb
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
Copy link

@davidgxue davidgxue Apr 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Retrospectively a quick question (more for my own understanding): why are these llama and LlamaForCausalLM? Isn't it on huggingface transformers library it has been defined as Phi3ForCausalLM? https://huggingface.co/docs/transformers/main/en/model_doc/phi3.

It seems like module and cls name is later used to import from vllm's /models directory, which we have not yet added phi3.py in to the list or created its own file yet?

I think Phi3, while on paper is very similar to Llama 2, it does have a few modifications that makes it different from llama 2 like fused qkv and mlp. Should we make Phi 3 its own file or do you think it is ok to just keep inheriting from llama this way? (or I guess based on the rest of the PR you may have already done this by just changing llama.py so maybe that just works haha)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The model is very similar to llama, hence it is better to reuse the existing code base. The same is true for Mistral.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah i saw your changes to the other files and they make sense now. Ty!

"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
Expand Down
14 changes: 9 additions & 5 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def __init__(
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
sliding_window = getattr(config, "sliding_window", None)
Expand Down Expand Up @@ -378,11 +382,11 @@ def sample(
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
(".qkv_proj", ".q_proj", "q"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick question: Why adding dot here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does Phi have a different set of mappings?

  • does this not break the existing llama executor?

Copy link
Contributor Author

@caiom caiom Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ensures that we are replacing the full layer name instead of a partial layer name.

Phi3 uses the vLLM layer naming convention (no replacement necessary), therefore we have a layer named gate_up_proj, without the dots, this layer would be incorrectly renamed to gate_gate_up_proj due to line 389.

So the dot in up_proj makes sure that we only change the name of the layer if the layer name starts with "up_proj" instead of containing "up_proj" anywhere in its name.

This change is compatible with Llama and Mistral models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for clarification! make sense.

(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
Expand Down
Loading