-
Notifications
You must be signed in to change notification settings - Fork 6.6k
[core / PEFT / LoRA] Integrate PEFT into Unet
#5151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
cf2c0ba
8759f55
c90aedc
0bfb136
3002ea3
d6f500c
64ca2bb
f62e506
48842c0
71d4990
1fb4aa2
4c803f6
11a493a
4ea8959
16b1161
b3a02be
cc135f2
5c493e5
a09530c
d3ce092
9e500d2
edaea14
8781506
0a14573
c26c418
6996b82
68912e4
10e0e61
8c42fa1
99fec57
ac925f8
ebb16ca
81f886e
7376deb
ff82de4
94403c1
e8fca9f
4f21a7b
3568e7f
459285f
24dad33
ec04337
b40592a
2646f3d
0e771f0
86bd6f5
02e73a4
86c7d69
94abbc0
61e316c
32dd0d5
ba6c180
892d1d3
c0d9d68
7e1e252
f4a5229
8dc6b87
4746de1
0413049
1d517e3
6fe1b2d
2825d5b
206f0de
2265fc2
265a928
e7a3dc6
7868b48
abb2325
81db89f
fc643eb
957108b
5d9ce0d
bd44f56
71c321e
c42d974
a0598e6
a7a6cd6
9992964
525743e
7183863
e44c17c
f435ce9
7e8cb7a
2af9bfd
8da2350
f497280
c0ce809
74cfc1c
36ec721
2c94a86
95d2b44
e82d83c
92aef0b
f939e04
10f6352
4b1a073
22452b7
48ae256
06db84d
0723b55
ca039a5
44ae0a9
a01d542
d64dc6f
f6d6e5d
5394d37
32043aa
599f556
6faee80
1c94452
a14779e
18f3a25
cad5a4b
3708ed9
323612b
ec65342
64e2d87
db0c3dc
6172b64
74d80a9
cb588ae
b419b52
498dc17
02d17b3
400c2da
1402506
e6d8042
3c4dc79
d5e7647
a5cd549
a02c162
ffaf30f
4145818
836e32e
2fa61fc
9102399
924222e
21a279a
0fe4203
e981af2
61737cf
b2150d9
3521188
d03d1a3
fabb521
fc55a1a
7106b22
1834f8e
21a8d6c
44f658d
7fd50a7
e92c6de
4e382ee
6ae767f
a0f976e
b4e1381
f708dba
f17206c
950d19c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| import torch.nn.functional as F | ||
| from torch import nn | ||
|
|
||
| from ..utils import USE_PEFT_BACKEND | ||
| from ..utils.torch_utils import maybe_allow_in_graph | ||
| from .activations import get_activation | ||
| from .attention_processor import Attention | ||
|
|
@@ -300,6 +301,7 @@ def __init__( | |
| super().__init__() | ||
| inner_dim = int(dim * mult) | ||
| dim_out = dim_out if dim_out is not None else dim | ||
| linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear | ||
|
|
||
| if activation_fn == "gelu": | ||
| act_fn = GELU(dim, inner_dim) | ||
|
|
@@ -316,14 +318,15 @@ def __init__( | |
| # project dropout | ||
| self.net.append(nn.Dropout(dropout)) | ||
| # project out | ||
| self.net.append(LoRACompatibleLinear(inner_dim, dim_out)) | ||
| self.net.append(linear_cls(inner_dim, dim_out)) | ||
| # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout | ||
| if final_dropout: | ||
| self.net.append(nn.Dropout(dropout)) | ||
|
|
||
| def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: | ||
| compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear) | ||
| for module in self.net: | ||
| if isinstance(module, (LoRACompatibleLinear, GEGLU)): | ||
| if isinstance(module, compatible_cls): | ||
| hidden_states = module(hidden_states, scale) | ||
| else: | ||
| hidden_states = module(hidden_states) | ||
|
|
@@ -368,7 +371,9 @@ class GEGLU(nn.Module): | |
|
|
||
| def __init__(self, dim_in: int, dim_out: int): | ||
| super().__init__() | ||
| self.proj = LoRACompatibleLinear(dim_in, dim_out * 2) | ||
| linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear | ||
|
|
||
| self.proj = linear_cls(dim_in, dim_out * 2) | ||
|
|
||
| def gelu(self, gate: torch.Tensor) -> torch.Tensor: | ||
| if gate.device.type != "mps": | ||
|
|
@@ -377,7 +382,8 @@ def gelu(self, gate: torch.Tensor) -> torch.Tensor: | |
| return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) | ||
|
|
||
| def forward(self, hidden_states, scale: float = 1.0): | ||
| hidden_states, gate = self.proj(hidden_states, scale).chunk(2, dim=-1) | ||
| args = () if USE_PEFT_BACKEND else (scale,) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice. |
||
| hidden_states, gate = self.proj(hidden_states, *args).chunk(2, dim=-1) | ||
| return hidden_states * self.gelu(gate) | ||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,10 +32,12 @@ | |
| DIFFUSERS_CACHE, | ||
| FLAX_WEIGHTS_NAME, | ||
| HF_HUB_OFFLINE, | ||
| MIN_PEFT_VERSION, | ||
| SAFETENSORS_WEIGHTS_NAME, | ||
| WEIGHTS_NAME, | ||
| _add_variant, | ||
| _get_model_file, | ||
| check_peft_version, | ||
| deprecate, | ||
| is_accelerate_available, | ||
| is_torch_version, | ||
|
|
@@ -187,6 +189,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): | |
| _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] | ||
| _supports_gradient_checkpointing = False | ||
| _keys_to_ignore_on_load_unexpected = None | ||
| _hf_peft_config_loaded = False | ||
|
|
||
| def __init__(self): | ||
| super().__init__() | ||
|
|
@@ -292,6 +295,153 @@ def disable_xformers_memory_efficient_attention(self): | |
| """ | ||
| self.set_use_memory_efficient_attention_xformers(False) | ||
|
|
||
| def add_adapter(self, adapter_config, adapter_name: str = "default") -> None: | ||
| r""" | ||
| Adds a new adapter to the current model for training. If no adapter name is passed, a default name is assigned | ||
| to the adapter to follow the convention of the PEFT library. | ||
|
|
||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them in the PEFT | ||
| [documentation](https://huggingface.co/docs/peft). | ||
|
|
||
| Args: | ||
| adapter_config (`[~peft.PeftConfig]`): | ||
| The configuration of the adapter to add; supported adapters are non-prefix tuning and adaption prompt | ||
| methods. | ||
| adapter_name (`str`, *optional*, defaults to `"default"`): | ||
| The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. | ||
pacman100 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||
|
|
||
| from peft import PeftConfig, inject_adapter_in_model | ||
|
|
||
| if not self._hf_peft_config_loaded: | ||
| self._hf_peft_config_loaded = True | ||
| elif adapter_name in self.peft_config: | ||
| raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") | ||
|
|
||
| if not isinstance(adapter_config, PeftConfig): | ||
| raise ValueError( | ||
| f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." | ||
| ) | ||
|
|
||
| # Unlike transformers, here we don't need to retrieve the name_or_path of the unet as the loading logic is | ||
| # handled by the `load_lora_layers` or `LoraLoaderMixin`. Therefore we set it to `None` here. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't quite get this. Does it hurt to have
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No it does not, but I think there is no equivalent of it in diffusers per my understanding |
||
| adapter_config.base_model_name_or_path = None | ||
| inject_adapter_in_model(adapter_config, self, adapter_name) | ||
| self.set_adapter(adapter_name) | ||
|
|
||
| def set_adapter(self, adapter_name: Union[str, List[str]]) -> None: | ||
| """ | ||
| Sets a specific adapter by forcing the model to only use that adapter and disables the other adapters. | ||
|
|
||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||
younesbelkada marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| official documentation: https://huggingface.co/docs/peft | ||
|
|
||
| Args: | ||
| adapter_name (Union[str, List[str]])): | ||
| The list of adapters to set or the adapter name in case of single adapter. | ||
| """ | ||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||
|
|
||
| if not self._hf_peft_config_loaded: | ||
| raise ValueError("No adapter loaded. Please load an adapter first.") | ||
|
|
||
| if isinstance(adapter_name, str): | ||
| adapter_name = [adapter_name] | ||
|
|
||
| missing = set(adapter_name) - set(self.peft_config) | ||
| if len(missing) > 0: | ||
| raise ValueError( | ||
| f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." | ||
| f" current loaded adapters are: {list(self.peft_config.keys())}" | ||
| ) | ||
|
|
||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| _adapters_has_been_set = False | ||
|
|
||
| for _, module in self.named_modules(): | ||
| if isinstance(module, BaseTunerLayer): | ||
| if hasattr(module, "set_adapter"): | ||
| module.set_adapter(adapter_name) | ||
| # Previous versions of PEFT does not support multi-adapter inference | ||
| elif not hasattr(module, "set_adapter") and len(adapter_name) != 1: | ||
| raise ValueError( | ||
| "You are trying to set multiple adapters and you have a PEFT version that does not support multi-adapter inference. Please upgrade to the latest version of PEFT." | ||
| " `pip install -U peft` or `pip install -U git+https://github.com/huggingface/peft.git`" | ||
| ) | ||
| else: | ||
| module.active_adapter = adapter_name | ||
| _adapters_has_been_set = True | ||
|
|
||
| if not _adapters_has_been_set: | ||
| raise ValueError( | ||
| "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." | ||
| ) | ||
|
|
||
| def disable_adapters(self) -> None: | ||
| r""" | ||
| Disable all adapters attached to the model and fallback to inference with the base model only. | ||
|
|
||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||
| official documentation: https://huggingface.co/docs/peft | ||
| """ | ||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||
|
|
||
| if not self._hf_peft_config_loaded: | ||
| raise ValueError("No adapter loaded. Please load an adapter first.") | ||
|
|
||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| for _, module in self.named_modules(): | ||
| if isinstance(module, BaseTunerLayer): | ||
| if hasattr(module, "enable_adapters"): | ||
| module.enable_adapters(enabled=False) | ||
| else: | ||
| # support for older PEFT versions | ||
| module.disable_adapters = True | ||
|
|
||
| def enable_adapters(self) -> None: | ||
| """ | ||
| Enable adapters that are attached to the model. The model will use `self.active_adapters()` to retrieve the | ||
| list of adapters to enable. | ||
|
|
||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||
| official documentation: https://huggingface.co/docs/peft | ||
| """ | ||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||
|
|
||
| if not self._hf_peft_config_loaded: | ||
| raise ValueError("No adapter loaded. Please load an adapter first.") | ||
|
|
||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| for _, module in self.named_modules(): | ||
| if isinstance(module, BaseTunerLayer): | ||
| if hasattr(module, "enable_adapters"): | ||
| module.enable_adapters(enabled=True) | ||
| else: | ||
| # support for older PEFT versions | ||
| module.disable_adapters = False | ||
|
|
||
| def active_adapters(self) -> List[str]: | ||
| """ | ||
| Gets the current list of active adapters of the model. | ||
|
|
||
| If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT | ||
| official documentation: https://huggingface.co/docs/peft | ||
| """ | ||
| check_peft_version(min_version=MIN_PEFT_VERSION) | ||
|
|
||
| if not self._hf_peft_config_loaded: | ||
| raise ValueError("No adapter loaded. Please load an adapter first.") | ||
|
|
||
| from peft.tuners.tuners_utils import BaseTunerLayer | ||
|
|
||
| for _, module in self.named_modules(): | ||
| if isinstance(module, BaseTunerLayer): | ||
| return module.active_adapter | ||
BenjaminBossan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def save_pretrained( | ||
| self, | ||
| save_directory: Union[str, os.PathLike], | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.