Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Phi #132

Merged
merged 12 commits into from
Dec 15, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Fixed layer norm
  • Loading branch information
tgaddair committed Dec 15, 2023
commit a89cb090381ac37cf9cd25b919426ee5effb6c84
69 changes: 14 additions & 55 deletions server/lorax_server/models/custom_modeling/flash_phi_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from lorax_server.utils import flash_attn
from lorax_server.utils import paged_attn
from lorax_server.utils.layers import (
FastLayerNorm,
TensorParallelAdapterRowLinear,
TensorParallelRowLinear,
TensorParallelColumnLinear,
Expand All @@ -64,59 +65,6 @@
ATTN_OUT_PROJ = "mixer.out_proj"
MLP_FC1 = "mlp.fc1"
MLP_FC2 = "mlp.fc2"


class PhiRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6):
"""
PhiRMSNorm is equivalent to LlamaLayerNorm
"""
super().__init__()

weight = weights.get_tensor(f"{prefix}.weight")
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 8192:
if residual is not None:
hidden_states += residual
residual = hidden_states

hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)

# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)

return self.weight * hidden_states, residual
else:
# faster post attention rms norm
normed_hidden_states, res, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
None,
None,
None,
None,
None,
0.0,
self.variance_epsilon,
1.0,
0,
None,
False,
True, # Activate RMSNorm
)
if res is None:
res = hidden_states

return normed_hidden_states, res


def load_attention(config, prefix, weights, layer_id, head_dim, n_head, n_head_kv):
Expand Down Expand Up @@ -191,6 +139,7 @@ def __init__(
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
self.layer_id = layer_id

def forward(
self,
Expand All @@ -206,6 +155,9 @@ def forward(
adapter_data,
):
qkv = self.Wqkv(hidden_states, adapter_data)
if self.layer_id == 0:
print(qkv.shape, qkv.norm().item())

query, kv = qkv.split(
[
self.head_size * self.num_heads,
Expand All @@ -223,6 +175,10 @@ def forward(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
)

if self.layer_id == 0:
print(query.shape, query.norm().item())
print(kv.shape, kv.norm().item())

# output tensor
attn_output = torch.empty_like(query)

Expand Down Expand Up @@ -302,7 +258,7 @@ def __init__(self, layer_id, config, weights):
super().__init__()
prefix = f"transformer.h.{layer_id}"

self.ln = PhiRMSNorm(
self.ln = FastLayerNorm.load(
prefix=f"{prefix}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.mixer = FlashPhiAttention(
Expand All @@ -326,6 +282,8 @@ def forward(
adapter_data,
):
normed_hidden_states, _ = self.ln(hidden_states, residual=None)
if self.mixer.layer_id == 0:
print(normed_hidden_states.shape, normed_hidden_states.norm().item())

attn_output = self.mixer(
normed_hidden_states,
Expand Down Expand Up @@ -389,6 +347,7 @@ def forward(
adapter_data: AdapterBatchData,
) -> torch.Tensor:
hidden_states = self.embd(input_ids)
print(hidden_states.shape, hidden_states.norm().item())

# Get rotary cos and sin for this forward
# Avoid to index in each layer
Expand Down Expand Up @@ -420,7 +379,7 @@ def __init__(self, config, weights):
super().__init__()

prefix = "lm_head"
self.ln = PhiRMSNorm(
self.ln = FastLayerNorm.load(
prefix=f"{prefix}.ln", weights=weights, eps=config.layer_norm_epsilon
)
self.linear = TensorParallelAdapterRowLinear.load(TensorParallelHead.load(
Expand Down