Skip to content

Commit

Permalink
Make supported_dtypes a priority list.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 7, 2024
1 parent cb7c4b4 commit 6969fc9
Showing 1 changed file with 22 additions and 12 deletions.
34 changes: 22 additions & 12 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,12 +562,22 @@ def unet_dtype(device=None, model_params=0, supported_dtypes=[torch.float16, tor
if model_params * 2 > free_model_memory:
return fp8_dtype

if should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16
for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16

for dt in supported_dtypes:
if dt == torch.float16 and should_use_fp16(device=device, model_params=model_params, manual_cast=True):
if torch.float16 in supported_dtypes:
return torch.float16
if dt == torch.bfloat16 and should_use_bf16(device, model_params=model_params, manual_cast=True):
if torch.bfloat16 in supported_dtypes:
return torch.bfloat16

return torch.float32

# None means no manual cast
Expand All @@ -583,13 +593,13 @@ def unet_manual_cast(weight_dtype, inference_device, supported_dtypes=[torch.flo
if bf16_supported and weight_dtype == torch.bfloat16:
return None

if fp16_supported and torch.float16 in supported_dtypes:
return torch.float16
for dt in supported_dtypes:
if dt == torch.float16 and fp16_supported:
return torch.float16
if dt == torch.bfloat16 and bf16_supported:
return torch.bfloat16

elif bf16_supported and torch.bfloat16 in supported_dtypes:
return torch.bfloat16
else:
return torch.float32
return torch.float32

def text_encoder_offload_device():
if args.gpu_only:
Expand Down

0 comments on commit 6969fc9

Please sign in to comment.