@@ -722,7 +722,7 @@ def trim_and_resize_if_required(self, subset: BaseSubset, image, reso, resized_s
722
722
def is_latent_cacheable (self ):
723
723
return all ([not subset .color_aug and not subset .random_crop for subset in self .subsets ])
724
724
725
- def cache_latents (self , vae , vae_batch_size = 1 ):
725
+ def cache_latents (self , vae , vae_batch_size = 1 , cache_to_disk = False , is_main_process = True ):
726
726
# ちょっと速くした
727
727
print ("caching latents." )
728
728
@@ -740,11 +740,38 @@ def cache_latents(self, vae, vae_batch_size=1):
740
740
if info .latents_npz is not None :
741
741
info .latents = self .load_latents_from_npz (info , False )
742
742
info .latents = torch .FloatTensor (info .latents )
743
- info .latents_flipped = self .load_latents_from_npz (info , True ) # might be None
743
+
744
+ # might be None, but that's ok because check is done in dataset
745
+ info .latents_flipped = self .load_latents_from_npz (info , True )
744
746
if info .latents_flipped is not None :
745
747
info .latents_flipped = torch .FloatTensor (info .latents_flipped )
746
748
continue
747
749
750
+ # check disk cache exists and size of latents
751
+ if cache_to_disk :
752
+ # TODO: refactor to unify with FineTuningDataset
753
+ info .latents_npz = os .path .splitext (info .absolute_path )[0 ] + ".npz"
754
+ info .latents_npz_flipped = os .path .splitext (info .absolute_path )[0 ] + "_flip.npz"
755
+ if not is_main_process :
756
+ continue
757
+
758
+ cache_available = False
759
+ expected_latents_size = (info .bucket_reso [1 ] // 8 , info .bucket_reso [0 ] // 8 ) # bucket_resoはWxHなので注意
760
+ if os .path .exists (info .latents_npz ):
761
+ cached_latents = np .load (info .latents_npz )
762
+ if cached_latents ["latents" ].shape [1 :3 ] == expected_latents_size :
763
+ cache_available = True
764
+
765
+ if subset .flip_aug :
766
+ cache_available = False
767
+ if os .path .exists (info .latents_npz_flipped ):
768
+ cached_latents_flipped = np .load (info .latents_npz_flipped )
769
+ if cached_latents_flipped ["latents" ].shape [1 :3 ] == expected_latents_size :
770
+ cache_available = True
771
+
772
+ if cache_available :
773
+ continue
774
+
748
775
# if last member of batch has different resolution, flush the batch
749
776
if len (batch ) > 0 and batch [- 1 ].bucket_reso != info .bucket_reso :
750
777
batches .append (batch )
@@ -760,6 +787,9 @@ def cache_latents(self, vae, vae_batch_size=1):
760
787
if len (batch ) > 0 :
761
788
batches .append (batch )
762
789
790
+ if cache_to_disk and not is_main_process : # don't cache latents in non-main process, set to info only
791
+ return
792
+
763
793
# iterate batches
764
794
for batch in tqdm (batches , smoothing = 1 , total = len (batches )):
765
795
images = []
@@ -773,14 +803,21 @@ def cache_latents(self, vae, vae_batch_size=1):
773
803
img_tensors = img_tensors .to (device = vae .device , dtype = vae .dtype )
774
804
775
805
latents = vae .encode (img_tensors ).latent_dist .sample ().to ("cpu" )
806
+
776
807
for info , latent in zip (batch , latents ):
777
- info .latents = latent
808
+ if cache_to_disk :
809
+ np .savez (info .latents_npz , latent .float ().numpy ())
810
+ else :
811
+ info .latents = latent
778
812
779
813
if subset .flip_aug :
780
814
img_tensors = torch .flip (img_tensors , dims = [3 ])
781
815
latents = vae .encode (img_tensors ).latent_dist .sample ().to ("cpu" )
782
816
for info , latent in zip (batch , latents ):
783
- info .latents_flipped = latent
817
+ if cache_to_disk :
818
+ np .savez (info .latents_npz_flipped , latent .float ().numpy ())
819
+ else :
820
+ info .latents_flipped = latent
784
821
785
822
def get_image_size (self , image_path ):
786
823
image = Image .open (image_path )
@@ -873,10 +910,10 @@ def __getitem__(self, index):
873
910
loss_weights .append (self .prior_loss_weight if image_info .is_reg else 1.0 )
874
911
875
912
# image/latentsを処理する
876
- if image_info .latents is not None :
913
+ if image_info .latents is not None : # cache_latents=Trueの場合
877
914
latents = image_info .latents if not subset .flip_aug or random .random () < 0.5 else image_info .latents_flipped
878
915
image = None
879
- elif image_info .latents_npz is not None :
916
+ elif image_info .latents_npz is not None : # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
880
917
latents = self .load_latents_from_npz (image_info , subset .flip_aug and random .random () >= 0.5 )
881
918
latents = torch .FloatTensor (latents )
882
919
image = None
@@ -1340,10 +1377,10 @@ def enable_XTI(self, *args, **kwargs):
1340
1377
for dataset in self .datasets :
1341
1378
dataset .enable_XTI (* args , ** kwargs )
1342
1379
1343
- def cache_latents (self , vae , vae_batch_size = 1 ):
1380
+ def cache_latents (self , vae , vae_batch_size = 1 , cache_to_disk = False , is_main_process = True ):
1344
1381
for i , dataset in enumerate (self .datasets ):
1345
1382
print (f"[Dataset { i } ]" )
1346
- dataset .cache_latents (vae , vae_batch_size )
1383
+ dataset .cache_latents (vae , vae_batch_size , cache_to_disk , is_main_process )
1347
1384
1348
1385
def is_latent_cacheable (self ) -> bool :
1349
1386
return all ([dataset .is_latent_cacheable () for dataset in self .datasets ])
@@ -2144,9 +2181,14 @@ def add_dataset_arguments(
2144
2181
parser .add_argument (
2145
2182
"--cache_latents" ,
2146
2183
action = "store_true" ,
2147
- help = "cache latents to reduce memory (augmentations must be disabled) / メモリ削減のためにlatentをcacheする (augmentationは使用不可)" ,
2184
+ help = "cache latents to main memory to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをメインメモリにcacheする (augmentationは使用不可) " ,
2148
2185
)
2149
2186
parser .add_argument ("--vae_batch_size" , type = int , default = 1 , help = "batch size for caching latents / latentのcache時のバッチサイズ" )
2187
+ parser .add_argument (
2188
+ "--cache_latents_to_disk" ,
2189
+ action = "store_true" ,
2190
+ help = "cache latents to disk to reduce VRAM usage (augmentations must be disabled) / VRAM削減のためにlatentをディスクにcacheする(augmentationは使用不可)" ,
2191
+ )
2150
2192
parser .add_argument (
2151
2193
"--enable_bucket" , action = "store_true" , help = "enable buckets for multi aspect ratio training / 複数解像度学習のためのbucketを有効にする"
2152
2194
)
@@ -3203,4 +3245,4 @@ def __call__(self, examples):
3203
3245
# set epoch and step
3204
3246
dataset .set_current_epoch (self .current_epoch .value )
3205
3247
dataset .set_current_step (self .current_step .value )
3206
- return examples [0 ]
3248
+ return examples [0 ]
0 commit comments