Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions docs/source/en/using-diffusers/ip_adapter.md
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,87 @@ image
<div class="flex justify-center">
    <img src="https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ipa-controlnet-out.png" />
</div>

### Style & layout control

[InstantStyle](https://arxiv.org/abs/2404.02733) is a plug-and-play method on top of IP-Adapter, which disentangles style and layout from image prompt to control image generation. This is achieved by only inserting IP-Adapters to some specific part of the model.

By default IP-Adapters are inserted to all layers of the model. Use the [`~loaders.IPAdapterMixin.set_ip_adapter_scale`] method with a dictionary to assign scales to IP-Adapter at different layers.

```py
from diffusers import AutoPipelineForImage2Image
from diffusers.utils import load_image
import torch

pipeline = AutoPipelineForImage2Image.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")

scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)
```

This will activate IP-Adapter at the second layer in the model's down-part block 2 and up-part block 0. The former is the layer where IP-Adapter injects layout information and the latter injects style. Inserting IP-Adapter to these two layers you can generate images following the style and layout of image prompt, but with contents more aligned to text prompt.

```py
style_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg")

generator = torch.Generator(device="cpu").manual_seed(42)
image = pipeline(
prompt="a cat, masterpiece, best quality, high quality",
image=style_image,
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
guidance_scale=5,
num_inference_steps=30,
generator=generator,
).images[0]
image
```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter image</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit_style_layout_cat.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">generated image</figcaption>
</div>
</div>

In contrast, inserting IP-Adapter to all layers will often generate images that overly focus on image prompt and diminish diversity.

Activate IP-Adapter only in the style layer and then call the pipeline again.

```py
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)

generator = torch.Generator(device="cpu").manual_seed(42)
image = pipeline(
prompt="a cat, masterpiece, best quality, high quality",
image=style_image,
negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
guidance_scale=5,
num_inference_steps=30,
generator=generator,
).images[0]
image
```

<div class="flex flex-row gap-4">
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit_style_cat.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter only in style layer</figcaption>
</div>
<div class="flex-1">
<img class="rounded-xl" src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/30518dfe089e6bf50008875077b44cb98fb2065c/diffusers/default_out.png"/>
<figcaption class="mt-2 text-center text-sm text-gray-500">IP-Adapter in all layers</figcaption>
</div>
</div>

Note that you don't have to specify all layers in the dictionary. Those not included in the dictionary will be set to scale 0 which means disable IP-Adapter by default.
49 changes: 40 additions & 9 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
is_transformers_available,
logging,
)
from .unet_loader_utils import _maybe_expand_lora_scales


if is_transformers_available():
Expand Down Expand Up @@ -243,25 +244,55 @@ def load_ip_adapter(

def set_ip_adapter_scale(self, scale):
"""
Sets the conditioning scale between text and image.
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.

Example:

```py
pipeline.set_ip_adapter_scale(0.5)
# To use original IP-Adapter
scale = 1.0
pipeline.set_ip_adapter_scale(scale)

# To use style block only
scale = {
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)

# To use style+layout blocks
scale = {
"down": {"block_2": [0.0, 1.0]},
"up": {"block_0": [0.0, 1.0, 0.0]},
}
pipeline.set_ip_adapter_scale(scale)

# To use style and layout from 2 reference images
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
pipeline.set_ip_adapter_scale(scales)
```
"""
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
for attn_processor in unet.attn_processors.values():
if not isinstance(scale, list):
scale = [scale]
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)

for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if not isinstance(scale, list):
scale = [scale] * len(attn_processor.scale)
if len(attn_processor.scale) != len(scale):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"`scale` should be a list of same length as the number if ip-adapters "
f"Expected {len(attn_processor.scale)} but got {len(scale)}."
f"Cannot assign {len(scale_configs)} scale_configs to "
f"{len(attn_processor.scale)} IP-Adapter."
)
attn_processor.scale = scale
elif len(scale_configs) == 1:
scale_configs = scale_configs * len(attn_processor.scale)
for i, scale_config in enumerate(scale_configs):
if isinstance(scale_config, dict):
for k, s in scale_config.items():
if attn_name.startswith(k):
attn_processor.scale[i] = s
else:
attn_processor.scale[i] = scale_config

def unload_ip_adapter(self):
"""
Expand Down
34 changes: 28 additions & 6 deletions src/diffusers/loaders/unet_loader_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def _translate_into_actual_layer_name(name):
return ".".join((updown, block, attn))


def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]]):
def _maybe_expand_lora_scales(
unet: "UNet2DConditionModel", weight_scales: List[Union[float, Dict]], default_scale=1.0
):
blocks_with_transformer = {
"down": [i for i, block in enumerate(unet.down_blocks) if hasattr(block, "attentions")],
"up": [i for i, block in enumerate(unet.up_blocks) if hasattr(block, "attentions")],
Expand All @@ -47,7 +49,11 @@ def _maybe_expand_lora_scales(unet: "UNet2DConditionModel", weight_scales: List[

expanded_weight_scales = [
_maybe_expand_lora_scales_for_one_adapter(
weight_for_adapter, blocks_with_transformer, transformer_per_block, unet.state_dict()
weight_for_adapter,
blocks_with_transformer,
transformer_per_block,
unet.state_dict(),
default_scale=default_scale,
)
for weight_for_adapter in weight_scales
]
Expand All @@ -60,6 +66,7 @@ def _maybe_expand_lora_scales_for_one_adapter(
blocks_with_transformer: Dict[str, int],
transformer_per_block: Dict[str, int],
state_dict: None,
default_scale: float = 1.0,
):
"""
Expands the inputs into a more granular dictionary. See the example below for more details.
Expand Down Expand Up @@ -108,21 +115,36 @@ def _maybe_expand_lora_scales_for_one_adapter(
scales = copy.deepcopy(scales)

if "mid" not in scales:
scales["mid"] = 1
scales["mid"] = default_scale
elif isinstance(scales["mid"], list):
if len(scales["mid"]) == 1:
scales["mid"] = scales["mid"][0]
else:
raise ValueError(f"Expected 1 scales for mid, got {len(scales['mid'])}.")

for updown in ["up", "down"]:
if updown not in scales:
scales[updown] = 1
scales[updown] = default_scale

# eg {"down": 1} to {"down": {"block_1": 1, "block_2": 1}}}
if not isinstance(scales[updown], dict):
scales[updown] = {f"block_{i}": scales[updown] for i in blocks_with_transformer[updown]}
scales[updown] = {f"block_{i}": copy.deepcopy(scales[updown]) for i in blocks_with_transformer[updown]}

# eg {"down": "block_1": 1}} to {"down": "block_1": [1, 1]}}
# eg {"down": {"block_1": 1}} to {"down": {"block_1": [1, 1]}}
for i in blocks_with_transformer[updown]:
block = f"block_{i}"
# set not assigned blocks to default scale
if block not in scales[updown]:
scales[updown][block] = default_scale
if not isinstance(scales[updown][block], list):
scales[updown][block] = [scales[updown][block] for _ in range(transformer_per_block[updown])]
elif len(scales[updown][block]) == 1:
# a list specifying scale to each masked IP input
scales[updown][block] = scales[updown][block] * transformer_per_block[updown]
elif len(scales[updown][block]) != transformer_per_block[updown]:
raise ValueError(
f"Expected {transformer_per_block[updown]} scales for {updown}.{block}, got {len(scales[updown][block])}."
)

# eg {"down": "block_1": [1, 1]}} to {"down.block_1.0": 1, "down.block_1.1": 1}
for i in blocks_with_transformer[updown]:
Expand Down
Loading