Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion comfy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
'block_size': qconfig.get("group_size", None),
}
if layout_params['scale'] is not None:
layout_params['scale'] = layout_params['scale'].to(device=device)
manually_loaded_keys.append(weight_scale_key)

self.weight = torch.nn.Parameter(
Expand All @@ -611,7 +612,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
_v = state_dict.pop(param_key, None)
if _v is None:
continue
setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False))
self.register_buffer(param_name, _v.to(device=device))
manually_loaded_keys.append(param_key)

super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
Expand Down
17 changes: 14 additions & 3 deletions comfy/quant_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,16 @@ def generic_copy_(func, args, kwargs):
return qt_dest
return func(*args, **kwargs)

@register_generic_util(torch.ops.aten.to.dtype)
def generic_to_dtype(func, args, kwargs):
"""Handle .to(dtype) calls - dtype conversion only."""
src = args[0]
if isinstance(src, QuantizedTensor):
# For dtype-only conversion, just change the orig_dtype, no real cast is needed
target_dtype = args[1] if len(args) > 1 else kwargs.get('dtype')
src._layout_params["orig_dtype"] = target_dtype
return src
return func(*args, **kwargs)

@register_generic_util(torch.ops.aten._has_compatible_shallow_copy_type.default)
def generic_has_compatible_shallow_copy_type(func, args, kwargs):
Expand Down Expand Up @@ -383,10 +393,11 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn):
scale = torch.tensor(scale)
scale = scale.to(device=tensor.device, dtype=torch.float32)

tensor_scaled = tensor * (1.0 / scale).to(tensor.dtype)
tensor_fp32 = tensor.to(torch.float32)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

casting to fp32 here causes some pretty large slowdowns that make the fp8 ops as slow as 16 bit.

tensor_scaled = tensor_fp32 * (1.0 / scale)
# TODO: uncomment this if it's actually needed because the clamp has a small performance penality'
# lp_amax = torch.finfo(dtype).max
# torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
lp_amax = torch.finfo(dtype).max
torch.clamp(tensor_scaled, min=-lp_amax, max=lp_amax, out=tensor_scaled)
qdata = tensor_scaled.to(dtype, memory_format=torch.contiguous_format)

layout_params = {
Expand Down
16 changes: 15 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import yaml
import math
import os
import json

import comfy.utils

Expand Down Expand Up @@ -917,7 +918,20 @@ class CLIPType(Enum):
def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}):
clip_data = []
for p in ckpt_paths:
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True))
clip_data.append(comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True))
if type(clip_data[-1]) == tuple:
model, metadata = clip_data.pop()
if metadata is not None and "_quantization_metadata" in metadata:
try:
quant_metadata = metadata.pop("_quantization_metadata")
quant_metadata = json.loads(quant_metadata)
if "layers" in quant_metadata:
layer_quant_config = quant_metadata["layers"]
model_options["layer_quant_config"] = layer_quant_config
logging.info(f"Detected quantized text encoder: {len(layer_quant_config)} layers with quantization")
except Exception as e:
logging.warning(f"Failed to parse quantization metadata: {e}")
clip_data.append(model)
return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options)


Expand Down
17 changes: 13 additions & 4 deletions comfy/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,13 +109,22 @@ def __init__(self, device="cpu", max_length=77,

operations = model_options.get("custom_operations", None)
scaled_fp8 = None
layer_quant_config = model_options.get("layer_quant_config", None)

if operations is None:
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
# Use MixedPrecisionOps if layer_quant_config is present (for FP8 text encoders)
if layer_quant_config is not None:
operations = comfy.ops.MixedPrecisionOps
comfy.ops.MixedPrecisionOps._layer_quant_config = layer_quant_config
comfy.ops.MixedPrecisionOps._compute_dtype = dtype
logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers")
else:
operations = comfy.ops.manual_cast
# Fallback to scaled_fp8_ops for backward compatibility
scaled_fp8 = model_options.get("scaled_fp8", None)
if scaled_fp8 is not None:
operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8)
else:
operations = comfy.ops.manual_cast

self.operations = operations
self.transformer = model_class(config, dtype, device, self.operations)
Expand Down
Loading