@@ -117,6 +117,8 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
117117 if weight_has_function or weight .dtype != dtype :
118118 with wf_context :
119119 weight = weight .to (dtype = dtype )
120+ if isinstance (weight , QuantizedTensor ):
121+ weight = weight .dequantize ()
120122 for f in s .weight_function :
121123 weight = f (weight )
122124
@@ -502,7 +504,7 @@ def convert_weight(self, weight, inplace=False, **kwargs):
502504 weight *= self .scale_weight .to (device = weight .device , dtype = weight .dtype )
503505 return weight
504506 else :
505- return weight * self .scale_weight .to (device = weight .device , dtype = weight . dtype )
507+ return weight . to ( dtype = torch . float32 ) * self .scale_weight .to (device = weight .device , dtype = torch . float32 )
506508
507509 def set_weight (self , weight , inplace_update = False , seed = None , return_weight = False , ** kwargs ):
508510 weight = comfy .float .stochastic_rounding (weight / self .scale_weight .to (device = weight .device , dtype = weight .dtype ), self .weight .dtype , seed = seed )
@@ -643,6 +645,24 @@ def forward(self, input, *args, **kwargs):
643645 not isinstance (input , QuantizedTensor )):
644646 input = QuantizedTensor .from_float (input , self .layout_type , scale = self .input_scale , dtype = self .weight .dtype )
645647 return self ._forward (input , self .weight , self .bias )
648+
649+ def convert_weight (self , weight , inplace = False , ** kwargs ):
650+ if isinstance (weight , QuantizedTensor ):
651+ return weight .dequantize ()
652+ else :
653+ return weight
654+
655+ def set_weight (self , weight , inplace_update = False , seed = None , return_weight = False , ** kwargs ):
656+ if getattr (self , 'layout_type' , None ) is not None :
657+ weight = QuantizedTensor .from_float (weight , self .layout_type , scale = None , dtype = self .weight .dtype , stochastic_rounding = seed , inplace_ops = True )
658+ else :
659+ weight = weight .to (self .weight .dtype )
660+ if return_weight :
661+ return weight
662+
663+ assert inplace_update is False # TODO: eventually remove the inplace_update stuff
664+ self .weight = torch .nn .Parameter (weight , requires_grad = False )
665+
646666 return MixedPrecisionOps
647667
648668def pick_operations (weight_dtype , compute_dtype , load_device = None , disable_fast_fp8 = False , fp8_optimizations = False , scaled_fp8 = None , model_config = None ):
0 commit comments