Skip to content

Fix AOPerModuleConfig name changes #18869

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,9 @@ steps:
- vllm/model_executor/layers/quantization
- tests/quantization
commands:
# temporary install here since we need nightly, will move to requirements/test.in
# after torchao 0.12 release
- pip install --pre torchao --index-url https://download.pytorch.org/whl/nightly/cu126
- VLLM_TEST_FORCE_LOAD_FORMAT=auto pytest -v -s quantization

- label: LM Eval Small Models # 53min
Expand Down
6 changes: 3 additions & 3 deletions tests/quantization/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
def test_pre_quantized_model(vllm_runner):
with vllm_runner("drisspg/float8_dynamic_act_float8_weight-opt-125m",
with vllm_runner("drisspg/fp8-opt-125m",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a unified place to store all the models, :-)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any recommendations? other quantization methods are also using random places:

models_4bit_to_test = [
("facebook/opt-125m", "quantize opt model inflight"),
("mistralai/Mistral-7B-Instruct-v0.3",
"quantize inflight model with both HF and Mistral format weights")
]
models_4bit_to_embedding_test = [
("intfloat/e5-mistral-7b-instruct", "quantize embedding model inflight"),
]
models_pre_qaunt_4bit_to_test = [
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
'read pre-quantized 4-bit FP4 model'),
('poedator/opt-125m-bnb-4bit', 'read pre-quantized 4-bit NF4 opt model'),
]
models_pre_quant_8bit_to_test = [
('meta-llama/Llama-Guard-3-8B-INT8',
'read pre-quantized llama 8-bit model'),
("yec019/fbopt-350m-8bit", "read pre-quantized 8-bit opt model"),
]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about having a central repo, like torchao/fp8-opt-125m?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just for test models? I feel might be a bit overkill

We do release official torchao models under pytorch, e.g.: https://huggingface.co/collections/pytorch/torchao-quantized-phi-4-mini-instruct-681566f123acc6fed345cb1a

quantization="torchao",
dtype="bfloat16",
enforce_eager=True) as llm:
Expand All @@ -30,10 +30,10 @@ def test_pre_quantized_model(vllm_runner):
"cuda:0",
# {"": "cuda"},
])
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
def test_opt_125m_int8wo_model_loading_with_params(vllm_runner,
pt_load_map_location):
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo"
model_name = "jerryzh168/opt-125m-int8wo-partial-quant"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
Expand Down
21 changes: 19 additions & 2 deletions vllm/model_executor/layers/quantization/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,36 @@
import torch.nn.functional as F
from torch.nn.parameter import Parameter

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs

logger = init_logger(__name__)


class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""

def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
if is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
logger.info(
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")

# TODO: remove after the torch dependency is updated to 2.8
if is_torch_equal_or_newer(
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""

def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
Expand Down Expand Up @@ -61,10 +78,10 @@ def get_quant_method(self, layer: torch.nn.Module,
if not isinstance(layer, LinearBase):
return None

from torchao.quantization import AOPerModuleConfig
from torchao.quantization import ModuleFqnToConfig

module_fqn = prefix
if isinstance(self.torchao_config, AOPerModuleConfig):
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)
Expand Down