Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
1e8fece
Compile int svd
dxqb Oct 4, 2025
bbf6eb8
- fix cache dir
dxqb Oct 5, 2025
2676d33
hide checkpoints from LoRA saving
dxqb Oct 5, 2025
dc289fa
fix buffer registration
dxqb Oct 6, 2025
73822b0
fix buffer registration
dxqb Oct 6, 2025
4821c9e
various
dxqb Oct 13, 2025
cb02a4c
various
dxqb Oct 13, 2025
efb9073
various
dxqb Oct 13, 2025
35ba023
cleanup
dxqb Oct 13, 2025
5633b21
torch.compile bug workaround
dxqb Oct 14, 2025
c37f805
same workaround for Qwen
dxqb Oct 14, 2025
d7532dc
gguf
dxqb Oct 15, 2025
6883981
merge
dxqb Oct 15, 2025
2de19c9
gguf
dxqb Oct 15, 2025
a3cd936
bugfix
dxqb Oct 15, 2025
bcf0b65
requirements
dxqb Oct 15, 2025
0dcb3dc
merge
dxqb Oct 15, 2025
5a2e590
name changes, axis wise
dxqb Oct 16, 2025
44a969b
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 16, 2025
882154d
merge
dxqb Oct 16, 2025
e5317d3
big type hint
dxqb Oct 16, 2025
7f07748
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 16, 2025
c8cb33b
use axis-wise quantization for both forward and backward
dxqb Oct 16, 2025
c6011d0
Merge branch 'upstream' into gguf
dxqb Oct 16, 2025
b772d42
merge
dxqb Oct 16, 2025
3167b90
merge
dxqb Oct 16, 2025
2d4a0c3
initial
dxqb Oct 16, 2025
68f0f71
merge #1060
dxqb Oct 16, 2025
0e282a2
merge
dxqb Oct 16, 2025
71af1f0
ui fix
dxqb Oct 16, 2025
1606c85
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 16, 2025
0f58a5e
GGUF with DoRA
dxqb Oct 16, 2025
8ca5782
GGUF with DoRA
dxqb Oct 16, 2025
881e7e5
GGUF A8 float bugfix
dxqb Oct 17, 2025
25ccc0c
improve check for #1050
dxqb Oct 17, 2025
cfc1492
Merge branch 't5' into compile_int8_svd
dxqb Oct 17, 2025
b4d8f30
improve check for #1050
dxqb Oct 17, 2025
827fa11
improve check for #1050
dxqb Oct 17, 2025
da296d6
re-enabled int W8A8
dxqb Oct 17, 2025
f273ac3
Merge branch 'compile_int8_svd' into compile_int8_svd_gguf
dxqb Oct 17, 2025
a3de776
merge
dxqb Oct 23, 2025
7351e9a
Merge branch 'upstream' into compile_int8_svd
dxqb Oct 26, 2025
56902b6
only quantize activations if GGUF weights are actually quantized
dxqb Oct 28, 2025
eaf4fe2
make layer filter a component
dxqb Nov 2, 2025
867b84c
quantization layer filter
dxqb Nov 2, 2025
e5d0317
add blocks preset
dxqb Nov 2, 2025
1f45c04
Merge branch 'blocks' into quant_layer_filter
dxqb Nov 2, 2025
30b7cca
merge
dxqb Nov 2, 2025
0242d16
quantization filter in presets
dxqb Nov 2, 2025
e5f5b0b
Merge branch 'quant_layer_filter' into compile_int8_svd
dxqb Nov 2, 2025
9e897b8
#1054
dxqb Nov 2, 2025
9aa7973
Merge branch 'config-prefix' into compile_int8_svd
dxqb Nov 2, 2025
3448801
bugfix
dxqb Nov 2, 2025
40d61c5
Merge branch 'quant_layer_filter' into compile_int8_svd
dxqb Nov 2, 2025
cf70b7a
Merge branch 'upstream' into compile_int8_svd
dxqb Nov 2, 2025
5bc6c5a
smaller eps, because gradients for some models are close to 1e-12
dxqb Nov 3, 2025
fb1e8a8
compile benchmarks
dxqb Nov 4, 2025
4ca84db
remove cast
dxqb Nov 4, 2025
41e44a2
detach dequantized weights
dxqb Nov 4, 2025
b3f69ae
name changes
dxqb Nov 7, 2025
f9c12a8
move code
dxqb Nov 7, 2025
5db4161
fix circular dependency
dxqb Nov 7, 2025
cd3f971
ensure contiguous grad output
dxqb Nov 7, 2025
9159d24
avoid attention mask
dxqb Nov 8, 2025
7760af3
Merge branch 'avoid_attn_mask' into compile_int8_svd
dxqb Nov 8, 2025
9559ebf
disable bug workaround - can currently not be reproduced and because …
dxqb Nov 8, 2025
49d2bc4
pad sequence length if an attention mask is necessary anyway
dxqb Nov 9, 2025
2ecf834
merge
dxqb Nov 9, 2025
dc71ef5
Merge branch 'upstream' into compile_int8_svd
dxqb Nov 11, 2025
c756b2c
merge
dxqb Nov 14, 2025
d6bb1ff
Merge branch 'upstream' into compile_int8_svd
dxqb Nov 14, 2025
27dc59d
merge
dxqb Nov 14, 2025
69f0fa1
merge fix
dxqb Nov 14, 2025
1739bfe
Fixes [Bug]: Layer filter isn't configured correct if a preset is loaded
O-J1 Nov 16, 2025
968b2a9
Simplify tooltip text for layer filter
dxqb Nov 16, 2025
6f2d2f5
Tweak tooltip text a little more
O-J1 Nov 16, 2025
ea29f8f
Merge branch 'upstream' into quant_layer_filter
dxqb Nov 16, 2025
1a7c5a6
initial
dxqb Nov 16, 2025
606b7a8
merge with #1139
dxqb Nov 22, 2025
8ce4604
merge with upstream
dxqb Nov 22, 2025
50982b7
UI update
dxqb Nov 22, 2025
16a6015
UI update
dxqb Nov 22, 2025
fde79ce
UI change and merge
dxqb Nov 22, 2025
b9bf5b4
fix comment
dxqb Nov 22, 2025
261c32c
comment fix
dxqb Nov 23, 2025
6128c52
merge svd and LoRa for efficiency
dxqb Nov 23, 2025
193f5ad
merge svd and LoRa for efficiency
dxqb Nov 23, 2025
d6c6666
change default - 16 is sufficient
dxqb Nov 23, 2025
355f522
merge #1143
dxqb Nov 24, 2025
91fef92
Merge branch 'quant_layer_filter' into compile_int8_svd
dxqb Nov 24, 2025
171629c
switch to quantization config
dxqb Nov 24, 2025
605c34b
remove debug print
dxqb Nov 24, 2025
5729eca
remove debug print
dxqb Nov 24, 2025
ac2e063
- scale matmul result in float32 - more accurate with no performance…
dxqb Nov 24, 2025
d98b2b5
allow A8 for unet
dxqb Nov 24, 2025
8a53199
OFT workaround for torch.compile slicing issue (#23)
dxqb Nov 27, 2025
a767fe5
fix GGUF A8 float
dxqb Nov 27, 2025
f3d9373
fix quantization config
dxqb Nov 27, 2025
9ca1ea7
merge
dxqb Nov 27, 2025
47e330e
Revert "merge"
dxqb Nov 27, 2025
49427ff
merge
dxqb Nov 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions modules/modelLoader/GenericEmbeddingModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


def make_embedding_model_loader(
Expand Down Expand Up @@ -33,7 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
quantization: QuantizationConfig,
) -> model_class | None:
base_model_loader = model_loader_class()
embedding_loader = embedding_loader_class()
Expand All @@ -43,7 +43,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quantization)
embedding_loader.load(model, model_names.embedding.model_name, model_names)

return model
Expand Down
6 changes: 3 additions & 3 deletions modules/modelLoader/GenericFineTuneModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


def make_fine_tune_model_loader(
Expand Down Expand Up @@ -33,7 +33,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
quantization: QuantizationConfig,
) -> model_class | None:
base_model_loader = model_loader_class()
if embedding_loader_class is not None:
Expand All @@ -44,7 +44,7 @@ def load(
self._load_internal_data(model, model_names.base_model)
model.model_spec = self._load_default_model_spec(model_type)

base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quantization)
if embedding_loader_class is not None:
embedding_loader.load(model, model_names.base_model, model_names)

Expand Down
6 changes: 3 additions & 3 deletions modules/modelLoader/GenericLoRAModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from modules.modelLoader.BaseModelLoader import BaseModelLoader
from modules.modelLoader.mixin.InternalModelLoaderMixin import InternalModelLoaderMixin
from modules.modelLoader.mixin.ModelSpecModelLoaderMixin import ModelSpecModelLoaderMixin
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter


def make_lora_model_loader(
Expand Down Expand Up @@ -34,7 +34,7 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
quantization: QuantizationConfig,
) -> model_class | None:
base_model_loader = model_loader_class()
lora_model_loader = lora_loader_class()
Expand All @@ -46,7 +46,7 @@ def load(
model.model_spec = self._load_default_model_spec(model_type)

if model_names.base_model is not None:
base_model_loader.load(model, model_type, model_names, weight_dtypes, quant_filters)
base_model_loader.load(model, model_type, model_names, weight_dtypes, quantization)
lora_model_loader.load(model, model_names)
if embedding_loader_class is not None:
embedding_loader.load(model, model_names.lora, model_names)
Expand Down
25 changes: 12 additions & 13 deletions modules/modelLoader/chroma/ChromaModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

from modules.model.ChromaModel import ChromaModel
from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin
from modules.util.enum.DataType import DataType
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter

import torch

Expand All @@ -34,11 +33,11 @@ def __load_internal(
base_model_name: str,
transformer_model_name: str,
vae_model_name: str,
quant_filters: list[ModuleFilter],
quantization: QuantizationConfig,
):
if os.path.isfile(os.path.join(base_model_name, "meta.json")):
self.__load_diffusers(
model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quant_filters,
model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name, quantization,
)
else:
raise Exception("not an internal model")
Expand All @@ -51,7 +50,7 @@ def __load_diffusers(
base_model_name: str,
transformer_model_name: str,
vae_model_name: str,
quant_filters: list[ModuleFilter],
quantization: QuantizationConfig,
):
diffusers_sub = []
if not transformer_model_name:
Expand Down Expand Up @@ -104,10 +103,10 @@ def __load_diffusers(
transformer_model_name,
#avoid loading the transformer in float32:
torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(),
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None,
)
transformer = self._convert_diffusers_sub_module_to_dtype(
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quantization,
)
else:
transformer = self._load_diffusers_sub_module(
Expand All @@ -116,7 +115,7 @@ def __load_diffusers(
weight_dtypes.train_dtype,
base_model_name,
"transformer",
quant_filters,
quantization,
)

model.model_type = model_type
Expand All @@ -134,7 +133,7 @@ def __load_safetensors(
base_model_name: str,
transformer_model_name: str,
vae_model_name: str,
quant_filters: list[ModuleFilter],
quantization: QuantizationConfig,
):
#no single file .safetensors for Chroma available at the time of writing this code
raise NotImplementedError("Loading of single file Chroma models not supported. Use the diffusers model instead. Optionally, transformer-only safetensor files can be loaded by overriding the transformer.")
Expand All @@ -145,29 +144,29 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
quantization: QuantizationConfig,
):
stacktraces = []

try:
self.__load_internal(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters,
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization,
)
return
except Exception:
stacktraces.append(traceback.format_exc())

try:
self.__load_diffusers(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters,
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization,
)
return
except Exception:
stacktraces.append(traceback.format_exc())

try:
self.__load_safetensors(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quant_filters,
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model, quantization,
)
return
except Exception:
Expand Down
27 changes: 13 additions & 14 deletions modules/modelLoader/flux/FluxModelLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

from modules.model.FluxModel import FluxModel
from modules.modelLoader.mixin.HFModelLoaderMixin import HFModelLoaderMixin
from modules.util.enum.DataType import DataType
from modules.util.config.TrainConfig import QuantizationConfig
from modules.util.enum.ModelType import ModelType
from modules.util.ModelNames import ModelNames
from modules.util.ModelWeightDtypes import ModelWeightDtypes
from modules.util.ModuleFilter import ModuleFilter

import torch

Expand Down Expand Up @@ -37,12 +36,12 @@ def __load_internal(
vae_model_name: str,
include_text_encoder_1: bool,
include_text_encoder_2: bool,
quant_filters: list[ModuleFilter],
quantization: QuantizationConfig,
):
if os.path.isfile(os.path.join(base_model_name, "meta.json")):
self.__load_diffusers(
model, model_type, weight_dtypes, base_model_name, transformer_model_name, vae_model_name,
include_text_encoder_1, include_text_encoder_2, quant_filters,
include_text_encoder_1, include_text_encoder_2, quantization,
)
else:
raise Exception("not an internal model")
Expand All @@ -57,7 +56,7 @@ def __load_diffusers(
vae_model_name: str,
include_text_encoder_1: bool,
include_text_encoder_2: bool,
quant_filters: list[ModuleFilter],
quantization: QuantizationConfig,
):
diffusers_sub = []
transformers_sub = []
Expand Down Expand Up @@ -140,10 +139,10 @@ def __load_diffusers(
transformer_model_name,
#avoid loading the transformer in float32:
torch_dtype = torch.bfloat16 if weight_dtypes.transformer.torch_dtype() is None else weight_dtypes.transformer.torch_dtype(),
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer == DataType.GGUF else None,
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16) if weight_dtypes.transformer.is_gguf() else None,
)
transformer = self._convert_diffusers_sub_module_to_dtype(
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quantization,
)
else:
transformer = self._load_diffusers_sub_module(
Expand All @@ -152,7 +151,7 @@ def __load_diffusers(
weight_dtypes.train_dtype,
base_model_name,
"transformer",
quant_filters,
quantization,
)

model.model_type = model_type
Expand All @@ -174,7 +173,7 @@ def __load_safetensors(
vae_model_name: str,
include_text_encoder_1: bool,
include_text_encoder_2: bool,
quant_filters: list[ModuleFilter],
quantization: QuantizationConfig,
):
transformer = FluxTransformer2DModel.from_single_file(
#always load transformer separately even though FluxPipeLine.from_single_file() could load it, to avoid loading in float32:
Expand Down Expand Up @@ -227,7 +226,7 @@ def __load_safetensors(
print("text encoder 2 (t5) not loaded, continuing without it")

transformer = self._convert_diffusers_sub_module_to_dtype(
pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quant_filters,
pipeline.transformer, weight_dtypes.transformer, weight_dtypes.train_dtype, quantization,
)

model.model_type = model_type
Expand All @@ -245,14 +244,14 @@ def load(
model_type: ModelType,
model_names: ModelNames,
weight_dtypes: ModelWeightDtypes,
quant_filters: list[ModuleFilter] | None = None,
quantization: QuantizationConfig,
):
stacktraces = []

try:
self.__load_internal(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters,
model_names.include_text_encoder, model_names.include_text_encoder_2, quantization,
)
return
except Exception:
Expand All @@ -261,7 +260,7 @@ def load(
try:
self.__load_diffusers(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters,
model_names.include_text_encoder, model_names.include_text_encoder_2, quantization,
)
return
except Exception:
Expand All @@ -270,7 +269,7 @@ def load(
try:
self.__load_safetensors(
model, model_type, weight_dtypes, model_names.base_model, model_names.transformer_model, model_names.vae_model,
model_names.include_text_encoder, model_names.include_text_encoder_2, quant_filters,
model_names.include_text_encoder, model_names.include_text_encoder_2, quantization,
)
return
except Exception:
Expand Down
Loading