Skip to content

Fix TorchAOConfig skip layers #19147

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,19 @@ def test_opt_125m_int4wo_model_per_module_quant(vllm_runner):
assert output
print(output)

@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_qwenvl_int8wo_model_loading_with_params(vllm_runner):
torch._dynamo.reset()
model_name = "mobicham/Qwen2.5-VL-3B-Instruct_int8wo_ao"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location="cuda:0") as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)

assert output
print(output)

if __name__ == "__main__":
pytest.main([__file__])
33 changes: 26 additions & 7 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""

def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
def __init__(self,
torchao_config,
skip_modules: Optional[list[str]] = None) -> None:
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
Expand All @@ -36,6 +37,8 @@ def __init__(self, torchao_config) -> None:
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
self.torchao_config = torchao_config
self.skip_modules = skip_modules or []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add that too, lemme take a look at the configs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is kinda problematic, this example only skips quant for q_proj but in vLLM QKV are merged, so it's not gonna work. I can still add this logic for layers that are loaded exactly as they are in vLLM (like o_proj).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added here: c4201b6

You can test it with this code, it was breaking before this commit:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig

model_id = "meta-llama/Llama-3.1-8B-Instruct"

from torchao.quantization import Int4WeightOnlyConfig, ModuleFqnToConfig
config = Int4WeightOnlyConfig(group_size=128)

quant_config = ModuleFqnToConfig({"_default": config, "model.layers.0.self_attn.o_proj": None, "model.layers.13.self_attn.o_proj": None})
quantization_config = TorchAoConfig(quant_type=quant_config)
quantized_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)

print("quantized model:", quantized_model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

q_model_id = 'quant_model_test'
quantized_model.save_pretrained(q_model_id, safe_serialization=False) 
tokenizer.save_pretrained(q_model_id)
######################################################################################

import torch
from vllm import LLM
from vllm.sampling_params import SamplingParams
llm = LLM(model="quant_model_test", dtype=torch.bfloat16) 
sampling_params = SamplingParams(max_tokens=1024, temperature=0.5, repetition_penalty=1.1, ignore_eos=False)
messages = [{"content": "You are a helpful assistant", "role":"system"}, {"content":"Solve this equation x^2 + 1 = -1.", "role":"user"}]
outputs = llm.chat(messages, sampling_params, chat_template=llm.get_tokenizer().chat_template)
print(outputs[0].outputs[0].text)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm this is kinda problematic, this example only skips quant for q_proj but in vLLM QKV are merged, so it's not gonna work. I can still add this logic for layers that are loaded exactly as they are in vLLM (like o_proj).

not sure I follow, are you saying both skipped_module and ModuleFqnToConfig can't handle this, or skipped_module can handle this but ModuleFqnToConfig can't?

It seems to me the functionality of skipped_modules is already covered by ModuleFqnToConfig and this configuration is serializable to transformer models such as https://huggingface.co/pytorch/Phi-4-mini-instruct-8da4w/blob/main/config.json#L102 so users don't need to worry about having extra configs, so I think we should just use that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's not the issue, the commit above handles that. The issue is that, in vLLM, q_proj + k_proj + v_proj are merged into a single tensor, so they need to use the same quant config or all None (same for mlp). Otherwise the logic works as you can see in the example above with o_proj.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so they need to use the same quant config or all None (same for mlp).

how is this enforced? are you saying torchao config should be aware that q_proj, k_proj and v_proj are connected and set them to use the same config automatically?


def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
Expand Down Expand Up @@ -67,11 +70,21 @@ def from_config(cls, config: dict[str, Any]) -> "TorchAOConfig":

hf_config = cls.get_from_keys_or(config, ["quant_type"], None)
assert hf_config is not None, "quant_type must be specified"
assert (len(hf_config) == 1 and "default" in hf_config
), "Expected only one key 'default' in quant_type dictionary"
assert len(hf_config) == 1 and "default" in hf_config, (
"Expected only one key 'default' in quant_type dictionary")
quant_type = hf_config["default"]
ao_config = config_from_dict(quant_type)
return cls(ao_config)

# Adds skipped modules defined in "modules_to_not_convert"
skip_modules = config.get("modules_to_not_convert", []) or []

# adds skipped modules defined in "module_fqn_to_config"
module_fqn = quant_type["_data"]["module_fqn_to_config"]
for layer in module_fqn:
if module_fqn.get(layer, None) is None:
skip_modules.append(layer)

return cls(ao_config, skip_modules)

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["QuantizeMethodBase"]:
Expand All @@ -80,13 +93,16 @@ def get_quant_method(self, layer: torch.nn.Module,

from torchao.quantization import ModuleFqnToConfig

if any(s in prefix for s in self.skip_modules):
return UnquantizedLinearMethod()

module_fqn = prefix
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(c)
current_torchao_config = TorchAOConfig(c, self.skip_modules)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to pass skip_modules here? it's already been used in L86 right

return TorchAOLinearMethod(current_torchao_config)
else:
return UnquantizedLinearMethod()
Expand All @@ -108,8 +124,11 @@ def torchao_quantize_param_data(param: torch.Tensor,
"""
from torchao.core.config import AOBaseConfig
from torchao.quantization import quantize_

assert isinstance(torchao_config, AOBaseConfig), f"{torchao_config}"
dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
dummy_linear = torch.nn.Linear(1, 1, bias=False)
dummy_linear.in_features = param.shape[1]
dummy_linear.out_features = param.shape[0]
dummy_linear.weight = param
quantize_(dummy_linear, torchao_config)
return dummy_linear.weight
Expand Down
Loading