Skip to content

Commit 635406e

Browse files
Only enable fp16 on z image models that actually support it. (Comfy-Org#12065)
1 parent ed6002c commit 635406e

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

comfy/ldm/lumina/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,7 @@ def __init__(
451451
device=None,
452452
dtype=None,
453453
operations=None,
454+
**kwargs,
454455
) -> None:
455456
super().__init__()
456457
self.dtype = dtype

comfy/model_detection.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
444444
dit_config["ffn_dim_multiplier"] = (8.0 / 3.0)
445445
dit_config["z_image_modulation"] = True
446446
dit_config["time_scale"] = 1000.0
447+
try:
448+
dit_config["allow_fp16"] = torch.std(state_dict['{}layers.{}.ffn_norm1.weight'.format(key_prefix, dit_config["n_layers"] - 2)], unbiased=False).item() < 0.42
449+
except Exception:
450+
pass
447451
if '{}cap_pad_token'.format(key_prefix) in state_dict_keys:
448452
dit_config["pad_tokens_multiple"] = 32
449453
sig_weight = state_dict.get('{}siglip_embedder.0.weight'.format(key_prefix), None)

comfy/supported_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ class ZImage(Lumina2):
10931093

10941094
def __init__(self, unet_config):
10951095
super().__init__(unet_config)
1096-
if comfy.model_management.extended_fp16_support():
1096+
if comfy.model_management.extended_fp16_support() and unet_config.get("allow_fp16", False):
10971097
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
10981098
self.supported_inference_dtypes.insert(1, torch.float16)
10991099

0 commit comments

Comments
 (0)