-
-
Notifications
You must be signed in to change notification settings - Fork 10.9k
Enable RMSNorm substitution for Transformers backend
#26353
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
dee601f
423e4eb
8f945a1
9c0e0fc
c8eba50
2568fa5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,7 +43,7 @@ | |
| from vllm.distributed import get_pp_group, get_tp_group | ||
| from vllm.distributed.utils import get_pp_indices | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.layers.layernorm import RMSNorm | ||
| from vllm.model_executor.layers.layernorm import GemmaRMSNorm, RMSNorm | ||
| from vllm.model_executor.layers.linear import ( | ||
| ColumnParallelLinear, | ||
| ReplicatedLinear, | ||
|
|
@@ -194,15 +194,29 @@ def replace_rms_norm_class(rms_norm: nn.Module, hidden_size: int) -> RMSNorm: | |
| - `var_hidden_size` is only ever used for Intern vision encoder in vLLM | ||
| and Transformers doesn't appear to have the same concept. | ||
| """ | ||
| kwargs = { | ||
| "hidden_size": hidden_size, | ||
| "eps": getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6), | ||
| "has_weight": getattr(rms_norm, "with_scale", True), | ||
| } | ||
| if (weight := getattr(rms_norm, "weight", None)) is not None: | ||
| # If weight is a Parameter, get its data tensor | ||
| weight = getattr(weight, "data", weight) | ||
| kwargs["dtype"] = weight.dtype | ||
| eps = getattr_iter(rms_norm, ("eps", "variance_epsilon"), 1e-6) | ||
| kwargs = {"hidden_size": hidden_size, "eps": eps} | ||
| # Update hidden size if weight is available | ||
| weight_meta = getattr(rms_norm, "weight", None) | ||
| if weight_meta is not None: | ||
| kwargs["hidden_size"] = weight_meta.size(0) | ||
| # Check if weight is all zeros, which indicates GemmaRMSNorm | ||
| # We must create a new instance because rms_norm is on meta | ||
| try: | ||
| with torch.device("cpu"): | ||
| weight_test = getattr(rms_norm.__class__(1), "weight", None) | ||
| except Exception: | ||
| logger.warning( | ||
| "Failed to determine if RMSNorm weight is centered on zero or one. " | ||
| "Defaulting to one." | ||
| ) | ||
| weight_test = None | ||
| if weight_test is not None and torch.all(weight_test == 0): | ||
| return GemmaRMSNorm(**kwargs) | ||
|
Comment on lines
+214
to
+215
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we simply check the existence of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would work for Gemma models as they are currently implemented in Transformers. However:
|
||
| # Otherwise assume it's a regular RMSNorm | ||
| kwargs["has_weight"] = getattr(rms_norm, "with_scale", True) | ||
| if weight_meta is not None: | ||
| kwargs["dtype"] = weight_meta.dtype | ||
| else: | ||
| # No weight, fall back to weightless RMSNorm | ||
| kwargs["has_weight"] = False | ||
|
|
@@ -645,11 +659,10 @@ def _recursive_replace(module: nn.Module, prefix: str): | |
| new_module = replace_linear_class( | ||
| child_module, style, self.quant_config, prefix=qual_name | ||
| ) | ||
| # TODO(hmellor): Enable RMSNorm replacement once we have a way | ||
| # to choose RMSNorm vs GemmaRMSNorm | ||
| # elif child_module.__class__.__name__.endswith("RMSNorm"): | ||
| # new_module = replace_rms_norm_class( | ||
| # child_module, self.config.hidden_size) | ||
| elif child_module.__class__.__name__.endswith("RMSNorm"): | ||
| new_module = replace_rms_norm_class( | ||
| child_module, self.text_config.hidden_size | ||
| ) | ||
| else: | ||
| _recursive_replace(child_module, prefix=qual_name) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.