@@ -128,17 +128,31 @@ def get_models(
128128 )
129129
130130
131- def text2img_dataloader (train_dataset , train_batch_size , tokenizer , vae , text_encoder ):
131+ @torch .no_grad ()
132+ def text2img_dataloader (
133+ train_dataset ,
134+ train_batch_size ,
135+ tokenizer ,
136+ vae ,
137+ text_encoder ,
138+ cached_latents : bool = False ,
139+ ):
140+
141+ if cached_latents :
142+ cached_latents_dataset = []
143+ for idx in tqdm (range (len (train_dataset ))):
144+ batch = train_dataset [idx ]
145+ # rint(batch)
146+ latents = vae .encode (
147+ batch ["instance_images" ].unsqueeze (0 ).to (dtype = vae .dtype ).to (vae .device )
148+ ).latent_dist .sample ()
149+ latents = latents * 0.18215
150+ batch ["instance_images" ] = latents .squeeze (0 )
151+ cached_latents_dataset .append (batch )
152+
132153 def collate_fn (examples ):
133154 input_ids = [example ["instance_prompt_ids" ] for example in examples ]
134155 pixel_values = [example ["instance_images" ] for example in examples ]
135-
136- # Concat class and instance examples for prior preservation.
137- # We do this to avoid doing two forward passes.
138- if examples [0 ].get ("class_prompt_ids" , None ) is not None :
139- input_ids += [example ["class_prompt_ids" ] for example in examples ]
140- pixel_values += [example ["class_images" ] for example in examples ]
141-
142156 pixel_values = torch .stack (pixel_values )
143157 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
144158
@@ -159,33 +173,60 @@ def collate_fn(examples):
159173
160174 return batch
161175
162- train_dataloader = torch .utils .data .DataLoader (
163- train_dataset ,
164- batch_size = train_batch_size ,
165- shuffle = True ,
166- collate_fn = collate_fn ,
167- )
176+ if cached_latents :
177+
178+ train_dataloader = torch .utils .data .DataLoader (
179+ cached_latents_dataset ,
180+ batch_size = train_batch_size ,
181+ shuffle = True ,
182+ collate_fn = collate_fn ,
183+ )
184+
185+ print ("PTI : Using cached latent." )
186+
187+ else :
188+ train_dataloader = torch .utils .data .DataLoader (
189+ train_dataset ,
190+ batch_size = train_batch_size ,
191+ shuffle = True ,
192+ collate_fn = collate_fn ,
193+ )
168194
169195 return train_dataloader
170196
171- def inpainting_dataloader (train_dataset , train_batch_size , tokenizer , vae , text_encoder ):
197+
198+ def inpainting_dataloader (
199+ train_dataset , train_batch_size , tokenizer , vae , text_encoder
200+ ):
172201 def collate_fn (examples ):
173202 input_ids = [example ["instance_prompt_ids" ] for example in examples ]
174203 pixel_values = [example ["instance_images" ] for example in examples ]
175204 mask_values = [example ["instance_masks" ] for example in examples ]
176- masked_image_values = [example ["instance_masked_images" ] for example in examples ]
205+ masked_image_values = [
206+ example ["instance_masked_images" ] for example in examples
207+ ]
177208
178209 # Concat class and instance examples for prior preservation.
179210 # We do this to avoid doing two forward passes.
180211 if examples [0 ].get ("class_prompt_ids" , None ) is not None :
181212 input_ids += [example ["class_prompt_ids" ] for example in examples ]
182213 pixel_values += [example ["class_images" ] for example in examples ]
183214 mask_values += [example ["class_masks" ] for example in examples ]
184- masked_image_values += [example ["class_masked_images" ] for example in examples ]
215+ masked_image_values += [
216+ example ["class_masked_images" ] for example in examples
217+ ]
185218
186- pixel_values = torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
187- mask_values = torch .stack (mask_values ).to (memory_format = torch .contiguous_format ).float ()
188- masked_image_values = torch .stack (masked_image_values ).to (memory_format = torch .contiguous_format ).float ()
219+ pixel_values = (
220+ torch .stack (pixel_values ).to (memory_format = torch .contiguous_format ).float ()
221+ )
222+ mask_values = (
223+ torch .stack (mask_values ).to (memory_format = torch .contiguous_format ).float ()
224+ )
225+ masked_image_values = (
226+ torch .stack (masked_image_values )
227+ .to (memory_format = torch .contiguous_format )
228+ .float ()
229+ )
189230
190231 input_ids = tokenizer .pad (
191232 {"input_ids" : input_ids },
@@ -198,7 +239,7 @@ def collate_fn(examples):
198239 "input_ids" : input_ids ,
199240 "pixel_values" : pixel_values ,
200241 "mask_values" : mask_values ,
201- "masked_image_values" : masked_image_values
242+ "masked_image_values" : masked_image_values ,
202243 }
203244
204245 if examples [0 ].get ("mask" , None ) is not None :
@@ -215,6 +256,7 @@ def collate_fn(examples):
215256
216257 return train_dataloader
217258
259+
218260def loss_step (
219261 batch ,
220262 unet ,
@@ -225,23 +267,30 @@ def loss_step(
225267 t_mutliplier = 1.0 ,
226268 mixed_precision = False ,
227269 mask_temperature = 1.0 ,
270+ cached_latents : bool = False ,
228271):
229272 weight_dtype = torch .float32
230-
231- latents = vae .encode (
232- batch ["pixel_values" ].to (dtype = weight_dtype ).to (unet .device )
233- ).latent_dist .sample ()
234- latents = latents * 0.18215
235-
236- if train_inpainting :
237- masked_image_latents = vae .encode (
238- batch ["masked_image_values" ].to (dtype = weight_dtype ).to (unet .device )
273+ if not cached_latents :
274+ latents = vae .encode (
275+ batch ["pixel_values" ].to (dtype = weight_dtype ).to (unet .device )
239276 ).latent_dist .sample ()
240- masked_image_latents = masked_image_latents * 0.18215
241- mask = F .interpolate (
242- batch ["mask_values" ].to (dtype = weight_dtype ).to (unet .device ),
243- scale_factor = 1 / 8
244- )
277+ latents = latents * 0.18215
278+
279+ if train_inpainting :
280+ masked_image_latents = vae .encode (
281+ batch ["masked_image_values" ].to (dtype = weight_dtype ).to (unet .device )
282+ ).latent_dist .sample ()
283+ masked_image_latents = masked_image_latents * 0.18215
284+ mask = F .interpolate (
285+ batch ["mask_values" ].to (dtype = weight_dtype ).to (unet .device ),
286+ scale_factor = 1 / 8 ,
287+ )
288+ else :
289+ latents = batch ["pixel_values" ]
290+
291+ if train_inpainting :
292+ masked_image_latents = batch ["masked_image_latents" ]
293+ mask = batch ["mask_values" ]
245294
246295 noise = torch .randn_like (latents )
247296 bsz = latents .shape [0 ]
@@ -257,7 +306,9 @@ def loss_step(
257306 noisy_latents = scheduler .add_noise (latents , noise , timesteps )
258307
259308 if train_inpainting :
260- latent_model_input = torch .cat ([noisy_latents , mask , masked_image_latents ], dim = 1 )
309+ latent_model_input = torch .cat (
310+ [noisy_latents , mask , masked_image_latents ], dim = 1
311+ )
261312 else :
262313 latent_model_input = noisy_latents
263314
@@ -268,7 +319,9 @@ def loss_step(
268319 batch ["input_ids" ].to (text_encoder .device )
269320 )[0 ]
270321
271- model_pred = unet (latent_model_input , timesteps , encoder_hidden_states ).sample
322+ model_pred = unet (
323+ latent_model_input , timesteps , encoder_hidden_states
324+ ).sample
272325 else :
273326
274327 encoder_hidden_states = text_encoder (
@@ -308,7 +361,12 @@ def loss_step(
308361
309362 target = target * mask
310363
311- loss = F .mse_loss (model_pred .float (), target .float (), reduction = "mean" )
364+ loss = (
365+ F .mse_loss (model_pred .float (), target .float (), reduction = "none" )
366+ .mean ([1 , 2 , 3 ])
367+ .mean ()
368+ )
369+
312370 return loss
313371
314372
@@ -328,6 +386,7 @@ def train_inversion(
328386 tokenizer ,
329387 lr_scheduler ,
330388 test_image_path : str ,
389+ cached_latents : bool ,
331390 accum_iter : int = 1 ,
332391 log_wandb : bool = False ,
333392 wandb_log_prompt_cnt : int = 10 ,
@@ -367,6 +426,7 @@ def train_inversion(
367426 scheduler ,
368427 train_inpainting = train_inpainting ,
369428 mixed_precision = mixed_precision ,
429+ cached_latents = cached_latents ,
370430 )
371431 / accum_iter
372432 )
@@ -375,6 +435,13 @@ def train_inversion(
375435 loss_sum += loss .detach ().item ()
376436
377437 if global_step % accum_iter == 0 :
438+ # print gradient of text encoder embedding
439+ print (
440+ text_encoder .get_input_embeddings ()
441+ .weight .grad [index_updates , :]
442+ .norm (dim = - 1 )
443+ .mean ()
444+ )
378445 optimizer .step ()
379446 optimizer .zero_grad ()
380447
@@ -448,7 +515,11 @@ def train_inversion(
448515 # open all images in test_image_path
449516 images = []
450517 for file in os .listdir (test_image_path ):
451- if file .lower ().endswith (".png" ) or file .lower ().endswith (".jpg" ) or file .lower ().endswith (".jpeg" ):
518+ if (
519+ file .lower ().endswith (".png" )
520+ or file .lower ().endswith (".jpg" )
521+ or file .lower ().endswith (".jpeg" )
522+ ):
452523 images .append (
453524 Image .open (os .path .join (test_image_path , file ))
454525 )
@@ -490,6 +561,7 @@ def perform_tuning(
490561 out_name : str ,
491562 tokenizer ,
492563 test_image_path : str ,
564+ cached_latents : bool ,
493565 log_wandb : bool = False ,
494566 wandb_log_prompt_cnt : int = 10 ,
495567 class_token : str = "person" ,
@@ -526,6 +598,7 @@ def perform_tuning(
526598 t_mutliplier = 0.8 ,
527599 mixed_precision = True ,
528600 mask_temperature = mask_temperature ,
601+ cached_latents = cached_latents ,
529602 )
530603 loss_sum += loss .detach ().item ()
531604
@@ -627,18 +700,12 @@ def train(
627700 train_text_encoder : bool = True ,
628701 pretrained_vae_name_or_path : str = None ,
629702 revision : Optional [str ] = None ,
630- class_data_dir : Optional [str ] = None ,
631- stochastic_attribute : Optional [str ] = None ,
632703 perform_inversion : bool = True ,
633704 use_template : Literal [None , "object" , "style" ] = None ,
634705 train_inpainting : bool = False ,
635706 placeholder_tokens : str = "" ,
636707 placeholder_token_at_data : Optional [str ] = None ,
637708 initializer_tokens : Optional [str ] = None ,
638- class_prompt : Optional [str ] = None ,
639- with_prior_preservation : bool = False ,
640- prior_loss_weight : float = 1.0 ,
641- num_class_images : int = 100 ,
642709 seed : int = 42 ,
643710 resolution : int = 512 ,
644711 color_jitter : bool = True ,
@@ -649,7 +716,6 @@ def train(
649716 save_steps : int = 100 ,
650717 gradient_accumulation_steps : int = 4 ,
651718 gradient_checkpointing : bool = False ,
652- mixed_precision = "fp16" ,
653719 lora_rank : int = 4 ,
654720 lora_unet_target_modules = {"CrossAttention" , "Attention" , "GEGLU" },
655721 lora_clip_target_modules = {"CLIPAttention" },
@@ -663,6 +729,7 @@ def train(
663729 continue_inversion : bool = False ,
664730 continue_inversion_lr : Optional [float ] = None ,
665731 use_face_segmentation_condition : bool = False ,
732+ cached_latents : bool = True ,
666733 use_mask_captioned_data : bool = False ,
667734 mask_temperature : float = 1.0 ,
668735 scale_lr : bool = False ,
@@ -773,11 +840,8 @@ def train(
773840
774841 train_dataset = PivotalTuningDatasetCapation (
775842 instance_data_root = instance_data_dir ,
776- stochastic_attribute = stochastic_attribute ,
777843 token_map = token_map ,
778844 use_template = use_template ,
779- class_data_root = class_data_dir if with_prior_preservation else None ,
780- class_prompt = class_prompt ,
781845 tokenizer = tokenizer ,
782846 size = resolution ,
783847 color_jitter = color_jitter ,
@@ -789,12 +853,19 @@ def train(
789853 train_dataset .blur_amount = 200
790854
791855 if train_inpainting :
856+ assert not cached_latents , "Cached latents not supported for inpainting"
857+
792858 train_dataloader = inpainting_dataloader (
793859 train_dataset , train_batch_size , tokenizer , vae , text_encoder
794860 )
795861 else :
796862 train_dataloader = text2img_dataloader (
797- train_dataset , train_batch_size , tokenizer , vae , text_encoder
863+ train_dataset ,
864+ train_batch_size ,
865+ tokenizer ,
866+ vae ,
867+ text_encoder ,
868+ cached_latents = cached_latents ,
798869 )
799870
800871 index_no_updates = torch .arange (len (tokenizer )) != - 1
@@ -813,6 +884,8 @@ def train(
813884 for param in params_to_freeze :
814885 param .requires_grad = False
815886
887+ if cached_latents :
888+ vae = None
816889 # STEP 1 : Perform Inversion
817890 if perform_inversion :
818891 ti_optimizer = optim .AdamW (
@@ -836,6 +909,7 @@ def train(
836909 text_encoder ,
837910 train_dataloader ,
838911 max_train_steps_ti ,
912+ cached_latents = cached_latents ,
839913 accum_iter = gradient_accumulation_steps ,
840914 scheduler = noise_scheduler ,
841915 index_no_updates = index_no_updates ,
@@ -941,6 +1015,7 @@ def train(
9411015 text_encoder ,
9421016 train_dataloader ,
9431017 max_train_steps_tuning ,
1018+ cached_latents = cached_latents ,
9441019 scheduler = noise_scheduler ,
9451020 optimizer = lora_optimizers ,
9461021 save_steps = save_steps ,
0 commit comments