Skip to content

Commit

Permalink
[Bugfix] More faithful implementation of Gemma (vllm-project#3653)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored Mar 27, 2024
1 parent c26d013 commit a882ca5
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions vllm/model_executor/models/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model compatible with HuggingFace weights."""
from functools import lru_cache
from typing import List, Optional, Tuple

import torch
Expand All @@ -22,6 +23,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.config import LoRAConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import GeluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (LinearMethodBase,
Expand All @@ -40,13 +42,43 @@
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput

logger = init_logger(__name__)


@lru_cache(maxsize=None)
def _get_gemma_act_fn(
hidden_act: Optional[str],
hidden_activation: Optional[str],
) -> nn.Module:
if hidden_activation is None:
if hidden_act is not None:
logger.warning(
"Gemma's activation function was incorrectly set to exact GeLU "
"in the config JSON file when it was initially released. "
"Changing the activation function to approximate GeLU "
"(`gelu_pytorch_tanh`). If you want to use the legacy "
f"`{hidden_act}`, edit the config JSON to set "
f"`hidden_activation={hidden_act}` instead of `hidden_act`. "
"See https://github.com/huggingface/transformers/pull/29402 "
"for more details.")
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu_pytorch_tanh":
return GeluAndMul(approximate="tanh")
elif hidden_activation == "gelu":
return GeluAndMul(approximate="none")
else:
raise ValueError(f"Activation function {hidden_act} is not "
"supported for Gemma models.")


class GemmaMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
hidden_act: Optional[str] = None,
hidden_activation: Optional[str] = None,
linear_method: Optional[LinearMethodBase] = None,
) -> None:
super().__init__()
Expand All @@ -58,7 +90,7 @@ def __init__(
hidden_size,
bias=False,
linear_method=linear_method)
self.act_fn = GeluAndMul()
self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation)

def forward(self, x):
gate_up, _ = self.gate_up_proj(x)
Expand Down Expand Up @@ -162,6 +194,8 @@ def __init__(
self.mlp = GemmaMLP(
hidden_size=self.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
hidden_activation=getattr(config, "hidden_activation", None),
linear_method=linear_method,
)
self.input_layernorm = RMSNorm(config.hidden_size,
Expand Down Expand Up @@ -218,6 +252,13 @@ def __init__(
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

# Normalize the embedding by sqrt(hidden_size)
# The normalizer's data type should be downcasted to the model's
# data type such as bfloat16, not float32.
# See https://github.com/huggingface/transformers/pull/29402
normalizer = self.config.hidden_size**0.5
self.register_buffer("normalizer", torch.tensor(normalizer))

def forward(
self,
input_ids: torch.Tensor,
Expand All @@ -226,8 +267,7 @@ def forward(
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
# Normalize the embedding by sqrt(hidden_size)
hidden_states *= self.config.hidden_size**0.5
hidden_states *= self.normalizer

residual = None
for i in range(len(self.layers)):
Expand Down

0 comments on commit a882ca5

Please sign in to comment.