-
Notifications
You must be signed in to change notification settings - Fork 8
/
debiased_hybrid_loss.py
149 lines (118 loc) · 5.6 KB
/
debiased_hybrid_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import torch
from torchvision import transforms
import torch.nn.functional as F
import random
from utils.lora import extract_lora_child_module
from utils.func_utils import tensor_to_vae_latent, sample_noise
def DebiasedHybridLoss(
train_loss_temporal,
accelerator,
optimizers,
lr_schedulers,
unet,
vae,
text_encoder,
noise_scheduler,
batch,
step,
config,
random_hflip_img=False,
spatial_lora_num=1
):
mask_spatial_lora = random.uniform(0, 1) < 0.2
cache_latents = config.train.cache_latents
if not cache_latents:
latents = tensor_to_vae_latent(batch["pixel_values"], vae)
else:
latents = batch["latents"]
# Sample noise that we'll add to the latents
# use_offset_noise = use_offset_noise and not rescale_schedule
noise = sample_noise(latents, 0.1, False)
bsz = latents.shape[0]
# Sample a random timestep for each video
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# *Potentially* Fixes gradient checkpointing training.
# See: https://github.com/prigoyal/pytorch_memonger/blob/master/tutorial/Checkpointing_for_PyTorch_models.ipynb
# if kwargs.get('eval_train', False):
# unet.eval()
# text_encoder.eval()
# Encode text embeddings
token_ids = batch['prompt_ids']
encoder_hidden_states = text_encoder(token_ids)[0]
detached_encoder_state = encoder_hidden_states.clone().detach()
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
encoder_hidden_states = detached_encoder_state
# optimization
if mask_spatial_lora:
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
for lora_i in loras:
lora_i.scale = 0.
loss_spatial = None
else:
loras = extract_lora_child_module(unet, target_replace_module=["Transformer2DModel"])
if spatial_lora_num == 1:
for lora_i in loras:
lora_i.scale = 1.
else:
for lora_i in loras:
lora_i.scale = 0.
for lora_idx in range(0, len(loras), spatial_lora_num):
loras[lora_idx + step].scale = 1.
loras = extract_lora_child_module(unet, target_replace_module=["TransformerTemporalModel"])
if len(loras) > 0:
for lora_i in loras:
lora_i.scale = 0.
ran_idx = torch.randint(0, noisy_latents.shape[2], (1,)).item()
if random.uniform(0, 1) < random_hflip_img:
pixel_values_spatial = transforms.functional.hflip(
batch["pixel_values"][:, ran_idx, :, :, :]).unsqueeze(1)
latents_spatial = tensor_to_vae_latent(pixel_values_spatial, vae)
noise_spatial = sample_noise(latents_spatial, 0.1, False)
noisy_latents_input = noise_scheduler.add_noise(latents_spatial, noise_spatial, timesteps)
target_spatial = noise_spatial
model_pred_spatial = unet(noisy_latents_input, timesteps,
encoder_hidden_states=encoder_hidden_states).sample
loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
target_spatial[:, :, 0, :, :].float(), reduction="mean")
else:
noisy_latents_input = noisy_latents[:, :, ran_idx, :, :]
target_spatial = target[:, :, ran_idx, :, :]
model_pred_spatial = unet(noisy_latents_input.unsqueeze(2), timesteps,
encoder_hidden_states=encoder_hidden_states).sample
loss_spatial = F.mse_loss(model_pred_spatial[:, :, 0, :, :].float(),
target_spatial.float(), reduction="mean")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states).sample
loss_temporal = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
beta = 1
alpha = (beta ** 2 + 1) ** 0.5
ran_idx = torch.randint(0, model_pred.shape[2], (1,)).item()
model_pred_decent = alpha * model_pred - beta * model_pred[:, :, ran_idx, :, :].unsqueeze(2)
target_decent = alpha * target - beta * target[:, :, ran_idx, :, :].unsqueeze(2)
loss_ad_temporal = F.mse_loss(model_pred_decent.float(), target_decent.float(), reduction="mean")
loss_temporal = loss_temporal + loss_ad_temporal
avg_loss_temporal = accelerator.gather(loss_temporal.repeat(config.train.train_batch_size)).mean()
train_loss_temporal += avg_loss_temporal.item() / config.train.gradient_accumulation_steps
if not mask_spatial_lora:
accelerator.backward(loss_spatial, retain_graph=True)
if spatial_lora_num == 1:
optimizers[1].step()
else:
optimizers[step+1].step()
accelerator.backward(loss_temporal)
optimizers[0].step()
if spatial_lora_num == 1:
lr_schedulers[1].step()
else:
lr_schedulers[1 + step].step()
lr_schedulers[0].step()
return loss_temporal, train_loss_temporal