Skip to content

Fix TorchAOConfig skip layers #19265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 12, 2025
15 changes: 15 additions & 0 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,20 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
print(output)


@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output
print(output)


if __name__ == "__main__":
pytest.main([__file__])
64 changes: 57 additions & 7 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,30 @@
logger = init_logger(__name__)


def should_skip(prefix: str, skip_modules: list[str]) -> bool:
"""
Robust skipping logic:
should_skip("model.model.layers.1.q_proj",
["model.model.layers.1.q_proj"]) # True
should_skip("model.model.layers.10.o_proj", ["o_proj"]) -> True
should_skip("visual.model.layers.1.q_proj", ["visual"]) -> True
should_skip("model.model.layers.1.q_proj", ["layers.1"]) -> True
should_skip("model.model.layers.11.q_proj", ["layers.1"]) -> False
"""
for s in skip_modules:
if prefix == s:
return True
if f".{s}." in f".{prefix}.":
return True
return False


class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""

def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __init__(self,
torchao_config,
skip_modules: Optional[list[str]] = None) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
Expand All @@ -36,6 +55,8 @@ def __init__(self, torchao_config) -> None:
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []

def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
Expand Down Expand Up @@ -67,11 +88,28 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":

hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
assert len(hf_config) == 1 and "default" in hf_config, (
"Expected only one key 'default' in quant_type dictionary")
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)

# Adds skipped modules defined in "modules_to_not_convert"
skip_modules = config.get("modules_to_not_convert", []) or []

# Adds skipped modules defined in "module_fqn_to_config"
_data = quant_type.get("_data", {})
if not isinstance(_data, dict):
_data = {}

module_fqn = _data.get("module_fqn_to_config", {})
if not isinstance(module_fqn, dict):
module_fqn = {}

for layer, layer_cfg in module_fqn.items():
if layer_cfg is None:
skip_modules.append(layer)

return cls(ao_config, skip_modules)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
Expand All @@ -80,13 +118,16 @@ def get_quant_method(self, layer: torch.nn.Module,

from torchao.quantization import ModuleFqnToConfig

if should_skip(prefix, self.skip_modules):
return UnquantizedLinearMethod()

module_fqn = prefix
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c)
current_torchao_config = TorchAOConfig(c, self.skip_modules)
return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()
Expand All @@ -108,8 +149,17 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_

assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
"""
Avoid real weight allocation for faster load, since we will
end up setting it to param.
"""
with torch.device("meta"):
dummy_linear = torch.nn.Linear(param.shape[1],
param.shape[0],
bias=False)

dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
Expand Down