Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] SD3.5 IP-Adapter Pipeline Integration #9987

Merged
merged 55 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
0af910b
Initial pipeline for SD3.5-Large-IP-Adapter
guiyrt Nov 21, 2024
5567438
Added support for single IPAdapter on SD3.5 pipeline
guiyrt Dec 6, 2024
50d09d9
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 6, 2024
0ef36dd
Fixed typo and reverted removal of skip_layers in SD3Transformer2DModel
guiyrt Dec 7, 2024
de8909a
Added new SD3IPAdapterMixin loader
guiyrt Dec 9, 2024
ab0d904
ip_adapter image embeds now considers num_images_per_prompt
guiyrt Dec 9, 2024
d868ddb
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 9, 2024
5aed1d3
Removed usage of einops
guiyrt Dec 9, 2024
4383175
Reverted joint_attention_kwargs default for consistency
guiyrt Dec 9, 2024
461ab73
Corrected einops removal
guiyrt Dec 9, 2024
8323240
Quality and style checks
guiyrt Dec 9, 2024
89c4e63
Quality and style checks
guiyrt Dec 9, 2024
27d574f
Handle None joint_attention_kwargs in JointTransformerBlock
guiyrt Dec 9, 2024
0a48648
Fix test_components_function
hlky Dec 9, 2024
10d0a06
Remove from img2img/inpaint for now
hlky Dec 9, 2024
c78c4fd
Fixed loading ip_adapter state dict
guiyrt Dec 10, 2024
0f6c607
Simpler image encoding
guiyrt Dec 10, 2024
53fd40d
Style check
guiyrt Dec 10, 2024
8039599
Better checks for image prompt considering ip_adapter scale
guiyrt Dec 10, 2024
7333bfc
Minor change correcting checking for ip_adapter embeds
guiyrt Dec 10, 2024
a87895e
Removing old check of ip_adapter scale
guiyrt Dec 10, 2024
4ba374a
Refactor of image_proj (testing)
guiyrt Dec 10, 2024
819dd3e
Revert "Removing old check of ip_adapter scale"
guiyrt Dec 10, 2024
262a3bb
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 10, 2024
ea32e13
Corrected property check
guiyrt Dec 10, 2024
f60751f
Corrected forward() of IPAdapterTimeImageProjectionBlock
guiyrt Dec 11, 2024
b0aa5cb
IPAdapterTimeImageProjectionBlock now uses original attention impleme…
guiyrt Dec 12, 2024
b3dc69a
Clean-up and make style
guiyrt Dec 12, 2024
84aa4a3
Minor changes in code structure
guiyrt Dec 13, 2024
34793fb
make style && make quality
guiyrt Dec 13, 2024
27fe083
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 13, 2024
68169f8
Updated dosctrings and doc entries
guiyrt Dec 16, 2024
d824451
Merge branch 'main' into sd3.5_IPAdapter
hlky Dec 16, 2024
24e6880
make
hlky Dec 16, 2024
43d2e77
More docs and small refactors
guiyrt Dec 17, 2024
05f49e6
Merge remote-tracking branch 'origin' into sd3.5_IPAdapter
guiyrt Dec 17, 2024
44e3847
Fix in loading state dict
guiyrt Dec 18, 2024
178e513
Enabled cpu offload
guiyrt Dec 18, 2024
7899c6a
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 18, 2024
8daca65
Renaming from transformers_sd3 to transformer_sd3
guiyrt Dec 18, 2024
7c918db
Missing rename
guiyrt Dec 18, 2024
99a6d59
Updated docs for SD3 pipeline
guiyrt Dec 18, 2024
3916298
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 18, 2024
02a6d90
Update docs/source/en/api/pipelines/stable_diffusion/stable_diffusion…
guiyrt Dec 18, 2024
64ab7f9
Minor doc correction
guiyrt Dec 18, 2024
b254aa3
Updated img source to hf/documentation-images
guiyrt Dec 18, 2024
5c28161
image_proj is now called from SD3Transformer2DModel
guiyrt Dec 19, 2024
1313501
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 19, 2024
b882e1b
ip_adapter_image_embeds go through joint_attention_kwargs
guiyrt Dec 19, 2024
988447f
Warning for sequential cpu offloading with image_encoder
guiyrt Dec 19, 2024
66c4866
Merge branch 'main' into sd3.5_IPAdapter
guiyrt Dec 19, 2024
98f4521
make style quality
guiyrt Dec 19, 2024
5eef7f1
Merge branch 'main' into sd3.5_IPAdapter
yiyixuxu Dec 19, 2024
18cd8e4
Update src/diffusers/models/attention.py
yiyixuxu Dec 19, 2024
65b477f
Update src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_dif…
yiyixuxu Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/diffusers/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,10 @@ def text_encoder_attn_modules(text_encoder):
"Mochi1LoraLoaderMixin",
]
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
_import_structure["ip_adapter"] = ["IPAdapterMixin"]
_import_structure["ip_adapter"] = [
"IPAdapterMixin",
"SD3IPAdapterMixin",
]

_import_structure["peft"] = ["PeftAdapterMixin"]

Expand All @@ -83,7 +86,10 @@ def text_encoder_attn_modules(text_encoder):
from .utils import AttnProcsLayers

if is_transformers_available():
from .ip_adapter import IPAdapterMixin
from .ip_adapter import (
IPAdapterMixin,
SD3IPAdapterMixin,
)
from .lora_pipeline import (
AmusedLoraLoaderMixin,
CogVideoXLoraLoaderMixin,
Expand Down
244 changes: 236 additions & 8 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,18 @@


if is_transformers_available():
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel

from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterJointAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
JointAttnProcessor2_0,
)

from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
)

logger = logging.get_logger(__name__)

Expand Down Expand Up @@ -348,3 +351,228 @@ def unload_ip_adapter(self):
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)


class SD3IPAdapterMixin:
"""Mixin for handling StableDiffusion 3 IP Adapters."""

@property
def is_ip_adapter_active(self) -> bool:
r"""Checks if any ip_adapter attention processor have scale > 0.

IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
image is irrelevant.

Returns:
`bool`: True when ip_adapter is loaded and any ip_adapter layer scale > 0.
"""
scales = [
attn_proc.scale
for attn_proc in self.transformer.attn_processors.values()
if isinstance(attn_proc, IPAdapterJointAttnProcessor2_0)
]

return len(scales) > 0 and any(scale > 0 for scale in scales)

@validate_hf_hub_args
def load_ip_adapter(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
subfolder: str,
weight_name: str,
image_encoder_folder: Optional[str] = "image_encoder",
**kwargs,
):
"""
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
with [`ModelMixin.save_pretrained`].
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
subfolder (`str`):
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
list is passed, it should have the same length as `weight_name`.
weight_name (`str`):
The name of the weight file to load. If a list is passed, it should have the same length as
`subfolder`.
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
`image_encoder_folder="different_subfolder/image_encoder"`.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only (`bool`, *optional*, defaults to `False`):
Whether to only load local model weights and configuration files or not. If set to `True`, the model
won't be downloaded from the Hub.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
`diffusers-cli login` (stored in `~/.huggingface`) is used.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
allowed by Git.
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
argument to `True` will raise an error.
"""
# Load the main state dict first
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)

if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
" install accelerate\n```\n."
)

if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
raise NotImplementedError(
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
" `low_cpu_mem_usage=False`."
)

user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}

if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
if weight_name.endswith(".safetensors"):
state_dict = {"image_proj": {}, "ip_adapter": {}}
with safe_open(model_file, framework="pt", device="cpu") as f:
guiyrt marked this conversation as resolved.
Show resolved Hide resolved
for key in f.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

keys = list(state_dict.keys())
if "image_proj" not in keys and "ip_adapter" not in keys:
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")

# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
if image_encoder_folder is not None:
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
if image_encoder_folder.count("/") == 0:
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
else:
image_encoder_subfolder = Path(image_encoder_folder).as_posix()

# Commons args for loading image encoder and image processor
args = dict(
pretrained_model_name_or_path_or_dict,
subfolder=image_encoder_subfolder,
low_cpu_mem_usage=low_cpu_mem_usage,
cache_dir=cache_dir,
local_files_only=local_files_only,
)

self.register_modules(
feature_extractor=SiglipImageProcessor.from_pretrained(**args).to(
self.device, dtype=self.dtype
),
image_encoder=SiglipVisionModel.from_pretrained(**args).to(self.device, dtype=self.dtype),
)
else:
raise ValueError(
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
)
else:
logger.warning(
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
)

# Load IP-Adapter into transformer
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)

def set_ip_adapter_scale(self, scale: float):
"""
Controls image/text prompt conditioning. A value of 1.0 means the model is only conditioned on the image
prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages the model to produce more
diverse images, but they may not be as aligned with the image prompt.

Example:

```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.set_ip_adapter_scale(0.6)
>>> ...
```
"""
for attn_processor in self.transformer.attn_processors.values():
if isinstance(attn_processor, IPAdapterJointAttnProcessor2_0):
attn_processor.scale = scale

def unload_ip_adapter(self):
"""
Unloads the IP Adapter weights.

Example:

```python
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
>>> pipeline.unload_ip_adapter()
>>> ...
```
"""
# Remove image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
self.register_to_config(image_encoder=None)

# Remove feature extractor
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
self.feature_extractor = None
self.register_to_config(feature_extractor=None)

# Remove image projection
self.transformer.image_proj = None

# Restore original attention processors layers
attn_procs = {
name: (JointAttnProcessor2_0() if isinstance(value, IPAdapterJointAttnProcessor2_0) else value.__class__())
for name, value in self.transformer.attn_processors.items()
}
self.transformer.set_attn_processor(attn_procs)
16 changes: 13 additions & 3 deletions src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,11 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
self._chunk_dim = dim

def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
):
if self.use_dual_attention:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
Expand All @@ -204,17 +208,23 @@ def forward(
encoder_hidden_states, emb=temb
)

# Empty dict if None is passed
if joint_attention_kwargs is None:
joint_attention_kwargs = {}

yiyixuxu marked this conversation as resolved.
Show resolved Hide resolved
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**joint_attention_kwargs,
)

# Process attention outputs for the `hidden_states`.
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = hidden_states + attn_output

if self.use_dual_attention:
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
hidden_states = hidden_states + attn_output2

Expand Down
Loading
Loading