Skip to content

Commit 2e9f7b5

Browse files
committed
cache latents to disk in dreambooth method
1 parent 5050971 commit 2e9f7b5

6 files changed

+67
-15
lines changed

fine_tune.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,14 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
142142
vae.requires_grad_(False)
143143
vae.eval()
144144
with torch.no_grad():
145-
train_dataset_group.cache_latents(vae, args.vae_batch_size)
145+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
146146
vae.to("cpu")
147147
if torch.cuda.is_available():
148148
torch.cuda.empty_cache()
149149
gc.collect()
150150

151+
accelerator.wait_for_everyone()
152+
151153
# 学習を準備する:モデルを適切な状態にする
152154
training_models = []
153155
if args.gradient_checkpointing:

library/train_util.py

+52-10
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_s
722722
def is_latent_cacheable(self):
723723
return all([not subset.color_aug and not subset.random_crop for subset in self.subsets])
724724

725-
def cache_latents(self, vae, vae_batch_size=1):
725+
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
726726
# ちょっと速くした
727727
print("caching latents.")
728728

@@ -740,11 +740,38 @@ def cache_latents(self, vae, vae_batch_size=1):
740740
if info.latents_npz is not None:
741741
info.latents = self.load_latents_from_npz(info, False)
742742
info.latents = torch.FloatTensor(info.latents)
743-
info.latents_flipped = self.load_latents_from_npz(info, True) # might be None
743+
744+
# might be None, but that's ok because check is done in dataset
745+
info.latents_flipped = self.load_latents_from_npz(info, True)
744746
if info.latents_flipped is not None:
745747
info.latents_flipped = torch.FloatTensor(info.latents_flipped)
746748
continue
747749

750+
# check disk cache exists and size of latents
751+
if cache_to_disk:
752+
# TODO: refactor to unify with FineTuningDataset
753+
info.latents_npz = os.path.splitext(info.absolute_path)[0] + ".npz"
754+
info.latents_npz_flipped = os.path.splitext(info.absolute_path)[0] + "_flip.npz"
755+
if not is_main_process:
756+
continue
757+
758+
cache_available = False
759+
expected_latents_size = (info.bucket_reso[1] // 8, info.bucket_reso[0] // 8) # bucket_resoはWxHなので注意
760+
if os.path.exists(info.latents_npz):
761+
cached_latents = np.load(info.latents_npz)
762+
if cached_latents["latents"].shape[1:3] == expected_latents_size:
763+
cache_available = True
764+
765+
if subset.flip_aug:
766+
cache_available = False
767+
if os.path.exists(info.latents_npz_flipped):
768+
cached_latents_flipped = np.load(info.latents_npz_flipped)
769+
if cached_latents_flipped["latents"].shape[1:3] == expected_latents_size:
770+
cache_available = True
771+
772+
if cache_available:
773+
continue
774+
748775
# if last member of batch has different resolution, flush the batch
749776
if len(batch) > 0 and batch[-1].bucket_reso != info.bucket_reso:
750777
batches.append(batch)
@@ -760,6 +787,9 @@ def cache_latents(self, vae, vae_batch_size=1):
760787
if len(batch) > 0:
761788
batches.append(batch)
762789

790+
if cache_to_disk and not is_main_process: # don't cache latents in non-main process, set to info only
791+
return
792+
763793
# iterate batches
764794
for batch in tqdm(batches, smoothing=1, total=len(batches)):
765795
images = []
@@ -773,14 +803,21 @@ def cache_latents(self, vae, vae_batch_size=1):
773803
img_tensors = img_tensors.to(device=vae.device, dtype=vae.dtype)
774804

775805
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
806+
776807
for info, latent in zip(batch, latents):
777-
info.latents = latent
808+
if cache_to_disk:
809+
np.savez(info.latents_npz, latent.float().numpy())
810+
else:
811+
info.latents = latent
778812

779813
if subset.flip_aug:
780814
img_tensors = torch.flip(img_tensors, dims=[3])
781815
latents = vae.encode(img_tensors).latent_dist.sample().to("cpu")
782816
for info, latent in zip(batch, latents):
783-
info.latents_flipped = latent
817+
if cache_to_disk:
818+
np.savez(info.latents_npz_flipped, latent.float().numpy())
819+
else:
820+
info.latents_flipped = latent
784821

785822
def get_image_size(self, image_path):
786823
image = Image.open(image_path)
@@ -873,10 +910,10 @@ def __getitem__(self, index):
873910
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)
874911

875912
# image/latentsを処理する
876-
if image_info.latents is not None:
913+
if image_info.latents is not None: # cache_latents=Trueの場合
877914
latents = image_info.latents if not subset.flip_aug or random.random() < 0.5 else image_info.latents_flipped
878915
image = None
879-
elif image_info.latents_npz is not None:
916+
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
880917
latents = self.load_latents_from_npz(image_info, subset.flip_aug and random.random() >= 0.5)
881918
latents = torch.FloatTensor(latents)
882919
image = None
@@ -1340,10 +1377,10 @@ def enable_XTI(self, *args, **kwargs):
13401377
for dataset in self.datasets:
13411378
dataset.enable_XTI(*args, **kwargs)
13421379

1343-
def cache_latents(self, vae, vae_batch_size=1):
1380+
def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_process=True):
13441381
for i, dataset in enumerate(self.datasets):
13451382
print(f"[Dataset {i}]")
1346-
dataset.cache_latents(vae, vae_batch_size)
1383+
dataset.cache_latents(vae, vae_batch_size, cache_to_disk, is_main_process)
13471384

13481385
def is_latent_cacheable(self) -> bool:
13491386
return all([dataset.is_latent_cacheable() for dataset in self.datasets])
@@ -2144,9 +2181,14 @@ def add_dataset_arguments(
21442181
parser.add_argument(
21452182
"--cache_latents",
21462183
action="store_true",
2147-
help="cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする(augmentationは使用不可)",
2184+
help="cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする(augmentationは使用不可) ",
21482185
)
21492186
parser.add_argument("--vae_batch_size", type=int, default=1, help="batch size for caching latents / latentのcache時のバッチサイズ")
2187+
parser.add_argument(
2188+
"--cache_latents_to_disk",
2189+
action="store_true",
2190+
help="cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)",
2191+
)
21502192
parser.add_argument(
21512193
"--enable_bucket", action="store_true", help="enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
21522194
)
@@ -3203,4 +3245,4 @@ def __call__(self, examples):
32033245
# set epoch and step
32043246
dataset.set_current_epoch(self.current_epoch.value)
32053247
dataset.set_current_step(self.current_step.value)
3206-
return examples[0]
3248+
return examples[0]

train_db.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,14 @@ def train(args):
117117
vae.requires_grad_(False)
118118
vae.eval()
119119
with torch.no_grad():
120-
train_dataset_group.cache_latents(vae, args.vae_batch_size)
120+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
121121
vae.to("cpu")
122122
if torch.cuda.is_available():
123123
torch.cuda.empty_cache()
124124
gc.collect()
125125

126+
accelerator.wait_for_everyone()
127+
126128
# 学習を準備する:モデルを適切な状態にする
127129
train_text_encoder = args.stop_text_encoder_training is None or args.stop_text_encoder_training >= 0
128130
unet.requires_grad_(True) # 念のため追加

train_network.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -172,12 +172,14 @@ def train(args):
172172
vae.requires_grad_(False)
173173
vae.eval()
174174
with torch.no_grad():
175-
train_dataset_group.cache_latents(vae, args.vae_batch_size)
175+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
176176
vae.to("cpu")
177177
if torch.cuda.is_available():
178178
torch.cuda.empty_cache()
179179
gc.collect()
180180

181+
accelerator.wait_for_everyone()
182+
181183
# prepare network
182184
import sys
183185

train_textual_inversion.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -233,12 +233,14 @@ def train(args):
233233
vae.requires_grad_(False)
234234
vae.eval()
235235
with torch.no_grad():
236-
train_dataset_group.cache_latents(vae, args.vae_batch_size)
236+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
237237
vae.to("cpu")
238238
if torch.cuda.is_available():
239239
torch.cuda.empty_cache()
240240
gc.collect()
241241

242+
accelerator.wait_for_everyone()
243+
242244
if args.gradient_checkpointing:
243245
unet.enable_gradient_checkpointing()
244246
text_encoder.gradient_checkpointing_enable()

train_textual_inversion_XTI.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -267,12 +267,14 @@ def train(args):
267267
vae.requires_grad_(False)
268268
vae.eval()
269269
with torch.no_grad():
270-
train_dataset_group.cache_latents(vae, args.vae_batch_size)
270+
train_dataset_group.cache_latents(vae, args.vae_batch_size, args.cache_latents_to_disk, accelerator.is_main_process)
271271
vae.to("cpu")
272272
if torch.cuda.is_available():
273273
torch.cuda.empty_cache()
274274
gc.collect()
275275

276+
accelerator.wait_for_everyone()
277+
276278
if args.gradient_checkpointing:
277279
unet.enable_gradient_checkpointing()
278280
text_encoder.gradient_checkpointing_enable()

0 commit comments

Comments
 (0)