-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
Fix TorchAOConfig skip layers #19147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
+39
−7
Closed
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
f8be4e2
Fix layer skip in TorchAOConfig + avoid nn.Linear(in_features, out_fe…
mobicham 2153c73
Merge branch 'vllm-project:main' into main
mobicham ddd8d50
switch to config.get(modules_to_not_convert, [])
mobicham 4dbe05a
fix init
mobicham eef5934
Merge branch 'vllm-project:main' into main
mobicham 3690738
add qwenvl loading test
mobicham c4201b6
skip layers in module_fqn_to_config
mobicham 263c10c
Merge branch 'main' into main
mobicham 6e4c2b7
pre-commit fix
mobicham File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,8 +20,9 @@ | |
class TorchAOConfig(QuantizationConfig): | ||
"""Config class for torchao.""" | ||
|
||
def __init__(self, torchao_config) -> None: | ||
self.torchao_config = torchao_config | ||
def __init__(self, | ||
torchao_config, | ||
skip_modules: Optional[list[str]] = None) -> None: | ||
""" | ||
# TorchAO quantization relies on tensor subclasses. In order, | ||
# to enable proper caching this needs standalone compile | ||
|
@@ -36,6 +37,8 @@ def __init__(self, torchao_config) -> None: | |
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1" | ||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1") | ||
""" | ||
self.torchao_config = torchao_config | ||
self.skip_modules = skip_modules or [] | ||
|
||
def __repr__(self) -> str: | ||
return f"TorchAOConfig({self.torchao_config})" | ||
|
@@ -67,11 +70,21 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig": | |
|
||
hf_config = cls.get_from_keys_or(config, ["quant_type"], None) | ||
assert hf_config is not None, "quant_type must be specified" | ||
assert (len(hf_config) == 1 and "default" in hf_config | ||
), "Expected only one key 'default' in quant_type dictionary" | ||
assert len(hf_config) == 1 and "default" in hf_config, ( | ||
"Expected only one key 'default' in quant_type dictionary") | ||
quant_type = hf_config["default"] | ||
ao_config = config_from_dict(quant_type) | ||
return cls(ao_config) | ||
|
||
# Adds skipped modules defined in "modules_to_not_convert" | ||
skip_modules = config.get("modules_to_not_convert", []) or [] | ||
|
||
# adds skipped modules defined in "module_fqn_to_config" | ||
module_fqn = quant_type["_data"]["module_fqn_to_config"] | ||
for layer in module_fqn: | ||
if module_fqn.get(layer, None) is None: | ||
skip_modules.append(layer) | ||
|
||
return cls(ao_config, skip_modules) | ||
|
||
def get_quant_method(self, layer: torch.nn.Module, | ||
prefix: str) -> Optional["QuantizeMethodBase"]: | ||
|
@@ -80,13 +93,16 @@ def get_quant_method(self, layer: torch.nn.Module, | |
|
||
from torchao.quantization import ModuleFqnToConfig | ||
|
||
if any(s in prefix for s in self.skip_modules): | ||
return UnquantizedLinearMethod() | ||
|
||
module_fqn = prefix | ||
if isinstance(self.torchao_config, ModuleFqnToConfig): | ||
module_fqn_to_config = self.torchao_config.module_fqn_to_config | ||
c = module_fqn_to_config.get( | ||
module_fqn) or module_fqn_to_config.get("_default", None) | ||
if c is not None: | ||
current_torchao_config = TorchAOConfig(c) | ||
current_torchao_config = TorchAOConfig(c, self.skip_modules) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to pass |
||
return TorchAOLinearMethod(current_torchao_config) | ||
else: | ||
return UnquantizedLinearMethod() | ||
|
@@ -108,8 +124,11 @@ def torchao_quantize_param_data(param: torch.Tensor, | |
""" | ||
from torchao.core.config import AOBaseConfig | ||
from torchao.quantization import quantize_ | ||
|
||
assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}" | ||
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) | ||
dummy_linear = torch.nn.Linear(1, 1, bias=False) | ||
dummy_linear.in_features = param.shape[1] | ||
dummy_linear.out_features = param.shape[0] | ||
dummy_linear.weight = param | ||
quantize_(dummy_linear, torchao_config) | ||
return dummy_linear.weight | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use https://huggingface.co/docs/transformers/main/en/quantization/torchao#per-module-quantization to skip modules?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add that too, lemme take a look at the configs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm this is kinda problematic, this example only skips quant for q_proj but in vLLM QKV are merged, so it's not gonna work. I can still add this logic for layers that are loaded exactly as they are in vLLM (like o_proj).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added here: c4201b6
You can test it with this code, it was breaking before this commit:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure I follow, are you saying both skipped_module and ModuleFqnToConfig can't handle this, or skipped_module can handle this but ModuleFqnToConfig can't?
It seems to me the functionality of skipped_modules is already covered by ModuleFqnToConfig and this configuration is serializable to transformer models such as https://huggingface.co/pytorch/Phi-4-mini-instruct-8da4w/blob/main/config.json#L102 so users don't need to worry about having extra configs, so I think we should just use that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's not the issue, the commit above handles that. The issue is that, in vLLM, q_proj + k_proj + v_proj are merged into a single tensor, so they need to use the same quant config or all None (same for mlp). Otherwise the logic works as you can see in the example above with o_proj.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how is this enforced? are you saying torchao config should be aware that q_proj, k_proj and v_proj are connected and set them to use the same config automatically?