Skip to content

Conversation

@sayakpaul
Copy link
Member

What does this PR do?

The AOBaseConfig classes introduced in torchao (since 0.9.0) are more flexible. Similar to Transformers, this PR adds support for allowing them in Diffusers:

from diffusers import DiffusionPipeline, TorchAoConfig, PipelineQuantizationConfig
from torchao.quantization import Int8WeightOnlyConfig
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"
pipeline_quant_config = PipelineQuantizationConfig(
    quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())}
)
pipe = DiffusionPipeline.from_pretrained(
    ckpt_id, quantization_config=pipeline_quant_config, torch_dtype=torch.bfloat16
).to("cuda")
_ = pipe("dog", num_inference_steps=2)

@stevhliu, would it be possible for you to propagate the relevant changes to our TorchAO docs from Transformers? Can happen in a later PR.

@sayakpaul sayakpaul requested a review from a-r-r-o-w September 3, 2025 07:11
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@stevhliu stevhliu mentioned this pull request Sep 8, 2025
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
# String-based config

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can deprecate this one since this is less scalable than AOBaseConfig

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will do so after this PR. Meanwhile, if you could review the PR, it'd be helpful.

@sayakpaul sayakpaul requested review from DN6 and SunMarc September 9, 2025 02:06
init

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for updating this !

f"Please upgrade to torchao > 0.9.0 to use `AOBaseConfig` instances."
)

if isinstance(self.quant_type, str):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: merge this with the branch in L521 to keep relevant things relevant together?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

if isinstance(self.quant_type, str):
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably clean this up in the future..

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not even needed since we will deprecate this codepath after this PR is merged.

Copy link

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, would be good to put up a deprecation plan for torchao version support and clean up the old code for str support I think, I don't think there is any reason why people can't upgrade to the most recent version currently

@sayakpaul
Copy link
Member Author

@jerryzh168 thanks for your reviews! I agree that it will make the library codebase cleaner if we started a deprecation cycle to promote the AoBaseConfig more and more. Will let @DN6 also review and then merge.


self.post_init()

def post_init(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it would be cleaner to run the checks in the begining and exit/fail early

def post_init(self):
    if not isinstance(self.quant_type, str):
        if not is_torchao_version(">=", "0.9.0"):
            raise ValueError(
                f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
                f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
            )
        
        from torchao.quantization.quant_api import AOBaseConfig
        if not isinstance(self.quant_type, AOBaseConfig):
            raise TypeError(
                f"quant_type must be a string or AOBaseConfig instance, got {type(self.quant_type).__name__}"
            )
        return
    
    TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
    
    if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS:
        # remaining str type validation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not isinstance(self.quant_type, str):
        if not is_torchao_version(">=", "0.9.0"):
            raise ValueError(
                f"torchao <= 0.9.0 only supports string quant_type, got {type(self.quant_type).__name__}. "
                f"Upgrade to torchao > 0.9.0 to use AOBaseConfig."
            )

This will be a breaking change. I think we should introduce a deprecation cycle before enforcing this.

Copy link
Collaborator

@DN6 DN6 Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why breaking? The error is only raised if quant_type is not a string and torchao<=0.9.0? Can change the second check to if is_torchao_version("<=", "0.9.0"): if that is more clear?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@sayakpaul sayakpaul requested a review from DN6 September 22, 2025 11:59
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 👍🏽

@sayakpaul sayakpaul merged commit 64a5187 into main Sep 29, 2025
13 of 14 checks passed
@sayakpaul sayakpaul deleted the aobaseconfig branch September 29, 2025 12:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants