Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from vllm.distributed import get_pp_group, get_tp_group
from vllm.distributed.utils import get_pp_indices
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
ReplicatedLinear,
Expand Down Expand Up @@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm:
- `var_hidden_size` is only ever used for Intern vision encoder in vLLM
and Transformers doesn't appear to have the same concept.
"""
kwargs = {
"hidden_size": hidden_size,
"eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6),
"has_weight": getattr(rms_norm, "with_scale", True),
}
if (weight := getattr(rms_norm, "weight", None)) is not None:
# If weight is a Parameter, get its data tensor
weight = getattr(weight, "data", weight)
kwargs["dtype"] = weight.dtype
eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6)
kwargs = {"hidden_size": hidden_size, "eps": eps}
# Update hidden size if weight is available
weight_meta = getattr(rms_norm, "weight", None)
if weight_meta is not None:
kwargs["hidden_size"] = weight_meta.size(0)
# Check if weight is all zeros, which indicates GemmaRMSNorm
# We must create a new instance because rms_norm is on meta
try:
with torch.device("cpu"):
weight_test = getattr(rms_norm.__class__(1), "weight", None)
except Exception:
logger.warning(
"Failed to determine if RMSNorm weight is centered on zero or one. "
"Defaulting to one."
)
weight_test = None
if weight_test is not None and torch.all(weight_test == 0):
return GemmaRMSNorm(**kwargs)
Comment on lines +214 to +215
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would work for Gemma models as they are currently implemented in Transformers. However:

  • Custom models may not use this pattern
  • _norm is a private method and so may change under us
  • A counter-example would be Moshi, which implements _norm but does x * w instead of x * (1 + w)

# Otherwise assume it's a regular RMSNorm
kwargs["has_weight"] = getattr(rms_norm, "with_scale", True)
if weight_meta is not None:
kwargs["dtype"] = weight_meta.dtype
else:
# No weight, fall back to weightless RMSNorm
kwargs["has_weight"] = False
Expand Down Expand Up @@ -645,11 +659,10 @@ def _recursive_replace(module: nn.Module, prefix: str):
new_module = replace_linear_class(
child_module, style, self.quant_config, prefix=qual_name
)
# TODO(hmellor): Enable RMSNorm replacement once we have a way
# to choose RMSNorm vs GemmaRMSNorm
# elif child_module.__class__.__name__.endswith("RMSNorm"):
# new_module = replace_rms_norm_class(
# child_module, self.config.hidden_size)
elif child_module.__class__.__name__.endswith("RMSNorm"):
new_module = replace_rms_norm_class(
child_module, self.text_config.hidden_size
)
else:
_recursive_replace(child_module, prefix=qual_name)

Expand Down