Skip to content

Commit bdd51b0

Browse files
authored
Merge pull request #180 from cloneofsimo/develop
v0.1.7
2 parents 799c17a + e48cbbb commit bdd51b0

File tree

4 files changed

+135
-90
lines changed

4 files changed

+135
-90
lines changed

lora_diffusion/cli_lora_pti.py

Lines changed: 125 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -128,17 +128,31 @@ def get_models(
128128
)
129129

130130

131-
def text2img_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
131+
@torch.no_grad()
132+
def text2img_dataloader(
133+
train_dataset,
134+
train_batch_size,
135+
tokenizer,
136+
vae,
137+
text_encoder,
138+
cached_latents: bool = False,
139+
):
140+
141+
if cached_latents:
142+
cached_latents_dataset = []
143+
for idx in tqdm(range(len(train_dataset))):
144+
batch = train_dataset[idx]
145+
# rint(batch)
146+
latents = vae.encode(
147+
batch["instance_images"].unsqueeze(0).to(dtype=vae.dtype).to(vae.device)
148+
).latent_dist.sample()
149+
latents = latents * 0.18215
150+
batch["instance_images"] = latents.squeeze(0)
151+
cached_latents_dataset.append(batch)
152+
132153
def collate_fn(examples):
133154
input_ids = [example["instance_prompt_ids"] for example in examples]
134155
pixel_values = [example["instance_images"] for example in examples]
135-
136-
# Concat class and instance examples for prior preservation.
137-
# We do this to avoid doing two forward passes.
138-
if examples[0].get("class_prompt_ids", None) is not None:
139-
input_ids += [example["class_prompt_ids"] for example in examples]
140-
pixel_values += [example["class_images"] for example in examples]
141-
142156
pixel_values = torch.stack(pixel_values)
143157
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
144158

@@ -159,33 +173,60 @@ def collate_fn(examples):
159173

160174
return batch
161175

162-
train_dataloader = torch.utils.data.DataLoader(
163-
train_dataset,
164-
batch_size=train_batch_size,
165-
shuffle=True,
166-
collate_fn=collate_fn,
167-
)
176+
if cached_latents:
177+
178+
train_dataloader = torch.utils.data.DataLoader(
179+
cached_latents_dataset,
180+
batch_size=train_batch_size,
181+
shuffle=True,
182+
collate_fn=collate_fn,
183+
)
184+
185+
print("PTI : Using cached latent.")
186+
187+
else:
188+
train_dataloader = torch.utils.data.DataLoader(
189+
train_dataset,
190+
batch_size=train_batch_size,
191+
shuffle=True,
192+
collate_fn=collate_fn,
193+
)
168194

169195
return train_dataloader
170196

171-
def inpainting_dataloader(train_dataset, train_batch_size, tokenizer, vae, text_encoder):
197+
198+
def inpainting_dataloader(
199+
train_dataset, train_batch_size, tokenizer, vae, text_encoder
200+
):
172201
def collate_fn(examples):
173202
input_ids = [example["instance_prompt_ids"] for example in examples]
174203
pixel_values = [example["instance_images"] for example in examples]
175204
mask_values = [example["instance_masks"] for example in examples]
176-
masked_image_values = [example["instance_masked_images"] for example in examples]
205+
masked_image_values = [
206+
example["instance_masked_images"] for example in examples
207+
]
177208

178209
# Concat class and instance examples for prior preservation.
179210
# We do this to avoid doing two forward passes.
180211
if examples[0].get("class_prompt_ids", None) is not None:
181212
input_ids += [example["class_prompt_ids"] for example in examples]
182213
pixel_values += [example["class_images"] for example in examples]
183214
mask_values += [example["class_masks"] for example in examples]
184-
masked_image_values += [example["class_masked_images"] for example in examples]
215+
masked_image_values += [
216+
example["class_masked_images"] for example in examples
217+
]
185218

186-
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
187-
mask_values = torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
188-
masked_image_values = torch.stack(masked_image_values).to(memory_format=torch.contiguous_format).float()
219+
pixel_values = (
220+
torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
221+
)
222+
mask_values = (
223+
torch.stack(mask_values).to(memory_format=torch.contiguous_format).float()
224+
)
225+
masked_image_values = (
226+
torch.stack(masked_image_values)
227+
.to(memory_format=torch.contiguous_format)
228+
.float()
229+
)
189230

190231
input_ids = tokenizer.pad(
191232
{"input_ids": input_ids},
@@ -198,7 +239,7 @@ def collate_fn(examples):
198239
"input_ids": input_ids,
199240
"pixel_values": pixel_values,
200241
"mask_values": mask_values,
201-
"masked_image_values": masked_image_values
242+
"masked_image_values": masked_image_values,
202243
}
203244

204245
if examples[0].get("mask", None) is not None:
@@ -215,6 +256,7 @@ def collate_fn(examples):
215256

216257
return train_dataloader
217258

259+
218260
def loss_step(
219261
batch,
220262
unet,
@@ -225,23 +267,30 @@ def loss_step(
225267
t_mutliplier=1.0,
226268
mixed_precision=False,
227269
mask_temperature=1.0,
270+
cached_latents: bool = False,
228271
):
229272
weight_dtype = torch.float32
230-
231-
latents = vae.encode(
232-
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
233-
).latent_dist.sample()
234-
latents = latents * 0.18215
235-
236-
if train_inpainting:
237-
masked_image_latents = vae.encode(
238-
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
273+
if not cached_latents:
274+
latents = vae.encode(
275+
batch["pixel_values"].to(dtype=weight_dtype).to(unet.device)
239276
).latent_dist.sample()
240-
masked_image_latents = masked_image_latents * 0.18215
241-
mask = F.interpolate(
242-
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
243-
scale_factor=1/8
244-
)
277+
latents = latents * 0.18215
278+
279+
if train_inpainting:
280+
masked_image_latents = vae.encode(
281+
batch["masked_image_values"].to(dtype=weight_dtype).to(unet.device)
282+
).latent_dist.sample()
283+
masked_image_latents = masked_image_latents * 0.18215
284+
mask = F.interpolate(
285+
batch["mask_values"].to(dtype=weight_dtype).to(unet.device),
286+
scale_factor=1 / 8,
287+
)
288+
else:
289+
latents = batch["pixel_values"]
290+
291+
if train_inpainting:
292+
masked_image_latents = batch["masked_image_latents"]
293+
mask = batch["mask_values"]
245294

246295
noise = torch.randn_like(latents)
247296
bsz = latents.shape[0]
@@ -257,7 +306,9 @@ def loss_step(
257306
noisy_latents = scheduler.add_noise(latents, noise, timesteps)
258307

259308
if train_inpainting:
260-
latent_model_input = torch.cat([noisy_latents, mask, masked_image_latents], dim=1)
309+
latent_model_input = torch.cat(
310+
[noisy_latents, mask, masked_image_latents], dim=1
311+
)
261312
else:
262313
latent_model_input = noisy_latents
263314

@@ -268,7 +319,9 @@ def loss_step(
268319
batch["input_ids"].to(text_encoder.device)
269320
)[0]
270321

271-
model_pred = unet(latent_model_input, timesteps, encoder_hidden_states).sample
322+
model_pred = unet(
323+
latent_model_input, timesteps, encoder_hidden_states
324+
).sample
272325
else:
273326

274327
encoder_hidden_states = text_encoder(
@@ -308,7 +361,12 @@ def loss_step(
308361

309362
target = target * mask
310363

311-
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
364+
loss = (
365+
F.mse_loss(model_pred.float(), target.float(), reduction="none")
366+
.mean([1, 2, 3])
367+
.mean()
368+
)
369+
312370
return loss
313371

314372

@@ -328,6 +386,7 @@ def train_inversion(
328386
tokenizer,
329387
lr_scheduler,
330388
test_image_path: str,
389+
cached_latents: bool,
331390
accum_iter: int = 1,
332391
log_wandb: bool = False,
333392
wandb_log_prompt_cnt: int = 10,
@@ -367,6 +426,7 @@ def train_inversion(
367426
scheduler,
368427
train_inpainting=train_inpainting,
369428
mixed_precision=mixed_precision,
429+
cached_latents=cached_latents,
370430
)
371431
/ accum_iter
372432
)
@@ -375,6 +435,13 @@ def train_inversion(
375435
loss_sum += loss.detach().item()
376436

377437
if global_step % accum_iter == 0:
438+
# print gradient of text encoder embedding
439+
print(
440+
text_encoder.get_input_embeddings()
441+
.weight.grad[index_updates, :]
442+
.norm(dim=-1)
443+
.mean()
444+
)
378445
optimizer.step()
379446
optimizer.zero_grad()
380447

@@ -448,7 +515,11 @@ def train_inversion(
448515
# open all images in test_image_path
449516
images = []
450517
for file in os.listdir(test_image_path):
451-
if file.lower().endswith(".png") or file.lower().endswith(".jpg") or file.lower().endswith(".jpeg"):
518+
if (
519+
file.lower().endswith(".png")
520+
or file.lower().endswith(".jpg")
521+
or file.lower().endswith(".jpeg")
522+
):
452523
images.append(
453524
Image.open(os.path.join(test_image_path, file))
454525
)
@@ -490,6 +561,7 @@ def perform_tuning(
490561
out_name: str,
491562
tokenizer,
492563
test_image_path: str,
564+
cached_latents: bool,
493565
log_wandb: bool = False,
494566
wandb_log_prompt_cnt: int = 10,
495567
class_token: str = "person",
@@ -526,6 +598,7 @@ def perform_tuning(
526598
t_mutliplier=0.8,
527599
mixed_precision=True,
528600
mask_temperature=mask_temperature,
601+
cached_latents=cached_latents,
529602
)
530603
loss_sum += loss.detach().item()
531604

@@ -627,18 +700,12 @@ def train(
627700
train_text_encoder: bool = True,
628701
pretrained_vae_name_or_path: str = None,
629702
revision: Optional[str] = None,
630-
class_data_dir: Optional[str] = None,
631-
stochastic_attribute: Optional[str] = None,
632703
perform_inversion: bool = True,
633704
use_template: Literal[None, "object", "style"] = None,
634705
train_inpainting: bool = False,
635706
placeholder_tokens: str = "",
636707
placeholder_token_at_data: Optional[str] = None,
637708
initializer_tokens: Optional[str] = None,
638-
class_prompt: Optional[str] = None,
639-
with_prior_preservation: bool = False,
640-
prior_loss_weight: float = 1.0,
641-
num_class_images: int = 100,
642709
seed: int = 42,
643710
resolution: int = 512,
644711
color_jitter: bool = True,
@@ -649,7 +716,6 @@ def train(
649716
save_steps: int = 100,
650717
gradient_accumulation_steps: int = 4,
651718
gradient_checkpointing: bool = False,
652-
mixed_precision="fp16",
653719
lora_rank: int = 4,
654720
lora_unet_target_modules={"CrossAttention", "Attention", "GEGLU"},
655721
lora_clip_target_modules={"CLIPAttention"},
@@ -663,6 +729,7 @@ def train(
663729
continue_inversion: bool = False,
664730
continue_inversion_lr: Optional[float] = None,
665731
use_face_segmentation_condition: bool = False,
732+
cached_latents: bool = True,
666733
use_mask_captioned_data: bool = False,
667734
mask_temperature: float = 1.0,
668735
scale_lr: bool = False,
@@ -773,11 +840,8 @@ def train(
773840

774841
train_dataset = PivotalTuningDatasetCapation(
775842
instance_data_root=instance_data_dir,
776-
stochastic_attribute=stochastic_attribute,
777843
token_map=token_map,
778844
use_template=use_template,
779-
class_data_root=class_data_dir if with_prior_preservation else None,
780-
class_prompt=class_prompt,
781845
tokenizer=tokenizer,
782846
size=resolution,
783847
color_jitter=color_jitter,
@@ -789,12 +853,19 @@ def train(
789853
train_dataset.blur_amount = 200
790854

791855
if train_inpainting:
856+
assert not cached_latents, "Cached latents not supported for inpainting"
857+
792858
train_dataloader = inpainting_dataloader(
793859
train_dataset, train_batch_size, tokenizer, vae, text_encoder
794860
)
795861
else:
796862
train_dataloader = text2img_dataloader(
797-
train_dataset, train_batch_size, tokenizer, vae, text_encoder
863+
train_dataset,
864+
train_batch_size,
865+
tokenizer,
866+
vae,
867+
text_encoder,
868+
cached_latents=cached_latents,
798869
)
799870

800871
index_no_updates = torch.arange(len(tokenizer)) != -1
@@ -813,6 +884,8 @@ def train(
813884
for param in params_to_freeze:
814885
param.requires_grad = False
815886

887+
if cached_latents:
888+
vae = None
816889
# STEP 1 : Perform Inversion
817890
if perform_inversion:
818891
ti_optimizer = optim.AdamW(
@@ -836,6 +909,7 @@ def train(
836909
text_encoder,
837910
train_dataloader,
838911
max_train_steps_ti,
912+
cached_latents=cached_latents,
839913
accum_iter=gradient_accumulation_steps,
840914
scheduler=noise_scheduler,
841915
index_no_updates=index_no_updates,
@@ -941,6 +1015,7 @@ def train(
9411015
text_encoder,
9421016
train_dataloader,
9431017
max_train_steps_tuning,
1018+
cached_latents=cached_latents,
9441019
scheduler=noise_scheduler,
9451020
optimizer=lora_optimizers,
9461021
save_steps=save_steps,

0 commit comments

Comments
 (0)