Skip to content

Commit

Permalink
small fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ddPn08 committed Apr 2, 2023
1 parent c4a11e5 commit 8bfa50e
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 0 deletions.
6 changes: 6 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2851,6 +2851,8 @@ def save_du():
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)

def remove_du(old_epoch_no):
out_dir_old = os.path.join(args.output_dir, EPOCH_DIFFUSERS_DIR_NAME.format(model_name, old_epoch_no))
Expand Down Expand Up @@ -2906,6 +2908,8 @@ def save_sd_model_on_train_end(
model_util.save_stable_diffusion_checkpoint(
args.v2, ckpt_file, text_encoder, unet, src_path, epoch, global_step, save_dtype, vae
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
else:
out_dir = os.path.join(args.output_dir, model_name)
os.makedirs(out_dir, exist_ok=True)
Expand All @@ -2914,6 +2918,8 @@ def save_sd_model_on_train_end(
model_util.save_diffusers_checkpoint(
args.v2, out_dir, text_encoder, unet, src_path, vae=vae, use_safetensors=use_safetensors
)
if args.huggingface_repo_id is not None:
huggingface_util.upload(out_dir, args, "/" + model_name)


def save_state_on_train_end(args: argparse.Namespace, accelerator):
Expand Down
2 changes: 2 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ def remove_old_func(old_epoch_no):

print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
print("model saved.")


Expand Down
2 changes: 2 additions & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,8 @@ def remove_old_func(old_epoch_no):

print(f"save trained model to {ckpt_file}")
save_weights(ckpt_file, updated_embs, save_dtype)
if args.huggingface_repo_id is not None:
huggingface_util.upload(ckpt_file, args, "/" + ckpt_name)
print("model saved.")


Expand Down

0 comments on commit 8bfa50e

Please sign in to comment.