Skip to content

Conversation

@dxqb
Copy link
Collaborator

@dxqb dxqb commented Oct 11, 2024

This code can be used to preserve the prior model on prompts other than the trained captions. After several more tests I think this is worth implementing and a quite generic feature:

  • It does not require any regularization image data. It works even when using the same training data for the reg steps as for the regular training steps.
  • It does not require a regularization caption. An empty caption for the reg steps works, indicating that this can preserve all kinds of concepts and whatever you train on
  • Additionally, it might improve training results on the trained captions, but I am not sure about this yet.

Let me know if I should provide more details here, which you can currently find on the OT discord.
There is a feature request for SimpleTuner here: bghira/SimpleTuner#1031

This is a draft PR only to determine the interest for a full PR. It only works with batch size one, only for Flux, only for LoRA, and only for transformer.

It could be implemented generically for all LoRA. With major effort, it could be implemented for Full Finetune, but to avoid having the full model in VRAM twice, pre-generation of reg steps predictions would be necessary.

@FurkanGozukara
Copy link

@dxqbYD can you add examples? your examples are great. even though i couldnt make it work maybe after properly implemented it will work :D

so examples of comparison and how you did setup your concepts

@dxqb
Copy link
Collaborator Author

dxqb commented Oct 14, 2024

samples can be found in these release notes of SimpleTuner: https://www.reddit.com/r/StableDiffusion/comments/1g2i13s/simpletuner_v112_now_with_masked_loss_training/

@dxqb
Copy link
Collaborator Author

dxqb commented Oct 19, 2024

kohya implementation: kohya-ss/sd-scripts#1710

@Nerogar
Copy link
Owner

Nerogar commented Oct 20, 2024

This sounds like a really good idea to add as an option. But it definitely needs a more generic implementation. There are two issues to solve

Dataset

How do we select the regularization samples during training? This also needs to work with a higher batch size than 1. Ideally it would mix regularization samples and normal training samples within the same batch.
"It does not require a regularization caption" I don't think this is strictly true. You need some kind of conditioning for the model. Not conditioning the model at all will probably significantly reduce the effect of this training method.
What do you think about adding a new flag to concepts that toggles this loss calculation for specific training samples? Then the user can decide whether to include captions or not, and which images to use.

Unhooking the LoRA

Each model has different sub-modules. So we need a generic method of disabling the LoRA for the prior result. A function in the model class to enable/disable all LoRAs could work well.

@bghira
Copy link

bghira commented Oct 20, 2024

how do you intend on mixing regularisation and training samples in a single batch @Nerogar ? that seems like not trivial. the actual target is changed.

@Nerogar
Copy link
Owner

Nerogar commented Oct 20, 2024

The only difference between prior preservation and normal training is the prediction target. So what I would do is basically this:

  1. Find the samples in the batch where the prior_preservation flag is set to True
  2. Calculate the prior prediction without the LoRA for those samples
  3. Replace the target of the batch in those samples with the prior prediction
  4. Calculate the loss without any modification

@bghira
Copy link

bghira commented Oct 20, 2024

yes, unfortunately it just doesn't have the same regularisation effect to do it that way. having an entire batch pull back toward the model works.

@dxqb
Copy link
Collaborator Author

dxqb commented Oct 20, 2024

yes, unfortunately it just doesn't have the same regularisation effect to do it that way. having an entire batch pull back toward the model works.

what are you basing this on?

what Nerogar describes above is what kohya has implemented. So if true, that would mean kohya's implementation doesn't work (as well)

@bghira
Copy link

bghira commented Oct 20, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

@dxqb
Copy link
Collaborator Author

dxqb commented Oct 20, 2024

How do we select the regularization samples during training? This also needs to work with a higher batch size than 1. Ideally it would mix regularization samples and normal training samples within the same batch. "It does not require a regularization caption" I don't think this is strictly true. You need some kind of conditioning for the model. Not conditioning the model at all will probably significantly reduce the effect of this training method.

It isn't obvious that this would work without captions, but it does. You can see samples in the reddit link above. The right-most column is without captions.

What do you think about adding a new flag to concepts that toggles this loss calculation for specific training samples? Then the user can decide whether to include captions or not, and which images to use.

Yes, agreed. There are more use cases than captions in favor of having it as a separate concept, for example balancing the regularisation using the number of repeats. In some of my tests, 1:1 was too much.

@bghira has also found using his implementation in SimpleTuner that even though it works with no external data, it works better against high-quality external data.

@dxqb
Copy link
Collaborator Author

dxqb commented Oct 20, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

okay thanks. any theory on why that would be? I don't see a theoretical reason for your finding that it works better on a separate batch:
reg gradients are tiny.
the regularisation described in the Dreambooth paper was always implemented in the same batch in the early scripts.
you could even argue that this type of contrastive training should work better in the same batch.

@O-J1
Copy link
Collaborator

O-J1 commented Oct 21, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

Could you please provide some evidence of this? I.e a significant enough amount of samples that your aren’t falling victim to seed rng

it’s important to get this right

@dxqb
Copy link
Collaborator Author

dxqb commented Oct 21, 2024

basing it on numerous tests we've run on a cluster of H100s over the last week

Could you please provide some evidence of this? I.e a significant enough amount of samples that your aren’t falling victim to seed rng

it’s important to get this right

if this turns out to be right, I'd recommend to implement a feature into the OT concepts like
"try to keep this concept separate from concept Y in batches"
and
"try to combine this concept with concept Y in batches"

It would influence how the batches are built, and the first option would be how ST builds batches.

This could be a useful feature on its own. For example, if you train 2 concepts, it can be beneficial to have 1 image of each concept in a batch, instead of the same concept twice, especially if the images in a concept are very similar.

@bghira
Copy link

bghira commented Oct 21, 2024

i dont have time, sorry, do it however works best for your codebase.

@DriveHabits

This comment was marked as off-topic.

@dxqb
Copy link
Collaborator Author

dxqb commented Nov 13, 2024

any update on this @dxqbYD

nothing usable for OneTrainer users yet.
more interesting experiments beyond just preserving prior knowledge of a separate prompt as above: It appears it can also be very useful when training a concept, controlling for what you don't want it to learn. The concept can then be mixed in by prompting, and even mixing with other independently trained LoRAs seems to work better then.

I should mention that there was apparently a paper published proposing this technique in April of this year, I just didn't know about it: https://arxiv.org/pdf/2404.07554
The authors have pointed this out at the PR of kohya's implementation. They coined it "Contrastive Adapter Training"

@TheForgotten69
Copy link
Contributor

Honestly, I would love to have this method implemented with higher batches. It seems to be the best preservation technique so far couple with some weight decay.

@dxqb
Copy link
Collaborator Author

dxqb commented Dec 19, 2024

Honestly, I would love to have this method implemented with higher batches. It seems to be the best preservation technique so far couple with some weight decay.

I don't plan to finish this PR for now, because the teacher-model-student-model thing employed here is much more powerful than just simple prior knowledge preservation, and I want to explore this further.

If anyone wants to finish it in the meantime, in kohya's code you can find what is necessary to apply it to mixed batches.
what I might submit eventually as a PR will be able to do this and more.

@DriveHabits
Copy link

it would be nice if someone could finish it and add batch size to it, i dont think kohyas one works well enough and cant seem to get good results, yours works pretty good.

@FurkanGozukara

This comment was marked as off-topic.

@DriveHabits

This comment was marked as off-topic.

@dxqb

This comment was marked as off-topic.

@DarkViewAI
Copy link

any update on this getting pushed to main branch?

Repository owner deleted a comment from FurkanGozukara Jan 19, 2025
@dxqb

This comment was marked as outdated.

@dxqb dxqb mentioned this pull request Feb 18, 2025
@dxqb
Copy link
Collaborator Author

dxqb commented Apr 27, 2025

Finally came around to implement this completely

  • for batch sizes > 1
  • for all models
  • including UI

@dxqb dxqb marked this pull request as ready for review April 27, 2025 09:00
@dxqb
Copy link
Collaborator Author

dxqb commented Apr 27, 2025

I have ran a few tests now:

  • I find that it works equally well with bs1 as with bs2
  • this also means it works well with mixed batches, because OT uses mixed batches at bs2
    (regarding the discussion above, that it might be necessary to separate sample batches from prior reg batches - I don't think so)
  • in none of my experiments did I find it necessary to increase the loss weight of the regularization concept. Users of kohya's implementations have discussed increasing it up to 10 and even up to 100, to get it to work. There must be something else wrong there - 1 works well.

@dxqb
Copy link
Collaborator Author

dxqb commented Apr 28, 2025

grafik

Various test runs with validation loss. On the left is the trained concept, on the right is a similar concept that is not trained.

  • in purple, the concept is learned but the other concept also becomes like your trained concept (without reg)
  • In green, with reg. The other concept remains stable, but it seems the trained concept is learned a bit harder
  • It seems to help to increase the LR a bit (orange), or only use reg steps half as much (black)

Samples confirm these graphs

@Nerogar Nerogar merged commit 61c7dd1 into Nerogar:master May 2, 2025
1 check passed
@dxqb dxqb deleted the prior_reg branch May 2, 2025 11:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants