Skip to content

FLUX.2-klein: CUDA OOM when running in for loop with different image #13079

@dangph-alala

Description

@dangph-alala

Describe the bug

When using Flux2KleinPipeline for inference, running a single inference works fine and uses ~14GB VRAM. However, when running the exact same code inside a for loop, it immediately causes CUDA Out of Memory error on the first iteration, trying to allocate additional memory.
The OOM occurs specifically in apply_rotary_emb during the transformer's attention computation:

File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 1232, in apply_rotary_emb
    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 936.00 MiB. GPU 0 has a total capacity of 15.48 GiB of which 227.25 MiB is free. Including non-PyTorch memory, this process has 15.25 GiB memory in use. Of the allocated memory 14.78 GiB is allocated by PyTorch, and 148.94 MiB is reserved by PyTorch but unallocated.

Reproduction

Working code (single inference, ~14GB VRAM):

ckpt_path = (
    "https://huggingface.co/unsloth/FLUX.2-klein-4B-GGUF/blob/main/flux-2-klein-4b-Q6_K.gguf"
)

transformer = Flux2Transformer2DModel.from_single_file(
    ckpt_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    config="black-forest-labs/FLUX.2-klein-4B",
    subfolder="transformer"
)

pipeline = Flux2KleinPipeline.from_pretrained(
    "black-forest-labs/FLUX.2-klein-4B",
    transformer=transformer,
    torch_dtype=dtype
).to(device)

image = pipeline(
    prompt="anime style",
    image=image,
    height=image.height,
    width=image.width,
    guidance_scale=1.0,
    num_inference_steps=4,
).images[0]

Failing code (for loop, OOM on first iteration):

ckpt_path = (
    "https://huggingface.co/unsloth/FLUX.2-klein-4B-GGUF/blob/main/flux-2-klein-4b-Q6_K.gguf"
)

transformer = Flux2Transformer2DModel.from_single_file(
    ckpt_path,
    quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
    torch_dtype=torch.bfloat16,
    config="black-forest-labs/FLUX.2-klein-4B",
    subfolder="transformer"
)

pipeline = Flux2KleinPipeline.from_pretrained(
    "black-forest-labs/FLUX.2-klein-4B",
    transformer=transformer,
    torch_dtype=dtype
).to(device)
for image_file in image_files:
    base_image_obj = Image.open(image_file).convert("RGB")
        image = pipeline(
            prompt="anime style",
            image=base_image_obj,
            height=base_image_obj.height, 
            width=base_image_obj.width,
            guidance_scale=1.0,
            num_inference_steps=4,
        ).images[0]

Logs

image = pipeline(
            ^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/pipelines/flux2/pipeline_flux2_klein.py", line 843, in __call__
    noise_pred = self.transformer(
                 ^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux2.py", line 874, in forward
    encoder_hidden_states, hidden_states = block(
                                           ^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux2.py", line 521, in forward
    attention_outputs = self.attn(
                        ^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux2.py", line 256, in forward
    return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/models/transformers/transformer_flux2.py", line 155, in __call__
    query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/venv/FluxApp/lib/python3.11/site-packages/diffusers/models/embeddings.py", line 1232, in apply_rotary_emb
    out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
           ~~~~~~~~~~^~~~~
torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 936.00 MiB. GPU 0 has a total capacity of 15.48 GiB of which 227.25 MiB is free. Including non-PyTorch memory, this process has 15.25 GiB memory in use. Of the allocated memory 14.78 GiB is allocated by PyTorch, and 148.94 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

System Info

  • OS: Ubuntu 24.04
  • Python: 3.11
  • PyTorch: 2.9 (please specify your version)
  • diffusers: latest (main branch)
  • GPU: 16GB VRAM (RTX 5080)
  • CUDA: 12.8

Who can help?

@DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions