Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the tT_loss #63

Open
zzc681 opened this issue Apr 16, 2023 · 2 comments
Open

About the tT_loss #63

zzc681 opened this issue Apr 16, 2023 · 2 comments

Comments

@zzc681
Copy link

zzc681 commented Apr 16, 2023

Hi, Thanks for your excellent work, but I have a small question about the loss function. When I was reading the code, I found that tT_loss calculates the loss between X_t and 0. Is there any meaning to doing this?
The code in the gaussian_diffusion.py, the function training_losses_e2e in class GaussianDiffusion
out_mean, _, _ = self.q_mean_variance(x_start, torch.LongTensor([self.num_timesteps - 1]).to(x_start.device))
tT_loss = mean_flat(out_mean ** 2)

@ryuliuxiaodong
Copy link

Same question for me.

The other loss terms written in training_losses_e2e are clear, which are also described in the paper (Equation 2). But I don't quite understand this tT_loss: why loss is calculated on each timestep of this forward diffusion process?

@K0ntact
Copy link

K0ntact commented Sep 24, 2024

Seems like tT_loss is to calculate how well the input are diffused into Normal Gaussian noise, since it is calculated by diffusing x_start to the last diffusion step via q_mean_variance.

def q_mean_variance(self, x_start, t):
"""
Get the distribution q(x_t | x_0).
:param x_start: the [N x C x ...] tensor of noiseless inputs.
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
"""

The squared difference between out_mean and normal distribution mean (0) is tT_loss, which explains why it is written as out_mean ** 2.

out_mean, _, _ = self.q_mean_variance(x_start, th.LongTensor([self.num_timesteps - 1]).to(x_start.device))
tT_loss = mean_flat(out_mean ** 2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants