Skip to content

Commit

Permalink
fix sd3 training to work without cachine TE outputs bmaltais#1465
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Aug 17, 2024
1 parent e45d3f8 commit 7367584
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions sd3_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,9 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
input_ids_clip_l, input_ids_clip_g, _, l_attn_mask, g_attn_mask, _ = batch["input_ids_list"]
with torch.set_grad_enabled(args.train_text_encoder):
# TODO support weighted captions
input_ids_clip_l = input_ids_clip_l.to(accelerator.device)
input_ids_clip_g = input_ids_clip_g.to(accelerator.device)
# text models in sd3_models require "cpu" for input_ids
input_ids_clip_l = input_ids_clip_l.to("cpu")
input_ids_clip_g = input_ids_clip_g.to("cpu")
lg_out, _, lg_pooled = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy,
[clip_l, clip_g, None],
Expand All @@ -770,7 +771,7 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
if t5_out is None:
_, _, input_ids_t5xxl, _, _, t5_attn_mask = batch["input_ids_list"]
with torch.no_grad():
input_ids_t5xxl = input_ids_t5xxl.to(accelerator.device) if t5_out is None else None
input_ids_t5xxl = input_ids_t5xxl.to("cpu") if t5_out is None else None
_, t5_out, _ = text_encoding_strategy.encode_tokens(
sd3_tokenize_strategy, [None, None, t5xxl], [None, None, input_ids_t5xxl, None, None, t5_attn_mask]
)
Expand Down

0 comments on commit 7367584

Please sign in to comment.