Skip to content

Unable to load IP Adapter into FLUX ( TypeError: CLIPVisionModelWithProjection.__init__() got an unexpected keyword argument 'dtype' ) #11581

Open
@Meatfucker

Description

@Meatfucker

Describe the bug

IP Adapters fail to load with the Flux pipeline that has had a dtype passed to it.

Reproduction

import torch
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
from diffusers.utils import load_image

model_id = "./models/black-forest-labs/FLUX.1-dev"
dtype = torch.bfloat16
pipe = FluxPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
)

apply_group_offloading(
pipe.transformer,
offload_type="leaf_level",
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
use_stream=True,
)
apply_group_offloading(
pipe.text_encoder,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
apply_group_offloading(
pipe.text_encoder_2,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)
apply_group_offloading(
pipe.vae,
offload_device=torch.device("cpu"),
onload_device=torch.device("cuda"),
offload_type="leaf_level",
use_stream=True,
)

pipe.to("cuda")

image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/flux_ip_adapter_input.jpg").resize((1024, 1024))

pipe.load_ip_adapter(
"./models/XLabs-AI/flux-ip-adapter",
weight_name="ip_adapter.safetensors",
image_encoder_pretrained_model_name_or_path="./models/openai/clip-vit-large-patch14"
)
pipe.set_ip_adapter_scale(1.0)

image = pipe(
width=1024,
height=1024,
prompt="wearing sunglasses",
negative_prompt="",
true_cfg=4.0,
generator=torch.Generator().manual_seed(4444),
ip_adapter_image=image,
).images[0]

Logs

The module 'CLIPTextModel' is group offloaded and moving it to cuda via `.to()` is not supported.
The module 'T5EncoderModel' is group offloaded and moving it to cuda via `.to()` is not supported.
The module 'FluxTransformer2DModel' is group offloaded and moving it to cuda via `.to()` is not supported.
The module 'AutoencoderKL' is group offloaded and moving it to cuda via `.to()` is not supported.
No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'. This is safe to ignore if LoRA state dict didn't originally have any CLIPTextModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new
No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'. This is safe to ignore if LoRA state dict didn't originally have any CLIPTextModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new
No LoRA keys associated to CLIPTextModel found with the prefix='text_encoder'. This is safe to ignore if LoRA state dict didn't originally have any CLIPTextModel related params. You can also try specifying `prefix=None` to resolve the warning. Otherwise, open an issue if you think it's unexpected: https://github.com/huggingface/diffusers/issues/new
Traceback (most recent call last):
  File "/home/developer/Desktop/image_inference.py", line 138, in <module>
    flux_pipe = add_flux_ip_adapter(add_flux_loras(create_flux_pipeline(img2img=False)))
  File "/home/developer/Desktop/flux_helpers.py", line 194, in add_flux_ip_adapter
    pipeline.load_ip_adapter(
  File "/home/developer/anaconda3/envs/imgInfer3/lib/python3.10/site-packages/huggingface_hub/utils/_validators.py", line 114, in _inner_fn
    return fn(*args, **kwargs)
  File "/home/developer/anaconda3/envs/imgInfer3/lib/python3.10/site-packages/diffusers/loaders/ip_adapter.py", line 523, in load_ip_adapter
    CLIPVisionModelWithProjection.from_pretrained(
  File "/home/developer/anaconda3/envs/imgInfer3/lib/python3.10/site-packages/transformers/modeling_utils.py", line 279, in _wrapper
    return func(*args, **kwargs)
  File "/home/developer/anaconda3/envs/imgInfer3/lib/python3.10/site-packages/transformers/modeling_utils.py", line 4342, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
TypeError: CLIPVisionModelWithProjection.__init__() got an unexpected keyword argument 'dtype'

System Info

python 3.10.16 he870216_1
diffusers 0.33.1 pypi_0 pypi
bitsandbytes 0.45.5 pypi_0 pypi
torch 2.6.0+cu118 pypi_0 pypi
torchaudio 2.6.0+cu118 pypi_0 pypi
torchvision 0.21.0+cu118 pypi_0 pypi

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions