Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add preference optimization (Diffusion-DPO, MaPO) #1427

Draft
wants to merge 5 commits into
base: dev
Choose a base branch
from

Conversation

rockerBOO
Copy link
Contributor

@rockerBOO rockerBOO commented Jul 13, 2024

Add preference optimization (PO) support
Add paired images in dataset.

Preference Optimization algo implemented:

Currently looking for feedback about implementation.

Decisions I made and why:

Pairing images

Pairing images in ImageSetInfo (exetnding ImageInfo)

  • Pairing images keeps the images from shuffling
  • Paired images should be updated after seeing both.

Batch size of 1 will load 2 captions and 2 images for 1 requested image/caption pair

Dataset

Datasets can be defined as "preference" args.preference --preference and in dataset_config.toml

  [[datasets.subsets]]
  image_dir = "dataset/1_name"
  preference = true

To create a pattern of dataset, we are hard coding in dataset/1_name/w and dataset/1_name/l. You would then have a typical dreambooth-like dataset with the following.

  • dataset/1_name/w/image.png
  • dataset/1_name/w/image.caption
  • dataset/1_name/l/image.png
  • dataset/1_name/l/image.caption

Note w and l are like the typical dataset with image/caption file pairs. They all have the same file name to create the pairs.

Good idea to consider other file dataset patterns.

Preference dataset examples:

Pickapic dataset is a preference between 2 images and showing the pairing and embedding the 2 images into the dataset.

Caption prefix/suffix for preference/non-preference

Prefix/suffix allow some techniques of moving away from some concepts. Allows different ones for preference/non preference, to give flexibility in experimentation.

Training

Added PO into the main training script to allow flexibility but will be moved to the typical functions for these. I have it setup for network training but would work other scripts.

  • Images come in as pairs on the tensors
  • Pass the loss through the PO algorithm for the pairs
  • Log the associated values from the training

Hyperparameters

--beta_dpo = KL-divergence parameter beta for Diffusion-DPO

2500 for 1.5, 5000 for SDXL were what I have found suggested.

--mapo_weight = MaPO contribution factor

Start around 0.1 but adjusting this can be helpful at how much the contribution of the preference optimization will have on the training. See

TODO

  • Cache support
  • ControlNet dataset (For masking)

Possible issues

Preference and regular training datasets mixed

This mixing would need to worked on at higher than 1 batch size. We assume chunking of pairs so unpaired images won't work that way.

The implementations may not be accurate

If you see something not correct, let me know.

Usage

State: This is currently working and producing favorable results.

Images/caption pairs stored in w and l directories.

dataset/1_myimages/w/image.jpg
dataset/1_myimages/w/image.txt

dataset/1_myimages/l/image.jpg
dataset/1_myimages/l/image.txt

NOTE Use the same name for images in w and l directories to make them paired

python train_network.py ... --preference --dataset_dir dataset --mapo_weight=0.1

or in your dataset config

  [[datasets.subsets]]
  image_dir = "dataset/1_name"
  preference = true

Related tickets: #1040

@feffy380
Copy link
Contributor

Do you have any training samples from this?

@rockerBOO
Copy link
Contributor Author

36ee5259 = pickapic dataset sample (500 preferences)
a9a03acb = my own preference dataset from prompts generated on SD 1.5

10 epochs, LoRA 16/16 with Prodigy d_coef at 1.25

MaPO with contribution weight of mapo_weight = 0.1

xyz_grid-0010-1229252821
Dreamshaper 8

xyz_grid-0005-2443401054
3D Animation 1.0

xyz_grid-0002-3672213076
SD 1.5 base model

Papers generally suggesting around 1000 preferences. I have been making a preference creation tool so one could make their own preferences on their own dataset.

@feffy380
Copy link
Contributor

I'm finding I need the learning rate as low as 1e-6 (for an SDXL lora), possibly lower if you have a bigger dataset. I also had text encoder training disabled.

One thing I want to try is using real chosen plus AI generated rejected images, inspired by what LLM folks have been doing to bypass the need for collecting real preference pairs

@rockerBOO
Copy link
Contributor Author

I'm finding I need the learning rate as low as 1e-6 (for an SDXL lora), possibly lower if you have a bigger dataset. I also had text encoder training disabled.

Which did you try Diffusion-DPO or MaPO? I found it to train slowly at 1e-4 on SD 1.5 with MaPO weight of 0.1. Haven't done a full hyperparameter test yet though.

@feffy380
Copy link
Contributor

MaPO with LR 1e-6 and beta 0.1 on sdxl. My dataset consists of real images in the target style and each has a matching AI image without style prompts as the rejected image. The differences are extreme, so maybe that's why I need lower learning rates?

@rockerBOO
Copy link
Contributor Author

the mapo_weight is described as the contribution which is basically the difference between the preference and non-preference. So you could adjust the weight to be like 0.05 and keep your original LR. I'm not sure what is the most efficient though.

@feffy380
Copy link
Contributor

feffy380 commented Aug 5, 2024

A few more observations with adamw and my real chosen, synthetic rejected dataset:

  • Stippling and gridlike artifacts appear when training for longer periods
  • Dropping the margin loss 75% of the time delayed the appearance of artifacts without slowing down style learning too much. This is not equivalent to reducing beta_mapo by 75%
  • Artifacts were less severe with min_snr_gamma=1 compared to without

I haven't managed to completely eliminate the artifacts, so my only option is early stopping. I've seen this with other preference optimization papers (if you look, they all train for very short periods of around 2000 steps) and it's annoying that it's never addressed.

@feffy380
Copy link
Contributor

feffy380 commented Aug 10, 2024

I finally sat down and built a real preference dataset and found pairs of images generated by the same model don't cause as many artifacts (probably because they come from the same distribution and any encoding artifacts cancel out)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants