Skip to content

[Core] Add PAG support for PixArtSigma #8921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Aug 2, 2024
Merged

[Core] Add PAG support for PixArtSigma #8921

merged 22 commits into from
Aug 2, 2024

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jul 21, 2024

What does this PR do?

Part of #8785.

I think PixArt Sigma is way more popular than PixArt Alpha (longer sequence length, better quality in general, etc.), it makes sense to just add PAG to PixArt Sigma. If there's a need, I think we could open it up to community.

TODO

  • Tests
  • Docs

Code

from diffusers import PixArtSigmaPAGPipeline
import torch 

pipe = PixArtSigmaPAGPipeline.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", 
    torch_dtype=torch.float16, 
    pag_applied_layers=[14]
).to("cuda")

image = pipe(
    "A small cactus with a happy face in the Sahara desert.", 
    guidance_scale=0.0, 
    pag_scale=2.0, 
    generator=torch.manual_seed(0)
).images[0]

A bit of ablation study is in the comment below.

@asomoza would appreciate it if you could run some experiments with it :)

@sunovivid for awareness.

@sayakpaul sayakpaul requested a review from yiyixuxu July 21, 2024 07:53
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul
Copy link
Member Author

Ablation

Code

import torch
import argparse
from diffusers import PixArtSigmaPAGPipeline, PixArtSigmaPipeline

def load_pipeline(args):
    if args.pag:
        pipe = PixArtSigmaPAGPipeline.from_pretrained(
            "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16, pag_applied_layers=args.pag_applied_layers
        ).to("cuda")
    else:
        pipe = PixArtSigmaPipeline.from_pretrained("PixArt-alpha/PixArt-Sigma-XL-2-1024-MS", torch_dtype=torch.float16).to("cuda")
    return pipe 

def run_pipeline(pipe, args):
    if args.pag:
        image = pipe(
            args.prompt, guidance_scale=args.cfg, pag_scale=args.pag_scale, generator=torch.manual_seed(0)
        ).images[0]
    else:
        image = pipe(args.prompt, generator=torch.manual_seed(0)).images[0]
    
    img_name = "_".join(args.prompt.split(" "))
    if args.pag:
        img_name += "_pag"
        pag_applied_layers = list(map(str, args.pag_applied_layers))
        pag_applied_layers = "_".join(pag_applied_layers)
        img_name += f"_cfg@{args.cfg}_pg@{args.pag_scale}_layers@{pag_applied_layers}"
    img_name += ".png"
    image.save(img_name)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pag", type=int, default=0)
    parser.add_argument("--prompt", type=str, default="A small cactus with a happy face in the Sahara desert.")
    parser.add_argument("--cfg", type=float, default=0.0)
    parser.add_argument("--pag_scale", type=float, default=4.0)
    parser.add_argument("--pag_applied_layers", metavar="N", type=int, nargs="*", default=[14])
    args = parser.parse_args()

    pipe = load_pipeline(args)
    run_pipeline(pipe, args)

Results

Prompt: A small cactus with a happy face in the Sahara desert

W/O PAG PAG

Prompt: Astronaut on Mars During sunset

W/O PAG PAG_CFG_0_PG_2_LAYERS_14 PAG_CFG_0_PG_2_LAYERS_13_14 PAG_CFG_2_PG_2_LAYERS_14

@@ -258,3 +258,91 @@ def pag_attn_processors(self):
if proc.__class__ in (PAGCFGIdentitySelfAttnProcessor2_0, PAGIdentitySelfAttnProcessor2_0):
processors[name] = proc
return processors


class PixArtPAGMixin(PAGMixin):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if I should copy-paste the other methods from PAGMixin and not make it a subclass of PAGMixin.

PixArtPAGMixin differs only in terms of the methods I implemented here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also curious about this. subclass looks neater than copy-paste.

Copy link
Collaborator

@yiyixuxu yiyixuxu Jul 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer not to have a Mixin nested into another Mixin if it's possible
we can either

  1. copy-paste the common methods from PAGMixin (maybe we can rename PAGMixin to SDPAGMixin in that case)
  2. we keep a common PAGMixin and move these methods that're potentially model-specific to pipeline methods

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will go with the first one. I will defer any refactoring to @yiyixuxu if she feels the need to do one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 27fceb3. LMK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @a-r-r-o-w here too, I think we can reuse this one for hunyuan too

@yiyixuxu
Copy link
Collaborator

maybe we can try the robot prepare meal example? https://huggingface.co/docs/diffusers/main/en/using-diffusers/pag#general-tasks

I cannot tell if there is an improvement or not in the example provided (but I'm also not very good at looking at this, so cc @asomoza for help here too)

@sayakpaul
Copy link
Member Author

Yeah I am not good at that either. So, requested @asomoza to chime in :D

Here are the insect cooking meal results.

No PAG CFG_1_PG_3_L14 CFG_4.5_PG_3_L13_14_15 CFG_4.5_PG_3_L14
Image 1 Image 2 Image 3 Image 4

CFG_1_PG_3_L14 means we are applying PAG with the following config:

  • CFG of 1.0
  • PAG scale of 3.0
  • pag_applied_layers is set to [14]

@asomoza
Copy link
Member

asomoza commented Jul 23, 2024

This one is the most interesting one of the new models, it definitely helps with the generation, it cleans the image and fixes the errors, but as with kolors, it takes away some details, so it depends on what you want.

For example in this image, it fixes the hands of the robot and the shape of the bowls:

W/O PAG PAG
20240722231625_1321339354 20240722231707_1321339354

Also I found the time to play with the model and it's a really good model, I'm impressed with the generations it can do.

Anyway, what I found interesting is how the PAG layers behave, they affect a lot of the generation, for example:

L1 L2 L6
20240722232322_1321339354_layer_1 20240722232333_1321339354_layer_2 20240722232418_1321339354_layer_6
L18 L19 L21
20240722232633_1321339354_layer_18 20240722232644_1321339354_layer_19 20240722232707_1321339354_layer_21

As with SDXL, I can see some relation of the layers with the generations but I'll need to play more with it to be sure. For example the 19 and 21 layers seems to affect more the background and some others seems to affect more the shape or the colors. Also the 1 and 6 layers can be used for the composition with a second pass over them.

@sayakpaul
Copy link
Member Author

Thank you very much, Alvaro! So, what I am hearing is that it makes sense to have PAG supported for PixArt Sigma, yeah?

Cc @lawrence-cj too you might find it nice.

@yiyixuxu
Copy link
Collaborator

@asomoza very nice! thanks!

@sayakpaul
Copy link
Member Author

@yiyixuxu can I add test and docs and get it ready for final review?

@yiyixuxu
Copy link
Collaborator

@sayakpaul sure ! looks good to me
let's merge this soon

Comment on lines +53 to +55
params = TEXT_TO_IMAGE_PARAMS.union({"pag_scale", "pag_adaptive_scale"})
params = set(params)
params.remove("cross_attention_kwargs")
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if this needs to be handled differently. PixArt doesn't have cross_attention_kwargs and doesn't need to have yet.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok!

Comment on lines +266 to +268
# Because the PAG PixArt Sigma has `pag_applied_layers`.
# Also, we shouldn't be doing `set_default_attn_processor()` after loading
# the pipeline with `pag_applied_layers`.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if this and subsequent tests should be handled differently.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

@sayakpaul
Copy link
Member Author

@yiyixuxu I have left some questions for you regarding the tests. LMK.

@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Aug 1, 2024

@sayakpaul looks good to me! feel free to merge once the conflicts are resolved

@sayakpaul sayakpaul merged commit 7b98c4c into main Aug 2, 2024
18 checks passed
@sayakpaul sayakpaul deleted the pag-pixart-sigma branch August 2, 2024 01:57
@yiyixuxu yiyixuxu added the PAG label Sep 4, 2024
sayakpaul added a commit that referenced this pull request Dec 23, 2024
* feat: add pixart sigma pag.

* inits.

* fixes

* fix

* remove print.

* copy paste methods to the pixart pag mixin

* fix-copies

* add documentation.

* add tests.

* remove correction file.

* remove pag_applied_layers

* empty
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants