Skip to content

Commit 866a0b5

Browse files
author
govu
committed
Implemented image_captions_filename and stop_text_encoder_training configuration
1 parent debc74f commit 866a0b5

1 file changed

Lines changed: 54 additions & 5 deletions

File tree

examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import os
66
import random
7+
import sys
78
from pathlib import Path
89
from typing import Optional
910

@@ -12,6 +13,7 @@
1213
import torch.nn.functional as F
1314
import torch.utils.checkpoint
1415
from torch.utils.data import Dataset
16+
import subprocess
1517

1618
from accelerate import Accelerator
1719
from accelerate.logging import get_logger
@@ -82,6 +84,12 @@ def random_mask(im_shape, ratio=1, mask_full_image=False):
8284

8385
def parse_args():
8486
parser = argparse.ArgumentParser(description="Simple example of a training script.")
87+
parser.add_argument(
88+
"--image_captions_filename",
89+
action="store_true",
90+
help="Get captions from filename",
91+
)
92+
8593
parser.add_argument(
8694
"--pretrained_model_name_or_path",
8795
type=str,
@@ -164,6 +172,8 @@ def parse_args():
164172
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
165173
)
166174
parser.add_argument("--num_train_epochs", type=int, default=1)
175+
parser.add_argument("--stop_text_encoder_training", type=int, default=sys.maxsize)
176+
167177
parser.add_argument(
168178
"--max_train_steps",
169179
type=int,
@@ -287,6 +297,7 @@ class DreamBoothDataset(Dataset):
287297

288298
def __init__(
289299
self,
300+
args,
290301
instance_data_root,
291302
instance_prompt,
292303
tokenizer,
@@ -298,6 +309,7 @@ def __init__(
298309
self.size = size
299310
self.center_crop = center_crop
300311
self.tokenizer = tokenizer
312+
self.image_captions_filename = None
301313

302314
self.instance_data_root = Path(instance_data_root)
303315
if not self.instance_data_root.exists():
@@ -308,6 +320,9 @@ def __init__(
308320
self.instance_prompt = instance_prompt
309321
self._length = self.num_instance_images
310322

323+
if args.image_captions_filename:
324+
self.image_captions_filename = True
325+
311326
if class_data_root is not None:
312327
self.class_data_root = Path(class_data_root)
313328
self.class_data_root.mkdir(parents=True, exist_ok=True)
@@ -337,16 +352,30 @@ def __len__(self):
337352

338353
def __getitem__(self, index):
339354
example = {}
340-
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
355+
path = self.instance_images_path[index % self.num_instance_images]
356+
instance_image = Image.open(path)
341357
if not instance_image.mode == "RGB":
342358
instance_image = instance_image.convert("RGB")
343-
instance_image = self.image_transforms_resize_and_crop(instance_image)
344359

360+
instance_prompt = self.instance_prompt
361+
362+
if self.image_captions_filename:
363+
filename = Path(path).stem
364+
pt=''.join([i for i in filename if not i.isdigit()])
365+
pt=pt.replace("_"," ")
366+
pt=pt.replace("(","")
367+
pt=pt.replace(")","")
368+
pt=pt.replace("-","")
369+
instance_prompt = pt
370+
sys.stdout.write(" " +instance_prompt+" ")
371+
sys.stdout.flush()
372+
373+
345374
example["PIL_images"] = instance_image
346375
example["instance_images"] = self.image_transforms(instance_image)
347376

348377
example["instance_prompt_ids"] = self.tokenizer(
349-
self.instance_prompt,
378+
instance_prompt,
350379
padding="do_not_pad",
351380
truncation=True,
352381
max_length=self.tokenizer.model_max_length,
@@ -533,6 +562,7 @@ def main():
533562
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
534563

535564
train_dataset = DreamBoothDataset(
565+
args,
536566
instance_data_root=args.instance_data_dir,
537567
instance_prompt=args.instance_prompt,
538568
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
@@ -672,7 +702,7 @@ def collate_fn(examples):
672702
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
673703
progress_bar.set_description("Steps")
674704

675-
for epoch in range(first_epoch, args.num_epochs):
705+
for epoch in range(first_epoch, args.num_train_epochs):
676706
unet.train()
677707
for step, batch in enumerate(train_dataloader):
678708
# Skip steps until we reach the resumed step
@@ -774,12 +804,26 @@ def collate_fn(examples):
774804
progress_bar.set_postfix(**logs)
775805
accelerator.log(logs, step=global_step)
776806

807+
if args.train_text_encoder and global_step == args.stop_text_encoder_training and global_step >= 30:
808+
if accelerator.is_main_process:
809+
print(" " +" Freezing the text_encoder ..."+" ")
810+
frz_dir=args.output_dir + "/text_encoder_frozen"
811+
if os.path.exists(frz_dir):
812+
subprocess.call('rm -r '+ frz_dir, shell=True)
813+
os.mkdir(frz_dir)
814+
pipeline = StableDiffusionPipeline.from_pretrained(
815+
args.pretrained_model_name_or_path,
816+
unet=accelerator.unwrap_model(unet),
817+
text_encoder=accelerator.unwrap_model(text_encoder),
818+
)
819+
pipeline.text_encoder.save_pretrained(frz_dir)
820+
777821
if global_step >= args.max_train_steps:
778822
break
779823

780824
accelerator.wait_for_everyone()
781825

782-
# Create the pipeline using using the trained modules and save it.
826+
# Create the pipeline using the trained modules and save it.
783827
if accelerator.is_main_process:
784828
pipeline = StableDiffusionPipeline.from_pretrained(
785829
args.pretrained_model_name_or_path,
@@ -788,6 +832,11 @@ def collate_fn(examples):
788832
)
789833
pipeline.save_pretrained(args.output_dir)
790834

835+
frz_dir=args.output_dir + "/text_encoder_frozen"
836+
if args.train_text_encoder and os.path.exists(frz_dir):
837+
subprocess.call('mv -f '+frz_dir +'/*.* '+ args.output_dir+'/text_encoder', shell=True)
838+
subprocess.call('rm -r '+ frz_dir, shell=True)
839+
791840
if args.push_to_hub:
792841
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
793842

0 commit comments

Comments
 (0)