Skip to content

Commit

Permalink
Merge branch 'sdxl' of https://github.com/kohya-ss/sd-scripts into dev2
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Jul 16, 2023
2 parents 1fd18e9 + 62dd99b commit 4478fc3
Show file tree
Hide file tree
Showing 6 changed files with 447 additions and 282 deletions.
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ The feature of SDXL training is now available in sdxl branch as an experimental

Summary of the feature:

- `tools/cache_latents.py` is added. This script can be used to cache the latents in advance.
- The options are almost the same as `sdxl_train.py'. See the help message for the usage.
- Please launch the script as follows:
`accelerate launch --num_cpu_threads_per_process 1 tools/cache_latents.py ...`
- This script should work with multi-GPU, but it is not tested in my environment.

- `sdxl_train.py` is a script for SDXL fine-tuning. The usage is almost the same as `fine_tune.py`, but it also supports DreamBooth dataset.
- `--full_bf16` option is added. Thanks to KohakuBlueleaf!
- This option enables the full bfloat16 training (includes gradients). This option is useful to reduce the GPU memory usage.
Expand Down
114 changes: 12 additions & 102 deletions finetune/prepare_buckets_latents.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,36 +34,18 @@ def collate_fn_remove_corrupted(batch):
return batch


def get_latents(vae, key_and_images, weight_dtype):
img_tensors = [IMAGE_TRANSFORMS(image) for _, image in key_and_images]
img_tensors = torch.stack(img_tensors)
img_tensors = img_tensors.to(DEVICE, weight_dtype)
with torch.no_grad():
latents = vae.encode(img_tensors).latent_dist.sample()

# check NaN
for (key, _), latents1 in zip(key_and_images, latents):
if torch.isnan(latents1).any():
raise ValueError(f"NaN detected in latents of {key}")

return latents


def get_npz_filename_wo_ext(data_dir, image_key, is_full_path, flip, recursive):
def get_npz_filename(data_dir, image_key, is_full_path, recursive):
if is_full_path:
base_name = os.path.splitext(os.path.basename(image_key))[0]
relative_path = os.path.relpath(os.path.dirname(image_key), data_dir)
else:
base_name = image_key
relative_path = ""

if flip:
base_name += "_flip"

if recursive and relative_path:
return os.path.join(data_dir, relative_path, base_name)
return os.path.join(data_dir, relative_path, base_name) + ".npz"
else:
return os.path.join(data_dir, base_name)
return os.path.join(data_dir, base_name) + ".npz"


def main(args):
Expand Down Expand Up @@ -113,36 +95,7 @@ def main(args):
def process_batch(is_last):
for bucket in bucket_manager.buckets:
if (is_last and len(bucket) > 0) or len(bucket) >= args.batch_size:
latents = get_latents(vae, [(key, img) for key, img, _, _ in bucket], weight_dtype)
assert (
latents.shape[2] == bucket[0][1].shape[0] // 8 and latents.shape[3] == bucket[0][1].shape[1] // 8
), f"latent shape {latents.shape}, {bucket[0][1].shape}"

for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive)
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)

# flip
if args.flip_aug:
latents = get_latents(
vae, [(key, img[:, ::-1].copy()) for key, img, _, _ in bucket], weight_dtype
) # copyがないとTensor変換できない

for (image_key, _, original_size, crop_left_top), latent in zip(bucket, latents):
npz_file_name = get_npz_filename_wo_ext(
args.train_data_dir, image_key, args.full_path, True, args.recursive
)
train_util.save_latents_to_disk(npz_file_name, latent, original_size, crop_left_top)
else:
# remove existing flipped npz
for image_key, _ in bucket:
npz_file_name = (
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)
if os.path.isfile(npz_file_name):
print(f"remove existing flipped npz / 既存のflipされたnpzファイルを削除します: {npz_file_name}")
os.remove(npz_file_name)

train_util.cache_batch_latents(vae, True, bucket, args.flip_aug, False)
bucket.clear()

# 読み込みの高速化のためにDataLoaderを使うオプション
Expand Down Expand Up @@ -203,61 +156,18 @@ def process_batch(is_last):
), f"internal error resized size is small: {resized_size}, {reso}"

# 既に存在するファイルがあればshape等を確認して同じならskipする
npz_file_name = get_npz_filename(args.train_data_dir, image_key, args.full_path, args.recursive)
if args.skip_existing:
npz_files = [get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, False, args.recursive) + ".npz"]
if args.flip_aug:
npz_files.append(
get_npz_filename_wo_ext(args.train_data_dir, image_key, args.full_path, True, args.recursive) + ".npz"
)

found = True
for npz_file in npz_files:
if not os.path.exists(npz_file):
found = False
break

latents, _, _ = train_util.load_latents_from_disk(npz_file)
if latents is None: # old version
found = False
break

if latents.shape[1] != reso[1] // 8 or latents.shape[2] != reso[0] // 8: # latentsのshapeを確認
found = False
break
if found:
if train_util.is_disk_cached_latents_is_expected(reso, npz_file_name, args.flip_aug):
continue

# 画像をリサイズしてトリミングする
# PILにinter_areaがないのでcv2で……
image = np.array(image)
if resized_size[0] != image.shape[1] or resized_size[1] != image.shape[0]: # リサイズ処理が必要?
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA)

trim_left = 0
if resized_size[0] > reso[0]:
trim_size = resized_size[0] - reso[0]
image = image[:, trim_size // 2 : trim_size // 2 + reso[0]]
trim_left = trim_size // 2

trim_top = 0
if resized_size[1] > reso[1]:
trim_size = resized_size[1] - reso[1]
image = image[trim_size // 2 : trim_size // 2 + reso[1]]
trim_top = trim_size // 2

original_size_wh = (resized_size[0], resized_size[1])
# target_size_wh = (reso[0], reso[1])
crop_left_top = (trim_left, trim_top)

assert (
image.shape[0] == reso[1] and image.shape[1] == reso[0]
), f"internal error, illegal trimmed size: {image.shape}, {reso}"

# # debug
# cv2.imwrite(f"r:\\test\\img_{len(img_ar_errors)}.jpg", image[:, :, ::-1])

# バッチへ追加
bucket_manager.add_image(reso, (image_key, image, original_size_wh, crop_left_top))
image_info = train_util.ImageInfo(image_key, 1, "", False, image_path)
image_info.latents_npz = npz_file_name
image_info.bucket_reso = reso
image_info.resized_size = resized_size
image_info.image = image
bucket_manager.add_image(reso, image_info)

# バッチを推論するか判定して推論する
process_batch(False)
Expand Down
33 changes: 13 additions & 20 deletions library/lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,9 +446,7 @@ def prepare_controlnet_image(

for image_ in image:
image_ = image_.convert("RGB")
image_ = image_.resize(
(width, height), resample=PIL_INTERPOLATION["lanczos"]
)
image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
image_ = np.array(image_)
image_ = image_[None, :]
images.append(image_)
Expand Down Expand Up @@ -479,6 +477,7 @@ def prepare_controlnet_image(

return image


class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
r"""
Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
Expand Down Expand Up @@ -889,8 +888,9 @@ def __call__(
mask = None

if controlnet_image is not None:
controlnet_image = prepare_controlnet_image(controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False)

controlnet_image = prepare_controlnet_image(
controlnet_image, width, height, batch_size, 1, self.device, controlnet.dtype, do_classifier_free_guidance, False
)

# 5. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand Down Expand Up @@ -930,8 +930,8 @@ def __call__(
guess_mode=False,
return_dict=False,
)
unet_additional_args['down_block_additional_residuals'] = down_block_res_samples
unet_additional_args['mid_block_additional_residual'] = mid_block_res_sample
unet_additional_args["down_block_additional_residuals"] = down_block_res_samples
unet_additional_args["mid_block_additional_residual"] = mid_block_res_sample

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings, **unet_additional_args).sample
Expand All @@ -956,20 +956,13 @@ def __call__(
if is_cancelled_callback is not None and is_cancelled_callback():
return None

# 9. Post-processing
image = self.decode_latents(latents)

# 10. Run safety checker
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)

# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)

if not return_dict:
return image, has_nsfw_concept
return latents

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def latents_to_image(self, latents):
# 9. Post-processing
image = self.decode_latents(latents.to(self.vae.dtype))
image = self.numpy_to_pil(image)
return image

def text2img(
self,
Expand Down
21 changes: 7 additions & 14 deletions library/sdxl_lpw_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ def __call__(

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, text_embedding, vector_embedding)
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training
noise_pred = noise_pred.to(dtype) # U-Net changes dtype in LoRA training

# perform guidance
if do_classifier_free_guidance:
Expand All @@ -1027,20 +1027,13 @@ def __call__(
if is_cancelled_callback is not None and is_cancelled_callback():
return None

# 9. Post-processing
image = self.decode_latents(latents.to(torch.float32))

# 10. Run safety checker
image, has_nsfw_concept = image, None # self.run_safety_checker(image, device, text_embeddings.dtype)

# 11. Convert to PIL
if output_type == "pil":
image = self.numpy_to_pil(image)
return latents

if not return_dict:
return image, has_nsfw_concept

return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
def latents_to_image(self, latents):
# 9. Post-processing
image = self.decode_latents(latents.to(self.vae.dtype))
image = self.numpy_to_pil(image)
return image

# copy from pil_utils.py
def numpy_to_pil(self, images: np.ndarray) -> Image.Image:
Expand Down
Loading

0 comments on commit 4478fc3

Please sign in to comment.