Skip to content

Commit 39baddf

Browse files
Add LoRA support from sd_script repo
1 parent 2cdf4cf commit 39baddf

6 files changed

+4324
-3
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@ cudnn_windows
55
.vscode
66
*.egg-info
77
build
8-
wd14_tagger_model
8+
wd14_tagger_model
9+
.DS_Store

finetune/prepare_buckets_latents.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -130,14 +130,16 @@ def main(args):
130130
latents = get_latents(vae, [img for _, _, img in bucket], weight_dtype)
131131

132132
for (image_key, reso, _), latent in zip(bucket, latents):
133-
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0]), latent)
133+
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
134+
np.savez(os.path.join(args.train_data_dir, npz_file_name), latent)
134135

135136
# flip
136137
if args.flip_aug:
137138
latents = get_latents(vae, [img[:, ::-1].copy() for _, _, img in bucket], weight_dtype) # copyがないとTensor変換できない
138139

139140
for (image_key, reso, _), latent in zip(bucket, latents):
140-
np.savez(os.path.join(args.train_data_dir, os.path.splitext(os.path.basename(image_key))[0] + '_flip'), latent)
141+
npz_file_name = os.path.splitext(os.path.basename(image_key))[0] if args.full_path else image_key
142+
np.savez(os.path.join(args.train_data_dir, npz_file_name + '_flip'), latent)
141143

142144
bucket.clear()
143145

0 commit comments

Comments
 (0)