Skip to content

Commit

Permalink
Merge pull request bmaltais#2342 from bmaltais/dev
Browse files Browse the repository at this point in the history
v24.0.4
  • Loading branch information
bmaltais authored Apr 19, 2024
2 parents 05cf164 + 98d826c commit 6c69b89
Show file tree
Hide file tree
Showing 35 changed files with 351 additions and 258 deletions.
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v24.0.3
v24.0.4
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The GUI allows you to set the training parameters and generate and run the requi
- [SDXL training](#sdxl-training)
- [Masked loss](#masked-loss)
- [Change History](#change-history)
- [2024/04/25 (v24.0.4)](#20240425-v2404)
- [2024/04/24 (v24.0.3)](#20240424-v2403)
- [2024/04/24 (v24.0.2)](#20240424-v2402)
- [2024/04/17 (v24.0.1)](#20240417-v2401)
Expand Down Expand Up @@ -408,6 +409,9 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG

## Change History

### 2024/04/25 (v24.0.4)

- ...

### 2024/04/24 (v24.0.3)

Expand Down
150 changes: 57 additions & 93 deletions docs/LoRA/options.md

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions kohya_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def UI(**kwargs):
share = kwargs.get("share", False)
do_not_share = kwargs.get("do_not_share", False)
server_name = kwargs.get("listen")
root_path = kwargs.get("root_path", None)

launch_kwargs["server_name"] = server_name
if username and password:
Expand All @@ -120,6 +121,8 @@ def UI(**kwargs):
else:
if share:
launch_kwargs["share"] = share
if root_path:
launch_kwargs["root_path"] = root_path
launch_kwargs["debug"] = True
interface.launch(**launch_kwargs)

Expand Down Expand Up @@ -172,6 +175,10 @@ def UI(**kwargs):
"--do_not_share", action="store_true", help="Do not share the gradio UI"
)

parser.add_argument(
"--root_path", type=str, default=None, help="`root_path` for Gradio to enable reverse proxy support. e.g. /kohya_ss"
)

args = parser.parse_args()

# Set up logging
Expand Down
2 changes: 1 addition & 1 deletion kohya_gui/class_advanced_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def full_options_update(full_fp16, full_bf16):
)
self.multires_noise_discount = gr.Slider(
label="Multires noise discount",
value=self.config.get("advanced.multires_noise_discount", 0),
value=self.config.get("advanced.multires_noise_discount", 0.3),
minimum=0,
maximum=1,
step=0.01,
Expand Down
5 changes: 4 additions & 1 deletion kohya_gui/class_source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ def list_dataset_config_dirs(path: str) -> list:

with gr.Accordion("Model", open=True):
with gr.Column(), gr.Group():
model_ext = gr.Textbox(value="*.safetensors *.ckpt", visible=False)
model_ext_name = gr.Textbox(value="Model types", visible=False)

# Define the input elements
with gr.Row():
with gr.Column(), gr.Row():
Expand Down Expand Up @@ -129,7 +132,7 @@ def list_dataset_config_dirs(path: str) -> list:
)
self.pretrained_model_name_or_path_file.click(
get_file_path,
inputs=self.pretrained_model_name_or_path,
inputs=[self.pretrained_model_name_or_path, model_ext, model_ext_name],
outputs=self.pretrained_model_name_or_path,
show_progress=False,
)
Expand Down
92 changes: 58 additions & 34 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,11 @@
huggingface = None
use_shell = False

TRAIN_BUTTON_VISIBLE = [gr.Button(visible=True), gr.Button(visible=False), gr.Textbox(value=time.time())]
TRAIN_BUTTON_VISIBLE = [
gr.Button(visible=True),
gr.Button(visible=False),
gr.Textbox(value=time.time()),
]


def save_configuration(
Expand Down Expand Up @@ -528,7 +532,7 @@ def train_model(
lr_warmup_steps = 0
else:
lr_warmup_steps = 0

if max_train_steps == 0:
max_train_steps_info = f"Max train steps: 0. sd-scripts will therefore default to 1600. Please specify a different value if required."
else:
Expand Down Expand Up @@ -589,9 +593,9 @@ def train_model(
"Regularisation images are used... Will double the number of steps required..."
)
reg_factor = 2

log.info(f"Regulatization factor: {reg_factor}")

if max_train_steps == 0:
# calculate max_train_steps
max_train_steps = int(
Expand All @@ -614,16 +618,16 @@ def train_model(
lr_warmup_steps = round(float(int(lr_warmup) * int(max_train_steps) / 100))
else:
lr_warmup_steps = 0

log.info(f"Total steps: {total_steps}")

log.info(f"Train batch size: {train_batch_size}")
log.info(f"Gradient accumulation steps: {gradient_accumulation_steps}")
log.info(f"Epoch: {epoch}")
log.info(max_train_steps_info)
log.info(f"lr_warmup_steps = {lr_warmup_steps}")

run_cmd = [fr'"{get_executable_path("accelerate")}"', "launch"]
run_cmd = [rf'"{get_executable_path("accelerate")}"', "launch"]

run_cmd = AccelerateLaunch.run_cmd(
run_cmd=run_cmd,
Expand All @@ -642,9 +646,9 @@ def train_model(
)

if sdxl:
run_cmd.append(fr'"{scriptdir}/sd-scripts/sdxl_train.py"')
run_cmd.append(rf'"{scriptdir}/sd-scripts/sdxl_train.py"')
else:
run_cmd.append(fr'"{scriptdir}/sd-scripts/train_db.py"')
run_cmd.append(rf'"{scriptdir}/sd-scripts/train_db.py"')

if max_data_loader_n_workers == "" or None:
max_data_loader_n_workers = 0
Expand All @@ -659,13 +663,6 @@ def train_model(
# def save_huggingface_to_toml(self, toml_file_path: str):
config_toml_data = {
# Update the values in the TOML data
"huggingface_repo_id": huggingface_repo_id,
"huggingface_token": huggingface_token,
"huggingface_repo_type": huggingface_repo_type,
"huggingface_repo_visibility": huggingface_repo_visibility,
"huggingface_path_in_repo": huggingface_path_in_repo,
"save_state_to_huggingface": save_state_to_huggingface,
"resume_from_huggingface": resume_from_huggingface,
"async_upload": async_upload,
"adaptive_noise_scale": adaptive_noise_scale if not 0 else None,
"bucket_no_upscale": bucket_no_upscale,
Expand All @@ -690,13 +687,24 @@ def train_model(
"gradient_checkpointing": gradient_checkpointing,
"huber_c": huber_c,
"huber_schedule": huber_schedule,
"huggingface_repo_id": huggingface_repo_id,
"huggingface_token": huggingface_token,
"huggingface_repo_type": huggingface_repo_type,
"huggingface_repo_visibility": huggingface_repo_visibility,
"huggingface_path_in_repo": huggingface_path_in_repo,
"ip_noise_gamma": ip_noise_gamma if ip_noise_gamma != 0 else None,
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
"keep_tokens": int(keep_tokens),
"learning_rate": learning_rate, # both for sd1.5 and sdxl
"learning_rate_te": learning_rate_te if not sdxl and not 0 else None, # only for sd1.5 and not 0
"learning_rate_te1": learning_rate_te1 if sdxl and not 0 else None, # only for sdxl and not 0
"learning_rate_te2": learning_rate_te2 if sdxl and not 0 else None, # only for sdxl and not 0
"learning_rate": learning_rate, # both for sd1.5 and sdxl
"learning_rate_te": (
learning_rate_te if not sdxl and not 0 else None
), # only for sd1.5 and not 0
"learning_rate_te1": (
learning_rate_te1 if sdxl and not 0 else None
), # only for sdxl and not 0
"learning_rate_te2": (
learning_rate_te2 if sdxl and not 0 else None
), # only for sdxl and not 0
"logging_dir": logging_dir,
"log_tracker_name": log_tracker_name,
"log_tracker_config": log_tracker_config,
Expand All @@ -709,8 +717,7 @@ def train_model(
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
"max_bucket_reso": max_bucket_reso,
"max_data_loader_n_workers": max_data_loader_n_workers,
"max_timestep": max_timestep if max_timestep!= 0 else None,
"max_timestep": max_timestep if max_timestep != 0 else None,
"max_token_length": int(max_token_length),
"max_train_epochs": max_train_epochs if max_train_epochs != 0 else None,
"max_train_steps": max_train_steps if max_train_steps != 0 else None,
Expand All @@ -725,9 +732,9 @@ def train_model(
"min_timestep": min_timestep if min_timestep != 0 else None,
"mixed_precision": mixed_precision,
"multires_noise_discount": multires_noise_discount,
"multires_noise_iterations": multires_noise_iterations if multires_noise_iterations != 0 else None,
"multires_noise_iterations": multires_noise_iterations if not 0 else None,
"no_token_padding": no_token_padding,
"noise_offset": noise_offset if noise_offset != 0 else None,
"noise_offset": noise_offset if not 0 else None,
"noise_offset_random_strength": noise_offset_random_strength,
"noise_offset_type": noise_offset_type,
"optimizer_type": optimizer,
Expand All @@ -745,23 +752,35 @@ def train_model(
"reg_data_dir": reg_data_dir,
"resolution": max_resolution,
"resume": resume,
"sample_every_n_epochs": sample_every_n_epochs if sample_every_n_epochs != 0 else None,
"sample_every_n_steps": sample_every_n_steps if sample_every_n_steps != 0 else None,
"resume_from_huggingface": resume_from_huggingface,
"sample_every_n_epochs": (
sample_every_n_epochs if sample_every_n_epochs != 0 else None
),
"sample_every_n_steps": (
sample_every_n_steps if sample_every_n_steps != 0 else None
),
"sample_prompts": create_prompt_file(sample_prompts, output_dir),
"sample_sampler": sample_sampler,
"save_every_n_epochs": save_every_n_epochs if save_every_n_epochs!= 0 else None,
"save_every_n_epochs": (
save_every_n_epochs if save_every_n_epochs != 0 else None
),
"save_every_n_steps": save_every_n_steps if save_every_n_steps != 0 else None,
"save_last_n_steps": save_last_n_steps if save_last_n_steps != 0 else None,
"save_last_n_steps_state": save_last_n_steps_state if save_last_n_steps_state != 0 else None,
"save_last_n_steps_state": (
save_last_n_steps_state if save_last_n_steps_state != 0 else None
),
"save_model_as": save_model_as,
"save_precision": save_precision,
"save_state": save_state,
"save_state_on_train_end": save_state_on_train_end,
"save_state_to_huggingface": save_state_to_huggingface,
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"sdpa": True if xformers == "sdpa" else None,
"seed": seed if seed != 0 else None,
"shuffle_caption": shuffle_caption,
"stop_text_encoder_training": stop_text_encoder_training if stop_text_encoder_training!= 0 else None,
"stop_text_encoder_training": (
stop_text_encoder_training if stop_text_encoder_training != 0 else None
),
"train_batch_size": train_batch_size,
"train_data_dir": train_data_dir,
"use_wandb": use_wandb,
Expand All @@ -781,8 +800,13 @@ def train_model(
config_toml_data = {
key: value
for key, value in config_toml_data.items()
if value != "" and value is not False
if value not in ["", False, None]
}

config_toml_data["max_data_loader_n_workers"] = max_data_loader_n_workers

# Sort the dictionary by keys
config_toml_data = dict(sorted(config_toml_data.items()))

tmpfilename = "./outputs/tmpfiledbooth.toml"
# Save the updated TOML data back to the file
Expand All @@ -793,7 +817,7 @@ def train_model(
log.error(f"Failed to write TOML file: {toml_file.name}")

run_cmd.append(f"--config_file")
run_cmd.append(fr'"{tmpfilename}"')
run_cmd.append(rf'"{tmpfilename}"')

# Initialize a dictionary with always-included keyword arguments
kwargs_for_training = {
Expand Down Expand Up @@ -851,7 +875,7 @@ def dreambooth_tab(
dummy_db_true = gr.Checkbox(value=True, visible=False)
dummy_db_false = gr.Checkbox(value=False, visible=False)
dummy_headless = gr.Checkbox(value=headless, visible=False)

global use_shell
use_shell = use_shell_flag

Expand All @@ -873,7 +897,7 @@ def dreambooth_tab(

with gr.Accordion("Metadata", open=False), gr.Group():
metadata = MetaData(config=config)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
Expand Down Expand Up @@ -1120,4 +1144,4 @@ def dreambooth_tab(
folders.reg_data_dir,
folders.output_dir,
folders.logging_dir,
)
)
Loading

0 comments on commit 6c69b89

Please sign in to comment.