Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 31 additions & 43 deletions modules/ui/ModelTab.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __setup_hi_dream_ui(self, frame):
allow_legacy_safetensors=self.train_config.training_method == TrainingMethod.LORA,
)

def __create_dtype_options(self, include_none: bool=True, include_gguf: bool=False, include_a8: bool=False) -> list[tuple[str, DataType]]:
def __create_dtype_options(self, include_gguf: bool=False, include_a8: bool=False) -> list[tuple[str, DataType]]:
options = [
("float32", DataType.FLOAT_32),
("bfloat16", DataType.BFLOAT_16),
Expand All @@ -318,9 +318,6 @@ def __create_dtype_options(self, include_none: bool=True, include_gguf: bool=Fal
("GGUF A8 int", DataType.GGUF_A8_INT),
]

if include_none:
options.insert(0, ("", DataType.NONE))

return options

def __create_base_dtype_components(self, frame, row: int) -> int:
Expand All @@ -341,11 +338,10 @@ def __create_base_dtype_components(self, frame, row: int) -> int:
path_modifier=lambda x: Path(x).parent.absolute() if x.endswith(".json") else x
)

# weight dtype
components.label(frame, row, 3, "Weight Data Type",
tooltip="The base model weight data type used for training. This can reduce memory consumption, but reduces precision")
components.options_kv(frame, row, 4, self.__create_dtype_options(False),
self.ui_state, "weight_dtype")
# compile
components.label(frame, row, 3, "Compile transformer blocks",
tooltip="Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues.")
components.switch(frame, row, 4, self.ui_state, "compile")

row += 1

Expand All @@ -370,8 +366,8 @@ def __create_base_components(
) -> int:
if has_unet:
# unet weight dtype
components.label(frame, row, 3, "Override UNet Data Type",
tooltip="Overrides the unet weight data type")
components.label(frame, row, 3, "UNet Data Type",
tooltip="The unet weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(include_a8=True),
self.ui_state, "unet.weight_dtype")

Expand All @@ -388,8 +384,8 @@ def __create_base_components(
)

# prior weight dtype
components.label(frame, row, 3, "Override Prior Data Type",
tooltip="Overrides the prior weight data type")
components.label(frame, row, 3, "Prior Data Type",
tooltip="The prior weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "prior.weight_dtype")

Expand All @@ -406,8 +402,8 @@ def __create_base_components(
)

# transformer weight dtype
components.label(frame, row, 3, "Override Transformer Data Type",
tooltip="Overrides the transformer weight data type")
components.label(frame, row, 3, "Transformer Data Type",
tooltip="The transformer weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(include_gguf=True, include_a8=True),
self.ui_state, "transformer.weight_dtype")

Expand Down Expand Up @@ -439,14 +435,6 @@ def __create_base_components(
else:
presets = {"full": []}

# compile
components.label(frame, row, 3, "Compile transformer blocks",
tooltip="Uses torch.compile and Triton to significantly speed up training. Only applies to transformer/unet. Disable in case of compatibility issues.")
components.switch(frame, row, 4, self.ui_state, "compile")

row += 1


components.label(frame, row, 0, "Quantization")
components.layer_filter_entry(frame, row, 1, self.ui_state,
preset_var_name="quantization.layer_filter_preset", presets=presets,
Expand Down Expand Up @@ -476,35 +464,35 @@ def __create_base_components(

if has_text_encoder:
# text encoder weight dtype
components.label(frame, row, 3, "Override Text Encoder Data Type",
tooltip="Overrides the text encoder weight data type")
components.label(frame, row, 3, "Text Encoder Data Type",
tooltip="The text encoder weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "text_encoder.weight_dtype")

row += 1

if has_text_encoder_1:
# text encoder 1 weight dtype
components.label(frame, row, 3, "Override Text Encoder 1 Data Type",
tooltip="Overrides the text encoder 1 weight data type")
components.label(frame, row, 3, "Text Encoder 1 Data Type",
tooltip="The text encoder 1 weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "text_encoder.weight_dtype")

row += 1

if has_text_encoder_2:
# text encoder 2 weight dtype
components.label(frame, row, 3, "Override Text Encoder 2 Data Type",
tooltip="Overrides the text encoder 2 weight data type")
components.label(frame, row, 3, "Text Encoder 2 Data Type",
tooltip="The text encoder 2 weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "text_encoder_2.weight_dtype")

row += 1

if has_text_encoder_3:
# text encoder 3 weight dtype
components.label(frame, row, 3, "Override Text Encoder 3 Data Type",
tooltip="Overrides the text encoder 3 weight data type")
components.label(frame, row, 3, "Text Encoder 3 Data Type",
tooltip="The text encoder 3 weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "text_encoder_3.weight_dtype")

Expand All @@ -521,8 +509,8 @@ def __create_base_components(
)

# text encoder 4 weight dtype
components.label(frame, row, 3, "Override Text Encoder 4 Data Type",
tooltip="Overrides the text encoder 4 weight data type")
components.label(frame, row, 3, "Text Encoder 4 Data Type",
tooltip="The text encoder 4 weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "text_encoder_4.weight_dtype")

Expand All @@ -538,8 +526,8 @@ def __create_base_components(
)

# vae weight dtype
components.label(frame, row, 3, "Override VAE Data Type",
tooltip="Overrides the vae weight data type")
components.label(frame, row, 3, "VAE Data Type",
tooltip="The vae weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "vae.weight_dtype")

Expand All @@ -557,8 +545,8 @@ def __create_effnet_encoder_components(self, frame, row: int):
)

# effnet encoder weight dtype
components.label(frame, row, 3, "Override Effnet Encoder Data Type",
tooltip="Overrides the effnet encoder weight data type")
components.label(frame, row, 3, "Effnet Encoder Data Type",
tooltip="The effnet encoder weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "effnet_encoder.weight_dtype")

Expand All @@ -581,25 +569,25 @@ def __create_decoder_components(
)

# decoder weight dtype
components.label(frame, row, 3, "Override Decoder Data Type",
tooltip="Overrides the decoder weight data type")
components.label(frame, row, 3, "Decoder Data Type",
tooltip="The decoder weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "decoder.weight_dtype")

row += 1

if has_text_encoder:
# decoder text encoder weight dtype
components.label(frame, row, 3, "Override Decoder Text Encoder Data Type",
tooltip="Overrides the decoder text encoder weight data type")
components.label(frame, row, 3, "Decoder Text Encoder Data Type",
tooltip="The decoder text encoder weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "decoder_text_encoder.weight_dtype")

row += 1

# decoder vqgan weight dtype
components.label(frame, row, 3, "Override Decoder VQGAN Data Type",
tooltip="Overrides the decoder vqgan weight data type")
components.label(frame, row, 3, "Decoder VQGAN Data Type",
tooltip="The decoder vqgan weight data type")
components.options_kv(frame, row, 4, self.__create_dtype_options(),
self.ui_state, "decoder_vqgan.weight_dtype")

Expand Down
69 changes: 39 additions & 30 deletions modules/util/config/TrainConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def default_values():
data.append(("stop_training_after", None, int, True))
data.append(("stop_training_after_unit", TimeUnit.NEVER, TimeUnit, False))
data.append(("learning_rate", None, float, True))
data.append(("weight_dtype", DataType.NONE, DataType, False))
data.append(("weight_dtype", DataType.FLOAT_32, DataType, False))
data.append(("dropout_probability", 0.0, float, False))
data.append(("train_embedding", True, bool, False))
data.append(("attention_mask", False, bool, False))
Expand Down Expand Up @@ -342,7 +342,6 @@ class TrainConfig(BaseConfig):

# model settings
base_model_name: str
weight_dtype: DataType
output_dtype: DataType
output_model_format: ModelFormat
output_model_destination: str
Expand Down Expand Up @@ -532,7 +531,7 @@ class TrainConfig(BaseConfig):
def __init__(self, data: list[(str, Any, type, bool)]):
super().__init__(
data,
config_version=9,
config_version=10,
config_migrations={
0: self.__migration_0,
1: self.__migration_1,
Expand All @@ -543,6 +542,7 @@ def __init__(self, data: list[(str, Any, type, bool)]):
6: self.__migration_6,
7: self.__migration_7,
8: self.__migration_8,
9: self.__migration_9,
}
)

Expand Down Expand Up @@ -740,24 +740,46 @@ def __migration_8(self, data: dict) -> dict:

return migrated_data

def __migration_9(self, data: dict) -> dict:
migrated_data = data.copy()

def replace_dtype(part: str):
if part in migrated_data and migrated_data[part]["weight_dtype"] == DataType.NONE:
migrated_data[part]["weight_dtype"] = migrated_data["weight_dtype"]
replace_dtype("unet")
replace_dtype("prior")
replace_dtype("transformer")
replace_dtype("text_encoder")
replace_dtype("text_encoder_2")
replace_dtype("text_encoder_3")
replace_dtype("text_encoder_4")
replace_dtype("vae")
replace_dtype("effnet_encoder")
replace_dtype("decoder")
replace_dtype("decoder_text_encoder")
replace_dtype("decoder_vqgan")
migrated_data.pop("weight_dtype")

return migrated_data

def weight_dtypes(self) -> ModelWeightDtypes:
return ModelWeightDtypes(
self.train_dtype,
self.fallback_train_dtype,
self.weight_dtype if self.unet.weight_dtype == DataType.NONE else self.unet.weight_dtype,
self.weight_dtype if self.prior.weight_dtype == DataType.NONE else self.prior.weight_dtype,
self.weight_dtype if self.transformer.weight_dtype == DataType.NONE else self.transformer.weight_dtype,
self.weight_dtype if self.text_encoder.weight_dtype == DataType.NONE else self.text_encoder.weight_dtype,
self.weight_dtype if self.text_encoder_2.weight_dtype == DataType.NONE else self.text_encoder_2.weight_dtype,
self.weight_dtype if self.text_encoder_3.weight_dtype == DataType.NONE else self.text_encoder_3.weight_dtype,
self.weight_dtype if self.text_encoder_4.weight_dtype == DataType.NONE else self.text_encoder_4.weight_dtype,
self.weight_dtype if self.vae.weight_dtype == DataType.NONE else self.vae.weight_dtype,
self.weight_dtype if self.effnet_encoder.weight_dtype == DataType.NONE else self.effnet_encoder.weight_dtype,
self.weight_dtype if self.decoder.weight_dtype == DataType.NONE else self.decoder.weight_dtype,
self.weight_dtype if self.decoder_text_encoder.weight_dtype == DataType.NONE else self.decoder_text_encoder.weight_dtype,
self.weight_dtype if self.decoder_vqgan.weight_dtype == DataType.NONE else self.decoder_vqgan.weight_dtype,
self.weight_dtype if self.lora_weight_dtype == DataType.NONE else self.lora_weight_dtype,
self.weight_dtype if self.embedding_weight_dtype == DataType.NONE else self.embedding_weight_dtype,
self.unet.weight_dtype,
self.prior.weight_dtype,
self.transformer.weight_dtype,
self.text_encoder.weight_dtype,
self.text_encoder_2.weight_dtype,
self.text_encoder_3.weight_dtype,
self.text_encoder_4.weight_dtype,
self.vae.weight_dtype,
self.effnet_encoder.weight_dtype,
self.decoder.weight_dtype,
self.decoder_text_encoder.weight_dtype,
self.decoder_vqgan.weight_dtype,
self.lora_weight_dtype,
self.embedding_weight_dtype,
)

def model_names(self) -> ModelNames:
Expand Down Expand Up @@ -905,7 +927,6 @@ def default_values() -> 'TrainConfig':

# model settings
data.append(("base_model_name", "stable-diffusion-v1-5/stable-diffusion-v1-5", str, False))
data.append(("weight_dtype", DataType.FLOAT_32, DataType, False))
data.append(("output_dtype", DataType.FLOAT_32, DataType, False))
data.append(("output_model_format", ModelFormat.SAFETENSORS, ModelFormat, False))
data.append(("output_model_destination", "models/model.safetensors", str, False))
Expand Down Expand Up @@ -980,7 +1001,6 @@ def default_values() -> 'TrainConfig':
unet.train = True
unet.stop_training_after = 0
unet.learning_rate = None
unet.weight_dtype = DataType.NONE
data.append(("unet", unet, TrainModelPartConfig, False))

# prior
Expand All @@ -989,7 +1009,6 @@ def default_values() -> 'TrainConfig':
prior.train = True
prior.stop_training_after = 0
prior.learning_rate = None
prior.weight_dtype = DataType.NONE
data.append(("prior", prior, TrainModelPartConfig, False))

# transformer
Expand All @@ -998,7 +1017,6 @@ def default_values() -> 'TrainConfig':
transformer.train = True
transformer.stop_training_after = 0
transformer.learning_rate = None
transformer.weight_dtype = DataType.NONE
data.append(("transformer", transformer, TrainModelPartConfig, False))

#quantization layer filter
Expand All @@ -1011,7 +1029,6 @@ def default_values() -> 'TrainConfig':
text_encoder.stop_training_after = 30
text_encoder.stop_training_after_unit = TimeUnit.EPOCH
text_encoder.learning_rate = None
text_encoder.weight_dtype = DataType.NONE
data.append(("text_encoder", text_encoder, TrainModelPartConfig, False))
data.append(("text_encoder_layer_skip", 0, int, False))

Expand All @@ -1021,7 +1038,6 @@ def default_values() -> 'TrainConfig':
text_encoder_2.stop_training_after = 30
text_encoder_2.stop_training_after_unit = TimeUnit.EPOCH
text_encoder_2.learning_rate = None
text_encoder_2.weight_dtype = DataType.NONE
data.append(("text_encoder_2", text_encoder_2, TrainModelPartConfig, False))
data.append(("text_encoder_2_layer_skip", 0, int, False))
data.append(("text_encoder_2_sequence_length", 77, int, True))
Expand All @@ -1032,7 +1048,6 @@ def default_values() -> 'TrainConfig':
text_encoder_3.stop_training_after = 30
text_encoder_3.stop_training_after_unit = TimeUnit.EPOCH
text_encoder_3.learning_rate = None
text_encoder_3.weight_dtype = DataType.NONE
data.append(("text_encoder_3", text_encoder_3, TrainModelPartConfig, False))
data.append(("text_encoder_3_layer_skip", 0, int, False))

Expand All @@ -1042,36 +1057,30 @@ def default_values() -> 'TrainConfig':
text_encoder_4.stop_training_after = 30
text_encoder_4.stop_training_after_unit = TimeUnit.EPOCH
text_encoder_4.learning_rate = None
text_encoder_4.weight_dtype = DataType.NONE
data.append(("text_encoder_4", text_encoder_4, TrainModelPartConfig, False))
data.append(("text_encoder_4_layer_skip", 0, int, False))

# vae
vae = TrainModelPartConfig.default_values()
vae.model_name = ""
vae.weight_dtype = DataType.FLOAT_32
data.append(("vae", vae, TrainModelPartConfig, False))

# effnet encoder
effnet_encoder = TrainModelPartConfig.default_values()
effnet_encoder.model_name = ""
effnet_encoder.weight_dtype = DataType.NONE
data.append(("effnet_encoder", effnet_encoder, TrainModelPartConfig, False))

# decoder
decoder = TrainModelPartConfig.default_values()
decoder.model_name = ""
decoder.weight_dtype = DataType.NONE
data.append(("decoder", decoder, TrainModelPartConfig, False))

# decoder text encoder
decoder_text_encoder = TrainModelPartConfig.default_values()
decoder_text_encoder.weight_dtype = DataType.NONE
data.append(("decoder_text_encoder", decoder_text_encoder, TrainModelPartConfig, False))

# decoder vqgan
decoder_vqgan = TrainModelPartConfig.default_values()
decoder_vqgan.weight_dtype = DataType.NONE
data.append(("decoder_vqgan", decoder_vqgan, TrainModelPartConfig, False))

# masked training
Expand Down
4 changes: 2 additions & 2 deletions training_presets/#flux LoRA.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"weight_dtype": "NFLOAT_4"
},
"text_encoder": {
"train": false
"train": false,
"weight_dtype": "BFLOAT_16"
},
"text_encoder_2": {
"train": false,
Expand All @@ -23,7 +24,6 @@
"weight_dtype": "FLOAT_32"
},
"train_dtype": "BFLOAT_16",
"weight_dtype": "BFLOAT_16",
"timestep_distribution": "LOGIT_NORMAL",
"dynamic_timestep_shifting": false
}
Loading