44import math
55import os
66import random
7+ import sys
78from pathlib import Path
89from typing import Optional
910
1213import torch .nn .functional as F
1314import torch .utils .checkpoint
1415from torch .utils .data import Dataset
16+ import subprocess
1517
1618from accelerate import Accelerator
1719from accelerate .logging import get_logger
@@ -82,6 +84,12 @@ def random_mask(im_shape, ratio=1, mask_full_image=False):
8284
8385def parse_args ():
8486 parser = argparse .ArgumentParser (description = "Simple example of a training script." )
87+ parser .add_argument (
88+ "--image_captions_filename" ,
89+ action = "store_true" ,
90+ help = "Get captions from filename" ,
91+ )
92+
8593 parser .add_argument (
8694 "--pretrained_model_name_or_path" ,
8795 type = str ,
@@ -164,6 +172,8 @@ def parse_args():
164172 "--sample_batch_size" , type = int , default = 4 , help = "Batch size (per device) for sampling images."
165173 )
166174 parser .add_argument ("--num_train_epochs" , type = int , default = 1 )
175+ parser .add_argument ("--stop_text_encoder_training" , type = int , default = sys .maxsize )
176+
167177 parser .add_argument (
168178 "--max_train_steps" ,
169179 type = int ,
@@ -287,6 +297,7 @@ class DreamBoothDataset(Dataset):
287297
288298 def __init__ (
289299 self ,
300+ args ,
290301 instance_data_root ,
291302 instance_prompt ,
292303 tokenizer ,
@@ -298,6 +309,7 @@ def __init__(
298309 self .size = size
299310 self .center_crop = center_crop
300311 self .tokenizer = tokenizer
312+ self .image_captions_filename = None
301313
302314 self .instance_data_root = Path (instance_data_root )
303315 if not self .instance_data_root .exists ():
@@ -308,6 +320,9 @@ def __init__(
308320 self .instance_prompt = instance_prompt
309321 self ._length = self .num_instance_images
310322
323+ if args .image_captions_filename :
324+ self .image_captions_filename = True
325+
311326 if class_data_root is not None :
312327 self .class_data_root = Path (class_data_root )
313328 self .class_data_root .mkdir (parents = True , exist_ok = True )
@@ -337,16 +352,30 @@ def __len__(self):
337352
338353 def __getitem__ (self , index ):
339354 example = {}
340- instance_image = Image .open (self .instance_images_path [index % self .num_instance_images ])
355+ path = self .instance_images_path [index % self .num_instance_images ]
356+ instance_image = Image .open (path )
341357 if not instance_image .mode == "RGB" :
342358 instance_image = instance_image .convert ("RGB" )
343- instance_image = self .image_transforms_resize_and_crop (instance_image )
344359
360+ instance_prompt = self .instance_prompt
361+
362+ if self .image_captions_filename :
363+ filename = Path (path ).stem
364+ pt = '' .join ([i for i in filename if not i .isdigit ()])
365+ pt = pt .replace ("_" ," " )
366+ pt = pt .replace ("(" ,"" )
367+ pt = pt .replace (")" ,"" )
368+ pt = pt .replace ("-" ,"" )
369+ instance_prompt = pt
370+ sys .stdout .write (" [0;32m" + instance_prompt + " [0m" )
371+ sys .stdout .flush ()
372+
373+
345374 example ["PIL_images" ] = instance_image
346375 example ["instance_images" ] = self .image_transforms (instance_image )
347376
348377 example ["instance_prompt_ids" ] = self .tokenizer (
349- self . instance_prompt ,
378+ instance_prompt ,
350379 padding = "do_not_pad" ,
351380 truncation = True ,
352381 max_length = self .tokenizer .model_max_length ,
@@ -533,6 +562,7 @@ def main():
533562 noise_scheduler = DDPMScheduler .from_pretrained (args .pretrained_model_name_or_path , subfolder = "scheduler" )
534563
535564 train_dataset = DreamBoothDataset (
565+ args ,
536566 instance_data_root = args .instance_data_dir ,
537567 instance_prompt = args .instance_prompt ,
538568 class_data_root = args .class_data_dir if args .with_prior_preservation else None ,
@@ -672,7 +702,7 @@ def collate_fn(examples):
672702 progress_bar = tqdm (range (global_step , args .max_train_steps ), disable = not accelerator .is_local_main_process )
673703 progress_bar .set_description ("Steps" )
674704
675- for epoch in range (first_epoch , args .num_epochs ):
705+ for epoch in range (first_epoch , args .num_train_epochs ):
676706 unet .train ()
677707 for step , batch in enumerate (train_dataloader ):
678708 # Skip steps until we reach the resumed step
@@ -774,12 +804,26 @@ def collate_fn(examples):
774804 progress_bar .set_postfix (** logs )
775805 accelerator .log (logs , step = global_step )
776806
807+ if args .train_text_encoder and global_step == args .stop_text_encoder_training and global_step >= 30 :
808+ if accelerator .is_main_process :
809+ print (" [0;32m" + " Freezing the text_encoder ..." + " [0m" )
810+ frz_dir = args .output_dir + "/text_encoder_frozen"
811+ if os .path .exists (frz_dir ):
812+ subprocess .call ('rm -r ' + frz_dir , shell = True )
813+ os .mkdir (frz_dir )
814+ pipeline = StableDiffusionPipeline .from_pretrained (
815+ args .pretrained_model_name_or_path ,
816+ unet = accelerator .unwrap_model (unet ),
817+ text_encoder = accelerator .unwrap_model (text_encoder ),
818+ )
819+ pipeline .text_encoder .save_pretrained (frz_dir )
820+
777821 if global_step >= args .max_train_steps :
778822 break
779823
780824 accelerator .wait_for_everyone ()
781825
782- # Create the pipeline using using the trained modules and save it.
826+ # Create the pipeline using the trained modules and save it.
783827 if accelerator .is_main_process :
784828 pipeline = StableDiffusionPipeline .from_pretrained (
785829 args .pretrained_model_name_or_path ,
@@ -788,6 +832,11 @@ def collate_fn(examples):
788832 )
789833 pipeline .save_pretrained (args .output_dir )
790834
835+ frz_dir = args .output_dir + "/text_encoder_frozen"
836+ if args .train_text_encoder and os .path .exists (frz_dir ):
837+ subprocess .call ('mv -f ' + frz_dir + '/*.* ' + args .output_dir + '/text_encoder' , shell = True )
838+ subprocess .call ('rm -r ' + frz_dir , shell = True )
839+
791840 if args .push_to_hub :
792841 repo .push_to_hub (commit_message = "End of training" , blocking = False , auto_lfs_prune = True )
793842
0 commit comments