Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/source/en/using-diffusers/automodel.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
> )
> ```

### Saving custom models

Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.

```py
# my_model.py
from diffusers import ModelMixin, ConfigMixin

class MyCustomModel(ModelMixin, ConfigMixin):
...

MyCustomModel.register_for_auto_class("AutoModel")

model = MyCustomModel(...)
model.save_pretrained("./my_model")
```

The saved `config.json` will include the `auto_map` field.

```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```

> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
38 changes: 38 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,38 @@ class ConfigMixin:
has_compatibles = False

_deprecated_kwargs = []
_auto_class = None

@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
trust_remote_code=True)`.

When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
to this class's module and class name.

Args:
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
Currently only `"AutoModel"` is supported.

Example:

```python
from diffusers import ModelMixin, ConfigMixin


class MyCustomModel(ModelMixin, ConfigMixin): ...


MyCustomModel.register_for_auto_class("AutoModel")
```
"""
if auto_class != "AutoModel":
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")

cls._auto_class = auto_class

def register_to_config(self, **kwargs):
if self.config_name is None:
Expand Down Expand Up @@ -621,6 +653,12 @@ def to_json_saveable(value):
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = config_dict.pop("_pre_quantization_dtype", None)

if getattr(self, "_auto_class", None) is not None:
module = self.__class__.__module__.split(".")[-1]
auto_map = config_dict.get("auto_map", {})
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
config_dict["auto_map"] = auto_map

return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

def to_json_file(self, json_file_path: str | os.PathLike):
Expand Down
53 changes: 53 additions & 0 deletions tests/models/test_models_auto.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import json
import os
import tempfile
import unittest
from unittest.mock import MagicMock, patch

from transformers import CLIPTextModel, LongformerModel

from diffusers import ConfigMixin
from diffusers.models import AutoModel, UNet2DConditionModel
from diffusers.models.modeling_utils import ModelMixin


class TestAutoModel(unittest.TestCase):
Expand Down Expand Up @@ -100,3 +105,51 @@ def test_from_config_with_model_type_routes_to_transformers(self, mock_get_class
def test_from_config_raises_on_none(self):
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
AutoModel.from_config(None)


class TestRegisterForAutoClass(unittest.TestCase):
def test_register_for_auto_class_sets_attribute(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"

DummyModel.register_for_auto_class("AutoModel")
self.assertEqual(DummyModel._auto_class, "AutoModel")

def test_register_for_auto_class_rejects_unsupported(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"

with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
DummyModel.register_for_auto_class("AutoPipeline")

def test_auto_map_in_saved_config(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"

DummyModel.register_for_auto_class("AutoModel")
model = DummyModel()

with tempfile.TemporaryDirectory() as tmpdir:
model.save_config(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)

self.assertIn("auto_map", config)
self.assertIn("AutoModel", config["auto_map"])
module_name = DummyModel.__module__.split(".")[-1]
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")

def test_no_auto_map_without_register(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"

model = DummyModel()

with tempfile.TemporaryDirectory() as tmpdir:
model.save_config(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)

self.assertNotIn("auto_map", config)
Loading