Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 76 additions & 2 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class BaseSubsetParams:
token_warmup_min: int = 1
token_warmup_step: float = 0
custom_attributes: Optional[Dict[str, Any]] = None
validation_seed: int = 0
validation_split: float = 0.0


@dataclass
Expand Down Expand Up @@ -102,6 +104,8 @@ class BaseDatasetParams:
resolution: Optional[Tuple[int, int]] = None
network_multiplier: float = 1.0
debug_dataset: bool = False
validation_seed: Optional[int] = None
validation_split: float = 0.0


@dataclass
Expand Down Expand Up @@ -478,9 +482,27 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, **asdict(dataset_blueprint.params))
dataset = dataset_klass(subsets=subsets, is_train=True, **asdict(dataset_blueprint.params))
datasets.append(dataset)

val_datasets:List[Union[DreamBoothDataset, FineTuningDataset, ControlNetDataset]] = []
for dataset_blueprint in dataset_group_blueprint.datasets:
if dataset_blueprint.params.validation_split <= 0.0:
continue
if dataset_blueprint.is_controlnet:
subset_klass = ControlNetSubset
dataset_klass = ControlNetDataset
elif dataset_blueprint.is_dreambooth:
subset_klass = DreamBoothSubset
dataset_klass = DreamBoothDataset
else:
subset_klass = FineTuningSubset
dataset_klass = FineTuningDataset

subsets = [subset_klass(**asdict(subset_blueprint.params)) for subset_blueprint in dataset_blueprint.subsets]
dataset = dataset_klass(subsets=subsets, is_train=False, **asdict(dataset_blueprint.params))
val_datasets.append(dataset)

# print info
info = ""
for i, dataset in enumerate(datasets):
Expand Down Expand Up @@ -566,6 +588,50 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu

logger.info(f"{info}")

if len(val_datasets) > 0:
info = ""

for i, dataset in enumerate(val_datasets):
info += dedent(
f"""\
[Validation Dataset {i}]
batch_size: {dataset.batch_size}
resolution: {(dataset.width, dataset.height)}
enable_bucket: {dataset.enable_bucket}
network_multiplier: {dataset.network_multiplier}
"""
)

if dataset.enable_bucket:
info += indent(
dedent(
f"""\
min_bucket_reso: {dataset.min_bucket_reso}
max_bucket_reso: {dataset.max_bucket_reso}
bucket_reso_steps: {dataset.bucket_reso_steps}
bucket_no_upscale: {dataset.bucket_no_upscale}
\n"""
),
" ",
)
else:
info += "\n"

for j, subset in enumerate(dataset.subsets):
info += indent(
dedent(
f"""\
[Subset {j} of Validation Dataset {i}]
image_dir: "{subset.image_dir}"
image_count: {subset.img_count}
num_repeats: {subset.num_repeats}
"""
),
" ",
)

logger.info(f"{info}")

# make buckets first because it determines the length of dataset
# and set the same seed for all datasets
seed = random.randint(0, 2**31) # actual seed is seed + epoch_no
Expand All @@ -574,7 +640,15 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
dataset.make_buckets()
dataset.set_seed(seed)

return DatasetGroup(datasets)
for i, dataset in enumerate(val_datasets):
logger.info(f"[Validation Dataset {i}]")
dataset.make_buckets()
dataset.set_seed(seed)

return (
DatasetGroup(datasets),
DatasetGroup(val_datasets) if val_datasets else None
)


def generate_dreambooth_subsets_config_by_subdirs(train_data_dir: Optional[str] = None, reg_data_dir: Optional[str] = None):
Expand Down
54 changes: 52 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX = "_te_outputs.npz"
TEXT_ENCODER_OUTPUTS_CACHE_SUFFIX_SD3 = "_sd3_te.npz"

def split_train_val(paths: List[str], validation_split: float, validation_seed: int) -> List[str]:
if validation_seed is not None:
print(f"Using validation seed: {validation_seed}")
prevstate = random.getstate()
random.seed(validation_seed)
random.shuffle(paths)
random.setstate(prevstate)
else:
random.shuffle(paths)

return paths[len(paths) - round(len(paths) * validation_split):]

class ImageInfo:
def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, absolute_path: str) -> None:
Expand Down Expand Up @@ -397,6 +408,8 @@ def __init__(
token_warmup_min: int,
token_warmup_step: Union[float, int],
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
self.image_dir = image_dir
self.alpha_mask = alpha_mask if alpha_mask is not None else False
Expand Down Expand Up @@ -424,6 +437,9 @@ def __init__(

self.img_count = 0

self.validation_seed = validation_seed
self.validation_split = validation_split


class DreamBoothSubset(BaseSubset):
def __init__(
Expand Down Expand Up @@ -453,6 +469,8 @@ def __init__(
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

Expand All @@ -478,6 +496,8 @@ def __init__(
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)

self.is_reg = is_reg
Expand Down Expand Up @@ -518,6 +538,8 @@ def __init__(
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"

Expand All @@ -543,6 +565,8 @@ def __init__(
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)

self.metadata_file = metadata_file
Expand Down Expand Up @@ -579,6 +603,8 @@ def __init__(
token_warmup_min,
token_warmup_step,
custom_attributes: Optional[Dict[str, Any]] = None,
validation_seed: Optional[int] = None,
validation_split: Optional[float] = 0.0,
) -> None:
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"

Expand All @@ -604,6 +630,8 @@ def __init__(
token_warmup_min,
token_warmup_step,
custom_attributes=custom_attributes,
validation_seed=validation_seed,
validation_split=validation_split,
)

self.conditioning_data_dir = conditioning_data_dir
Expand Down Expand Up @@ -1799,6 +1827,9 @@ def __init__(
bucket_no_upscale: bool,
prior_loss_weight: float,
debug_dataset: bool,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)

Expand All @@ -1808,6 +1839,9 @@ def __init__(
self.size = min(self.width, self.height) # 短いほう
self.prior_loss_weight = prior_loss_weight
self.latents_cache = None
self.is_train = is_train
self.validation_seed = validation_seed
self.validation_split = validation_split

self.enable_bucket = enable_bucket
if self.enable_bucket:
Expand Down Expand Up @@ -1992,6 +2026,9 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
)
continue

if self.is_train == False:
img_paths = split_train_val(img_paths, self.validation_split, self.validation_seed)

if subset.is_reg:
num_reg_images += subset.num_repeats * len(img_paths)
else:
Expand All @@ -2009,7 +2046,11 @@ def load_dreambooth_dir(subset: DreamBoothSubset):
subset.img_count = len(img_paths)
self.subsets.append(subset)

logger.info(f"{num_train_images} train images with repeating.")
if self.is_train:
logger.info(f"{num_train_images} train images with repeating.")
else:
logger.info(f"{num_train_images} validation images with repeating.")

self.num_train_images = num_train_images

logger.info(f"{num_reg_images} reg images.")
Expand Down Expand Up @@ -2050,6 +2091,9 @@ def __init__(
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: bool,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)

Expand Down Expand Up @@ -2276,6 +2320,9 @@ def __init__(
bucket_reso_steps: int,
bucket_no_upscale: bool,
debug_dataset: float,
is_train: bool,
validation_seed: int,
validation_split: float,
) -> None:
super().__init__(resolution, network_multiplier, debug_dataset)

Expand Down Expand Up @@ -2324,6 +2371,9 @@ def __init__(
bucket_no_upscale,
1.0,
debug_dataset,
is_train,
validation_seed,
validation_split,
)

# config_util等から参照される値をいれておく(若干微妙なのでなんとかしたい)
Expand Down Expand Up @@ -4887,7 +4937,7 @@ def get_optimizer(args, trainable_params) -> tuple[str, str, object]:
import schedulefree as sf
except ImportError:
raise ImportError("No schedulefree / schedulefreeがインストールされていないようです")

if optimizer_type == "RAdamScheduleFree".lower():
optimizer_class = sf.RAdamScheduleFree
logger.info(f"use RAdamScheduleFree optimizer | {optimizer_kwargs}")
Expand Down
Loading