Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update debiased estimation loss function to accommodate V-pred #1715

Merged
merged 3 commits into from
Oct 25, 2024

Conversation

catboxanon
Copy link

@catboxanon catboxanon commented Oct 21, 2024

This PR:

  1. Updates debiased estimation loss function for V-pred. The previous function was intended only for epsilon. For reference:
    [Feature Request] Update apply_debiased_estimation to work properly with v-prediction. #1058 (comment)

2) Adds a deprecation notice for scale_v_pred_loss_like_noise_pred.
For reference:
#934
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/main/guided_diffusion/gaussian_diffusion.py#L864

Removed per discussion below.

cc @feffy380 @sdbds, let me know if I'm missing something.

1) Updates debiased estimation loss function for V-pred.
2) Prevents now-deprecated scaling of loss if ztSNR is enabled.
@liesened
Copy link

liesened commented Oct 21, 2024

At the current time I believe you never actually want to use scale_v_prediction_loss_like_noise_prediction, as it was originally a "fix" for an already outdated and wrongly implemented version of MinSNR (which is fixed now). Tying it to ZSNR doesn't really make any sense.

@catboxanon
Copy link
Author

You're right. I guess a warning only will have to suffice.

@catboxanon catboxanon changed the title Fix training for V-pred and ztSNR Update debiased estimation loss function to accommodate V-pred Oct 21, 2024
@kohya-ss
Copy link
Owner

Thank you for this! I plotted the weights of each loss, is this correct? Debiased estimation vpred is based on the implementation of the new apply_debiased_estimation.

Figure_1
Figure_2

Based on this chart, I think there may be some use for scale_v_pred_loss_like_noise_pred as well.

@catboxanon
Copy link
Author

is this correct?

Difficult for me to say affirmatively. I'd prefer waiting for a second or third opinion on it.

Based on this chart, I think there may be some use for scale_v_pred_loss_like_noise_pred as well.

I can remove the deprecation notice if you would like.

@catboxanon catboxanon marked this pull request as draft October 21, 2024 13:06
@liesened
Copy link

liesened commented Oct 21, 2024

Your plots are a bit confusing, I assume "Debiased estimation" is the one in this PR with v_prediction=False, "Debiased estimation vpred" is with v_prediction=True, so far good, but with "SNR" it gets unclear. If you use "Scale vpred like noise vpred" with "SNR weighted loss vpred", then "Scale vpred like noise vpred" definitely doesn't look right to me. The weighting shouldn't decrease with timesteps increasing.

The purpose of this change is to clamp debiased weighting between (0,1) to accomodate for v-prediction loss, which looks right.

@liesened
Copy link

liesened commented Oct 21, 2024

Some A/B/C tests from bluvoll trying to convert eps-based checkpoint to vpred on a tiny dataset:
MinSNR | broken Debiased | fixed Debiased
xyz_grid-0007-4018816163
xyz_grid-0010-2740262790
xyz_grid-0011-4062610619
xyz_grid-0012-1817032482

@kohya-ss
Copy link
Owner

kohya-ss commented Oct 21, 2024

Sorry for the confusion. Each corresponds as follows.

  • Debiased estimation: apply_debiased_estimation , v_pred=False
  • Debiased estimation vpred: apply_debiased_estimation, v_pred=True
  • Scale vpred like noise pred: scale_v_prediction_loss_like_noise_prediction
  • SNR weighted loss, gamma=5: apply_snr_weight, min_snr_gamma=5, vpred=False
  • SNR weighted loss vpred, gamma=5: apply_snr_weight, min_snr_gamma=5, vpred=True

As I understand it, scale_v_prediction_loss_like_noise_prediction is not intended to be used in conjunction with apply_snr_weight.

@liesened
Copy link

liesened commented Oct 21, 2024

Tried to do some math and plot this as well.
image

I think it's important to see how each weighting strategy changes the effective SNR. However, in practice no one trains v-pred without ZSNR, it would make more sense to plot using ZSNR schedule. Notice that Debiased estimation vpred weighting * SNR(t) overlaps with vpred-like loss:
image

If there's anything of note here, it's that Debiased+vpred WITH vpred-like loss look suspiciously alike to MinSNR+vpred at higher timesteps, which can suggest that scale_v_prediction_loss_like_noise_prediction may (or may not) be useful for Debiased+vpred. Didn't include some options since the readability is already really poor.

Either way, discarding scale_v_prediction_loss_like_noise_prediction right away was a poor choice. More tests would likely be needed.

@liesened
Copy link

Additionally, here's a plot comparing variants of debiased estimation.
image

Non-vpred variant is kinda useless with that spike at the beginning.

What I'd suggest is to leave scale_v_pred_loss_like_noise_pred as is. V-prediction variant of Debiased estimation seems useful.

@kohya-ss
Copy link
Owner

Thank you for the great diagrams and insight!

I did not expect scale_v_pred_loss_like_noise_pred to be combined with ZSNR, but it's interesting that it seems to make sense to combine them.

Non-vpred variant is kinda useless with that spike at the beginning.

It's true that they are for noise pred, so it can't be helped that they have no meaning with v-pred.

@catboxanon
I think it would be better to remove the deprecation warning for scale_v_pred_loss_like_noise_pred. Then this PR may be ready to be merged.

@catboxanon
Copy link
Author

@kohya-ss

I think it would be better to remove the deprecation warning for scale_v_pred_loss_like_noise_pred. Then this PR may be ready to be merged.

Done.

@catboxanon catboxanon marked this pull request as ready for review October 22, 2024 16:17
@sdbds
Copy link
Contributor

sdbds commented Oct 23, 2024

Some A/B/C tests from bluvoll trying to convert eps-based checkpoint to vpred on a tiny dataset: MinSNR | broken Debiased | fixed Debiased xyz_grid-0007-4018816163 xyz_grid-0010-2740262790 xyz_grid-0011-4062610619 xyz_grid-0012-1817032482

It looks good from the results.
For some reason, it seems that debias estimation loss function late in training cause color contamination.

@sdbds
Copy link
Contributor

sdbds commented Oct 23, 2024

Considering that the original paper was published a long time ago, I thought it would be a good idea to refer to the cited paper for updates.
https://scholar.google.com/scholar?cites=6450976606823846518&as_sdt=2005&sciodt=0,5&oi=gsb

@liesened
Copy link

liesened commented Oct 23, 2024

@sdbds

For some reason, it seems that debias estimation loss function late in training cause color contamination.

I'm not exactly sure what do you mean by "color contamination", but if it's about weird color splotches, I do think this is rather strange. My theory is that they appear because of the new prediction target and schedule being wonky at first, and this will eventually go away with sufficient training. In fact, this model was trained on 300k samples using 1/(snr+1) and I don't see any splotches there. It doesn't happen with MinSNR on the test examples likely because MinSNR doesn't rescale neither the mid nor high timesteps, and mid timesteps recieve more training with MinSNR compared to debiased. Notice how "grey" MinSNR results look. This is only a speculation though, and it may not be true.

I thought it would be a good idea to refer to the cited paper for updates.

I only found these loosely related papers, but looks like they don't focus on what this PR attempts to do at all.

@kohya-ss kohya-ss merged commit c632af8 into kohya-ss:dev Oct 25, 2024
1 check passed
@kohya-ss
Copy link
Owner

I've merged. Sorry for the delay.

@catboxanon catboxanon deleted the vpred-ztsnr-fixes branch October 25, 2024 12:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants