Skip to content

Commit

Permalink
add caching to disk for text encoder outputs
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Jul 16, 2023
1 parent 62dd99b commit 516f64f
Show file tree
Hide file tree
Showing 6 changed files with 537 additions and 142 deletions.
68 changes: 12 additions & 56 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,62 +140,6 @@ def load_tokenizers(args: argparse.Namespace):
return tokeniers


def get_hidden_states(
args: argparse.Namespace, input_ids1, input_ids2, tokenizer1, tokenizer2, text_encoder1, text_encoder2, weight_dtype=None
):
# input_ids: b,n,77 -> b*n, 77
b_size = input_ids1.size()[0]
input_ids1 = input_ids1.reshape((-1, tokenizer1.model_max_length)) # batch_size*n, 77
input_ids2 = input_ids2.reshape((-1, tokenizer2.model_max_length)) # batch_size*n, 77

# text_encoder1
enc_out = text_encoder1(input_ids1, output_hidden_states=True, return_dict=True)
hidden_states1 = enc_out["hidden_states"][11]

# text_encoder2
enc_out = text_encoder2(input_ids2, output_hidden_states=True, return_dict=True)
hidden_states2 = enc_out["hidden_states"][-2] # penuultimate layer
pool2 = enc_out["text_embeds"]

# b*n, 77, 768 or 1280 -> b, n*77, 768 or 1280
n_size = 1 if args.max_token_length is None else args.max_token_length // 75
hidden_states1 = hidden_states1.reshape((b_size, -1, hidden_states1.shape[-1]))
hidden_states2 = hidden_states2.reshape((b_size, -1, hidden_states2.shape[-1]))

if args.max_token_length is not None:
# bs*3, 77, 768 or 1024
# encoder1: <BOS>...<EOS> の三連を <BOS>...<EOS> へ戻す
states_list = [hidden_states1[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer1.model_max_length):
states_list.append(hidden_states1[:, i : i + tokenizer1.model_max_length - 2]) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states1[:, -1].unsqueeze(1)) # <EOS>
hidden_states1 = torch.cat(states_list, dim=1)

# v2: <BOS>...<EOS> <PAD> ... の三連を <BOS>...<EOS> <PAD> ... へ戻す 正直この実装でいいのかわからん
states_list = [hidden_states2[:, 0].unsqueeze(1)] # <BOS>
for i in range(1, args.max_token_length, tokenizer2.model_max_length):
chunk = hidden_states2[:, i : i + tokenizer2.model_max_length - 2] # <BOS> の後から 最後の前まで
# this causes an error:
# RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
# if i > 1:
# for j in range(len(chunk)): # batch_size
# if input_ids2[n_index + j * n_size, 1] == tokenizer2.eos_token_id: # 空、つまり <BOS> <EOS> <PAD> ...のパターン
# chunk[j, 0] = chunk[j, 1] # 次の <PAD> の値をコピーする
states_list.append(chunk) # <BOS> の後から <EOS> の前まで
states_list.append(hidden_states2[:, -1].unsqueeze(1)) # <EOS> か <PAD> のどちらか
hidden_states2 = torch.cat(states_list, dim=1)

# pool はnの最初のものを使う
pool2 = pool2[::n_size]

if weight_dtype is not None:
# this is required for additional network training
hidden_states1 = hidden_states1.to(weight_dtype)
hidden_states2 = hidden_states2.to(weight_dtype)

return hidden_states1, hidden_states2, pool2


def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
Expand Down Expand Up @@ -391,6 +335,11 @@ def add_sdxl_training_arguments(parser: argparse.ArgumentParser):
parser.add_argument(
"--cache_text_encoder_outputs", action="store_true", help="cache text encoder outputs / text encoderの出力をキャッシュする"
)
parser.add_argument(
"--cache_text_encoder_outputs_to_disk",
action="store_true",
help="cache text encoder outputs to disk / text encoderの出力をディスクにキャッシュする",
)


def verify_sdxl_training_args(args: argparse.Namespace):
Expand All @@ -417,6 +366,13 @@ def verify_sdxl_training_args(args: argparse.Namespace):
not hasattr(args, "weighted_captions") or not args.weighted_captions
), "weighted_captions cannot be enabled in SDXL training currently / SDXL学習では今のところweighted_captionsを有効にすることはできません"

if args.cache_text_encoder_outputs_to_disk and not args.cache_text_encoder_outputs:
args.cache_text_encoder_outputs = True
print(
"cache_text_encoder_outputs is enabled because cache_text_encoder_outputs_to_disk is enabled / "
+ "cache_text_encoder_outputs_to_diskが有効になっているためcache_text_encoder_outputsが有効になりました"
)


def sample_images(*args, **kwargs):
return train_util.sample_images_common(SdxlStableDiffusionLongPromptWeightingPipeline, *args, **kwargs)
Loading

0 comments on commit 516f64f

Please sign in to comment.