Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Add new way to manage timestep distribution and new loss offset feature #1375

Open
Anzhc opened this issue Jun 16, 2024 · 17 comments

Comments

@Anzhc
Copy link

Anzhc commented Jun 16, 2024

We were working some time on an way to create timestep distribution specific to any particular dataset used in training.

We tested it for last couple months using Derrian easy scripts trainer with hardcoded changes, but we're not familiar with sd-scripts structure enough to fully integrate that as toggleable feature on our own at the current point in time, hence feature request.

https://github.com/Anzhc/Timestep-Attention-and-other-shenanigans

Here we give code of functions structured roughly correctly for usage with variables in sd-scripts. Also examples how we implemented it in hardcoded way are included in repo (look into "Kohya file examples" folder).
It builds distribution based on received loss values, similar to this:
322636402-c5368e3c-c074-45a9-a929-e49808f87618
But unique to each dataset(as it's based on particular loss values received in scope of training).

In majority of cases it improved convergence speed and quality of outcome in our tests.

We also provide code for Loss Offset Curve feature that works with said distribution, can be scaled and scheduled. It also supposed to improve convergence speed.

Some experimental loss functions are included too, but those are really just experiments.

@cheald
Copy link

cheald commented Jun 18, 2024

This looks like it could pair really well with an approach that I'm working on and about to publish (the short version is "synchronoise noise std/mean to the underlying model"), which also dramatically improves convergence speed. I'm going to see if I can adapt this approach into my approach, because I have a theory that my approach will flatten the curves you're seeing there, which could potentially mean that the two approaches could cause faster convergence together.

Thanks!

@Anzhc
Copy link
Author

Anzhc commented Jun 18, 2024

This looks like it could pair really well with an approach that I'm working on and about to publish (the short version is "synchronoise noise std/mean to the underlying model"), which also dramatically improves convergence speed. I'm going to see if I can adapt this approach into my approach, because I have a theory that my approach will flatten the curves you're seeing there, which could potentially mean that the two approaches could cause faster convergence together.

Thanks!

We actually were following your development on new loss approach and it's weighting for last couple days, would love to see that incorporated with your really fancy tech! :D

@cheald
Copy link

cheald commented Jun 18, 2024

Can you elaborate a bit on the underlying theory of operation of your approach? A quick read through seemed to suggest that you're just tracking losses per-timestep, then weighting timestep selection towards higher-loss timesteps over the course of the run. Do I have that roughly correct?

@Anzhc
Copy link
Author

Anzhc commented Jun 18, 2024

Yes, that's the mechanism. We initialize a Loss Map that contains each timestep and it's loss, we initialize it with equal loss of 1 for each, but that can be changed to emulate different distributions at start, like lognorm(0,1).
Whenever we receive a new batch of timesteps, we update loss of that timestep and specified range of near timesteps by amount of weight decay, where last number is out main timestep, and each next is +-1 from it(i.e. if you put 5 numbers in it, it would change timestep received +-4 from it by specified portions). This helps to build map faster, but can be adjusted to be more precise for longer trainings.
It is also possible to vary averaging by increasing or decreasing change value, which is by default 60% move towards loss of timestep acquired.

Then we adjust that distribution by a small margin by dampening high chance timesteps and boosting low chance timesteps, to still give them meaningful chance to be sampled once in a while.

That's the mechanism, so you are correct.
Then that Timestep Loss Map can be used throughout training loop for weighting and adjusting various things, like loss curve function we have there, or clustered MSE, which selects cluster amount based on that map.

I have a gif how it changes in course of training,
distribution_animation

@Anzhc
Copy link
Author

Anzhc commented Jun 18, 2024

Different datasets led to meaningfully different graphs, so i think dynamic way of approaching this is currently best

@Anzhc
Copy link
Author

Anzhc commented Jun 18, 2024

Screenshot 2024-06-19 025005
Screenshot 2024-06-19 025022
Form is generally the same, though i think this is going to change on architecture-to-architecture basis, but spread in chances is roughly ~10-15% between those generally similar trainings(im usually just training some anime styles).

@cheald
Copy link

cheald commented Jun 18, 2024

Are you using min_snr in your examples there? Based on my experiments, I would not expect that sharp downwards trend in loss observations in the earliest timesteps, unless you're using a mechanism that is artificially dampening it. Loss should be highest in the earliest timesteps (particular for SD1.5) as that's where the true noise/predicted noise standard deviation is the highest.

You might be able to simplify things by just using a gaussian kernel over the raw observations, which would give you configurable bandwidth for tuning the resolution/smoothness of the histogram.

@Anzhc
Copy link
Author

Anzhc commented Jun 18, 2024

I indeed do use min snr of 8 in training. I don't want to really drop it, but in hindsight it would be better to move loss_for_timesteps before it is altered by min snr and other things.
Right now im placing it after min snr and other mods, but before loss_curve_offset.

Im not sure simplifying would be the goal here, since i wasn't aiming at receiving smooth timestep distribution, but instead at a balance, where we still could see meaningful differences in sampling chance between relatively close timesteps, if they indeed reliably produce higher loss consistently.

When i was testing with some amount of Huber loss mixed in, trend was more even, just as a note.

I should place saving of loss for timesteps before min snr and see how it behaves with that.

@Anzhc
Copy link
Author

Anzhc commented Jun 19, 2024

My graphs were created with training with loss for timesteps being recorded at this point:
image

I'll put it before min snr and see what that leads me to whenever i have time to train some style again. Thanks for pointing that one out xd

@cheald
Copy link

cheald commented Jun 19, 2024

Was this an SD1.5 model you generated the graphs on?

Absent the min_snr depression, this looks an awful lot like my observed std() distributions from my latest post in the loss thread, which makes me wonder if I could just use something like:

prob = torch.nn.functional.softmax(-std_by_ts.mean(dim=1).reshape(-1), dim=0)
cat = torch.distributions.Categorical(probs=prob.float())
timesteps = cat.sample([b_size]).to(device=latents.device)

image

image

From a quick first test, it seems to be VERY promising. I'll run some more comprehensive tests, but it might be that the std discrepancy is accounting for enough of the overall loss that we could just sample timestamps with it directly.

Edit: This is my result on my SD1.5 test harness after only 20 epochs. I'd normally need to get to 40-50 epochs to get this kind of result before. This isn't dynamically learning the loss map like you are, but might be able to jump right to a good-enough approximation using the std measurements from the underlying model.

cb_rv_v03_e000018_03_20240618172358_58008135

It's worth noting that if you're using a mechanism like this to compensate for the loss distribution imbalance, min_snr becomes useless-to-harmful. My example there is without any kind of min_snr, noise offset, or otherwise, and it seems to be working splendidly.

@Anzhc
Copy link
Author

Anzhc commented Jun 19, 2024

My graphs are from SDXL, Pony Diffusion v6 in particular.

It's hard for me to measure results by, your face, i think? Since i don't know it xd
On 1.5 i have very few tests, but dynamically sampled distribution did converge quite a bit better and faster when it came to particular concept i tested.
323054378-decf0a08-de7f-493e-ac0b-60780b65a37d
i.e. on epoch 15 it already looks better than result of standard distribution even on epoch 50.

It is probably will be quite beneficial to initialize our TA with your distribution to make it more efficient in early phase, and then let it modify itself to conform to dataset|model combo, as from your graphs i see that it does indeed change quite drastically based not only on dataset, but on underlying model as well as i thought.

I was thinking of doing general lognorm(0,1) initialization instead of uniform, or even shifted lognorm(0.5, 1) or so, that did show better performance in plain training on XL for us, but never tested that with dynamics.

Right now i started a training of style with loss_for_timesteps moved to before min snr, but it'll take some hours to run fully, and then i'll need to re-train another one with it moved back for reference.
Here is what distribution im at 220(1320 real steps) in:
image
It doesn't seem like min snr affected distribution much in my case? This does look fairly similar to graphs i've shown before. Unless i f'ed something up, but i don't think that's the case?

@cheald
Copy link

cheald commented Jun 19, 2024

Fair enough! Suffice it to say that both model quality and fidelity are way better than in previous approaches.

I'll need to generate stats for ponyv6. It's a very different model than most and is probably going to exhibit some characteristics not found in other models. My sd1.5 result is conclusive enough that I think there's something very worth pursuing here!

@cheald
Copy link

cheald commented Jun 19, 2024

Here's the noise std/mean for Ponyv6:

image

As expected, it's quite different from SD1.5 or base SDXL models. The lack of mean convergence towards 0 is particularly interesting.

@bluvoll
Copy link

bluvoll commented Jun 20, 2024

Here's the noise std/mean for Ponyv6:

image

As expected, it's quite different from SD1.5 or base SDXL models. The lack of mean convergence towards 0 is particularly interesting.

As you said, SDXL behaves very differently, in fact we got to the conclusion "SDXL is very rigid" even with extreme values it didn't reach unusable results like SD1.5, but we found that this approach still yielded better and faster results, in fact some stuff started to get more details when using lower LoRA Learning rate vs what we used to use with PonyV6, sadly we couldn't test much Finetunning due to only having X090 class hardware, but results on rented A6000 and A100 showed the same tendecy of "being better", but we lack a lot knowledge regarding ML in general, hopefully whatever you cook with this + your noise shenanigans ends up in something very very nice for the community.

@feffy380
Copy link
Contributor

feffy380 commented Jul 2, 2024

@Anzhc

It doesn't seem like min snr affected distribution much in my case? This does look fairly similar to graphs i've shown before. Unless i f'ed something up, but i don't think that's the case?

I think it's the way you spread the loss weights. Values at the edges only receive updates from one side. Maybe use a 1d convolution with padding set to "replicate"

@Anzhc
Copy link
Author

Anzhc commented Jul 3, 2024

I think it's the way you spread the loss weights. Values at the edges only receive updates from one side. Maybe use a 1d convolution with padding set to "replicate"

You mean let 999 be updated by 1 and vice versa? Might be an interesting idea. Tests from some people suggest that setting curve from the get go instead of letting it learn from uniform leads to worse performance in small trainings, so i think there is some grounding that needs to be done by learning earlier timesteps, and this likely will slightly increase hit at them.

@Anzhc
Copy link
Author

Anzhc commented Jul 3, 2024

But i wont be able to test anything for next 4 days e_e
image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants