Skip to content

Commit 6621845

Browse files
committed
[Quantization] support pass MappingType for TorchAoConfig
1 parent 37a5f1b commit 6621845

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/diffusers/quantizers/quantization_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from packaging import version
3434

3535
from ..utils import is_torch_available, is_torchao_available, logging
36+
from torchao.quantization.quant_primitives import MappingType
3637

3738

3839
if is_torch_available():
@@ -46,6 +47,11 @@ class QuantizationMethod(str, Enum):
4647
GGUF = "gguf"
4748
TORCHAO = "torchao"
4849

50+
class CustomJSONEncoder(json.JSONEncoder):
51+
def default(self, obj):
52+
if isinstance(obj, MappingType):
53+
return obj.name
54+
return super().default(obj)
4955

5056
@dataclass
5157
class QuantizationConfigMixin:
@@ -673,4 +679,4 @@ def __repr__(self):
673679
```
674680
"""
675681
config_dict = self.to_dict()
676-
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
682+
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=CustomJSONEncoder)}\n"

0 commit comments

Comments
 (0)