Closed
Description
Describe the bug
hi, the clause i highlighted in the link above prevents a model from using gradient checkpointing in eval mode. this is particularly useful for e.g. LORAs.
perhaps you meant to check something like this instead?
if torch.is_grad_enabled() and self.gradient_checkpointing:
Same for any other module in the repo
Reproduction
import torch
from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
block = UNetMidBlock2DCrossAttn(32, 32, 32, cross_attention_dim=32)
block.gradient_checkpointing = True
block.eval()
block(torch.randn((1, 32, 64, 64)), torch.randn((1, 32,)))
Logs
No response
System Info
- 🤗 Diffusers version: 0.30.3
- Platform: Windows-10-10.0.22631-SP0
- Running on Google Colab?: No
- Python version: 3.10.13
- PyTorch version (GPU?): 2.4.1+cu118 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.25.1
- Transformers version: 4.45.1
- Accelerate version: 0.34.2
- PEFT version: 0.12.0
- Bitsandbytes version: 0.44.1
- Safetensors version: 0.4.5
- xFormers version: 0.0.28.post3
- Accelerator: NVIDIA GeForce RTX 4060 Laptop GPU, 8188 MiB
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No