Skip to content

Commit

Permalink
Add support for metadata parameters (bmaltais#2295)
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais authored Apr 15, 2024
1 parent a8320e3 commit a22d462
Show file tree
Hide file tree
Showing 9 changed files with 246 additions and 176 deletions.
2 changes: 1 addition & 1 deletion .release
Original file line number Diff line number Diff line change
@@ -1 +1 @@
v23.1.6
v24.0.0
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ The GUI allows you to set the training parameters and generate and run the requi
- [SDXL training](#sdxl-training)
- [Masked loss](#masked-loss)
- [Change History](#change-history)
- [2024/04/12 (v23.1.6)](#20240412-v2316)
- [2024/04/12 (v24.0.0)](#20240412-v2400)
- [2024/04/10 (v23.1.5)](#20240410-v2315)
- [Security Improvements](#security-improvements)
- [2024/04/08 (v23.1.4)](#20240408-v2314)
Expand Down Expand Up @@ -402,7 +402,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG

## Change History

### 2024/04/12 (v23.1.6)
### 2024/04/12 (v24.0.0)

- Rewrote significant portions of the code to address security vulnerabilities and remove the `shell=True` parameter from process calls.
- Enhanced the training and tensorboard buttons to provide a more intuitive and user-friendly experience.
Expand All @@ -411,6 +411,7 @@ ControlNet dataset is used to specify the mask. The mask images should be the RG
- Converted the Graphical User Interface (GUI) to use the configuration TOML file format to pass arguments to sd-scripts. This change improves security by eliminating the need for sensitive information to be passed through the command-line interface (CLI).
- Made various other minor improvements and bug fixes to enhance the overall functionality and user experience.
- Disabled LR Warmup when using the Constant LR Scheduler to prevent traceback errors with sd-scripts.
- Added support for metadata capture to the GUI

### 2024/04/10 (v23.1.5)

Expand Down
7 changes: 7 additions & 0 deletions config example.toml
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,10 @@ train_data_dir = "" # Image folder to caption (contain
undesired_tags = "" # comma-separated list of tags to remove, e.g. 1girl,1boy
use_rating_tags = false # Use rating tags
use_rating_tags_as_last_tag = false # Use rating tags as last tagging tags

[metadata]
metadata_title = "" # Title for model metadata (default is output_name)
metadata_author = "" # Author name for model metadata
metadata_description = "" # Description for model metadata
metadata_license = "" # License for model metadata
metadata_tags = "" # Tags for model metadata
195 changes: 45 additions & 150 deletions kohya_gui/class_metadata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import gradio as gr
import os
import shlex

from .class_gui_config import KohyaSSGUIConfig

Expand All @@ -12,161 +10,58 @@ def __init__(
) -> None:
self.config = config

with gr.Accordion("Resource Selection", open=True):
with gr.Row():
self.mixed_precision = gr.Dropdown(
label="Mixed precision",
choices=["no", "fp16", "bf16", "fp8"],
value=self.config.get("accelerate_launch.mixed_precision", "fp16"),
info="Whether or not to use mixed precision training.",
)
self.num_processes = gr.Number(
label="Number of processes",
value=self.config.get("accelerate_launch.num_processes", 1),
precision=0,
minimum=1,
info="The total number of processes to be launched in parallel.",
)
self.num_machines = gr.Number(
label="Number of machines",
value=self.config.get("accelerate_launch.num_machines", 1),
precision=0,
minimum=1,
info="The total number of machines used in this training.",
)
self.num_cpu_threads_per_process = gr.Slider(
minimum=1,
maximum=os.cpu_count(),
step=1,
label="Number of CPU threads per core",
value=self.config.get(
"accelerate_launch.num_cpu_threads_per_process", 2
),
info="The number of CPU threads per process.",
)
with gr.Row():
self.dynamo_backend = gr.Dropdown(
label="Dynamo backend",
choices=[
"no",
"eager",
"aot_eager",
"inductor",
"aot_ts_nvfuser",
"nvprims_nvfuser",
"cudagraphs",
"ofi",
"fx2trt",
"onnxrt",
"tensorrt",
"ipex",
"tvm",
],
value=self.config.get("accelerate_launch.dynamo_backend", "no"),
info="The backend to use for the dynamo JIT compiler.",
)
self.dynamo_mode = gr.Dropdown(
label="Dynamo mode",
choices=[
"default",
"reduce-overhead",
"max-autotune",
],
value=self.config.get("accelerate_launch.dynamo_mode", "default"),
info="Choose a mode to optimize your training with dynamo.",
)
self.dynamo_use_fullgraph = gr.Checkbox(
label="Dynamo use fullgraph",
value=self.config.get("accelerate_launch.dynamo_use_fullgraph", False),
info="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
)
self.dynamo_use_dynamic = gr.Checkbox(
label="Dynamo use dynamic",
value=self.config.get("accelerate_launch.dynamo_use_dynamic", False),
info="Whether to enable dynamic shape tracing.",
)

with gr.Accordion("Hardware Selection", open=True):
with gr.Row():
self.multi_gpu = gr.Checkbox(
label="Multi GPU",
value=self.config.get("accelerate_launch.multi_gpu", False),
info="Whether or not this should launch a distributed GPU training.",
)
with gr.Accordion("Distributed GPUs", open=True):
with gr.Row():
self.gpu_ids = gr.Textbox(
label="GPU IDs",
value=self.config.get("accelerate_launch.gpu_ids", ""),
placeholder="example: 0,1",
info=" What GPUs (by id) should be used for training on this machine as a comma-separated list",
)
self.main_process_port = gr.Number(
label="Main process port",
value=self.config.get("accelerate_launch.main_process_port", 0),
precision=1,
minimum=0,
maximum=65535,
info="The port to use to communicate with the machine of rank 0.",
)
with gr.Row():
self.extra_accelerate_launch_args = gr.Textbox(
label="Extra accelerate launch arguments",
value=self.config.get(
"accelerate_launch.extra_accelerate_launch_args", ""
),
placeholder="example: --same_network --machine_rank 4",
info="List of extra parameters to pass to accelerate launch",
self.metadata_title = gr.Textbox(
label="Metadata title",
placeholder="(optional) title for model metadata (default is output_name)",
interactive=True,
value=self.config.get("metadata.title", ""),
)
self.metadata_author = gr.Textbox(
label="Metadata author",
placeholder="(optional) author name for model metadata",
interactive=True,
value=self.config.get("metadata.author", ""),
)
self.metadata_description = gr.Textbox(
label="Metadata description",
placeholder="(optional) description for model metadata",
interactive=True,
value=self.config.get("metadata.description", ""),
)
with gr.Row():
self.metadata_license = gr.Textbox(
label="Metadata license",
placeholder="(optional) license for model metadata",
interactive=True,
value=self.config.get("metadata.license", ""),
)
self.metadata_tags = gr.Textbox(
label="Metadata tags",
placeholder="(optional) tags for model metadata, separated by comma",
interactive=True,
value=self.config.get("metadata.tags", ""),
)

def run_cmd(run_cmd: list, **kwargs):
if "dynamo_backend" in kwargs and kwargs.get("dynamo_backend"):
run_cmd.append("--dynamo_backend")
run_cmd.append(kwargs["dynamo_backend"])

if "dynamo_mode" in kwargs and kwargs.get("dynamo_mode"):
run_cmd.append("--dynamo_mode")
run_cmd.append(kwargs["dynamo_mode"])

if "dynamo_use_fullgraph" in kwargs and kwargs.get("dynamo_use_fullgraph"):
run_cmd.append("--dynamo_use_fullgraph")

if "dynamo_use_dynamic" in kwargs and kwargs.get("dynamo_use_dynamic"):
run_cmd.append("--dynamo_use_dynamic")

if "extra_accelerate_launch_args" in kwargs and kwargs["extra_accelerate_launch_args"] != "":
extra_accelerate_launch_args = kwargs["extra_accelerate_launch_args"].replace('"', "")
for arg in extra_accelerate_launch_args.split():
run_cmd.append(shlex.quote(arg))

if "gpu_ids" in kwargs and kwargs.get("gpu_ids") != "":
run_cmd.append("--gpu_ids")
run_cmd.append(shlex.quote(kwargs["gpu_ids"]))

if "main_process_port" in kwargs and kwargs.get("main_process_port", 0) > 0:
run_cmd.append("--main_process_port")
run_cmd.append(str(int(kwargs["main_process_port"])))

if "mixed_precision" in kwargs and kwargs.get("mixed_precision"):
run_cmd.append("--mixed_precision")
run_cmd.append(shlex.quote(kwargs["mixed_precision"]))
if "metadata_title" in kwargs and kwargs.get("metadata_title") != "":
run_cmd.append("--metadata_title")
run_cmd.append(kwargs["metadata_title"])

if "multi_gpu" in kwargs and kwargs.get("multi_gpu"):
run_cmd.append("--multi_gpu")
if "metadata_author" in kwargs and kwargs.get("metadata_author") != "":
run_cmd.append("--metadata_author")
run_cmd.append(kwargs["metadata_author"])

if "num_processes" in kwargs and int(kwargs.get("num_processes", 0)) > 0:
run_cmd.append("--num_processes")
run_cmd.append(str(int(kwargs["num_processes"])))
if "metadata_description" in kwargs and kwargs.get("metadata_description") != "":
run_cmd.append("--metadata_description")
run_cmd.append(kwargs["metadata_description"])

if "num_machines" in kwargs and int(kwargs.get("num_machines", 0)) > 0:
run_cmd.append("--num_machines")
run_cmd.append(str(int(kwargs["num_machines"])))
if "metadata_license" in kwargs and kwargs.get("metadata_license") != "":
run_cmd.append("--metadata_license")
run_cmd.append(kwargs["metadata_license"])

if (
"num_cpu_threads_per_process" in kwargs
and int(kwargs.get("num_cpu_threads_per_process", 0)) > 0
):
run_cmd.append("--num_cpu_threads_per_process")
run_cmd.append(str(int(kwargs["num_cpu_threads_per_process"])))
if "metadata_tags" in kwargs and kwargs.get("metadata_tags") != "":
run_cmd.append("--metadata_tags")
run_cmd.append(kwargs["metadata_tags"])

return run_cmd
37 changes: 33 additions & 4 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from .class_folders import Folders
from .class_command_executor import CommandExecutor
from .class_huggingface import HuggingFace
from .class_metadata import MetaData

from .dreambooth_folder_creation_gui import (
gradio_dreambooth_folder_creation_tab,
Expand Down Expand Up @@ -172,6 +173,11 @@ def save_configuration(
save_state_to_huggingface,
resume_from_huggingface,
async_upload,
metadata_author,
metadata_description,
metadata_license,
metadata_tags,
metadata_title,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -325,6 +331,11 @@ def open_configuration(
save_state_to_huggingface,
resume_from_huggingface,
async_upload,
metadata_author,
metadata_description,
metadata_license,
metadata_tags,
metadata_title,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -473,6 +484,11 @@ def train_model(
save_state_to_huggingface,
resume_from_huggingface,
async_upload,
metadata_author,
metadata_description,
metadata_license,
metadata_tags,
metadata_title,
):
# Get list of function parameters and values
parameters = list(locals().items())
Expand Down Expand Up @@ -681,10 +697,10 @@ def train_model(
"ip_noise_gamma": ip_noise_gamma,
"ip_noise_gamma_random_strength": ip_noise_gamma_random_strength,
"keep_tokens": int(keep_tokens),
"learning_rate": learning_rate,
"learning_rate_te": learning_rate_te,
"learning_rate_te1": learning_rate_te1,
"learning_rate_te2": learning_rate_te2,
"learning_rate": learning_rate, # both for sd1.5 and sdxl
"learning_rate_te": learning_rate_te if not sdxl else None, # only for sd1.5
"learning_rate_te1": learning_rate_te1 if sdxl else None, # only for sdxl
"learning_rate_te2": learning_rate_te2 if sdxl else None, # only for sdxl
"logging_dir": logging_dir,
"log_tracker_name": log_tracker_name,
"log_tracker_config": log_tracker_config,
Expand All @@ -703,6 +719,11 @@ def train_model(
"max_train_epochs": max_train_epochs,
"max_train_steps": int(max_train_steps),
"mem_eff_attn": mem_eff_attn,
"metadata_author": metadata_author,
"metadata_description": metadata_description,
"metadata_license": metadata_license,
"metadata_tags": metadata_tags,
"metadata_title": metadata_title,
"min_bucket_reso": int(min_bucket_reso),
"min_snr_gamma": min_snr_gamma,
"min_timestep": int(min_timestep),
Expand Down Expand Up @@ -859,6 +880,9 @@ def dreambooth_tab(
with gr.Accordion("Folders", open=False), gr.Group():
folders = Folders(headless=headless, config=config)

with gr.Accordion("Metadata", open=False), gr.Group():
metadata = MetaData(config=config)

with gr.Accordion("Dataset Preparation", open=False):
gr.Markdown(
"This section provide Dreambooth tools to help setup your dataset..."
Expand Down Expand Up @@ -1034,6 +1058,11 @@ def dreambooth_tab(
huggingface.save_state_to_huggingface,
huggingface.resume_from_huggingface,
huggingface.async_upload,
metadata.metadata_author,
metadata.metadata_description,
metadata.metadata_license,
metadata.metadata_tags,
metadata.metadata_title,
]

configuration.button_open_config.click(
Expand Down
Loading

0 comments on commit a22d462

Please sign in to comment.