Skip to content

StableDiffusionSafetyChecker ignores attn_implementation load kwarg #8957

@jambayk

Description

@jambayk

Describe the bug

transformers added sdpa and FA2 for CLIP model in huggingface/transformers#31940. It now initializes the vision model like https://github.com/huggingface/transformers/blob/85a1269e19af022e04bc2aad82572cd5a9e8cdd9/src/transformers/models/clip/modeling_clip.py#L1143.

However, StableDiffusionSafetyChecker uses

self.vision_model = CLIPVisionModel(config.vision_config)
so it always gets initialized with sdpa attention.

Reproduction

from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker

model = StableDiffusionSafetyChecker.from_pretrained(
    "runwayml/stable-diffusion-v1-5", 
    subfolder="safety_checker", 
   attn_implementation="eager"
)
print(type(model.vision_model.vision_model.encoder.layers[0].self_attn))

Expected transformers.models.clip.modeling_clip.CLIPAttention but got transformers.models.clip.modeling_clip.CLIPSdpaAttention.

Logs

No response

System Info

diffusers 0.29.0
transformers 4.43.1

Who can help?

@sayakpaul @dn

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions