Skip to content

Commit

Permalink
Better per model memory usage estimations.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Aug 2, 2024
1 parent 3a9ee99 commit ea03c9d
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 25 deletions.
29 changes: 4 additions & 25 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
self.concat_keys = ()
logging.info("model_type {}".format(model_type.name))
logging.debug("adm {}".format(self.adm_channels))
self.memory_usage_factor = model_config.memory_usage_factor

def apply_model(self, x, t, c_concat=None, c_crossattn=None, control=None, transformer_options={}, **kwargs):
sigma = t
Expand Down Expand Up @@ -252,11 +253,11 @@ def memory_required(self, input_shape):
dtype = self.manual_cast_dtype
#TODO: this needs to be tweaked
area = input_shape[0] * math.prod(input_shape[2:])
return (area * comfy.model_management.dtype_size(dtype) * 0.01) * (1024 * 1024)
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
else:
#TODO: this formula might be too aggressive since I tweaked the sub-quad and split algorithms to use less memory.
area = input_shape[0] * math.prod(input_shape[2:])
return (area * 0.3) * (1024 * 1024)
return (area * 0.15 * self.memory_usage_factor) * (1024 * 1024)


def unclip_adm(unclip_conditioning, device, noise_augmentor, noise_augment_merge=0.0, seed=None):
Expand Down Expand Up @@ -354,6 +355,7 @@ def encode_adm(self, **kwargs):
flat = torch.flatten(torch.cat(out)).unsqueeze(dim=0).repeat(clip_pooled.shape[0], 1)
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)


class SVD_img2vid(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_EDM, device=None):
super().__init__(model_config, model_type, device=device)
Expand Down Expand Up @@ -594,17 +596,6 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
return out

def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(dtype) * 0.012) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)

class AuraFlow(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
Expand Down Expand Up @@ -702,15 +693,3 @@ def extra_conds(self, **kwargs):
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['guidance'] = comfy.conds.CONDRegular(torch.FloatTensor([kwargs.get("guidance", 3.5)]))
return out

def memory_required(self, input_shape):
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
dtype = self.get_dtype()
if self.manual_cast_dtype is not None:
dtype = self.manual_cast_dtype
#TODO: this probably needs to be tweaked
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * comfy.model_management.dtype_size(dtype) * 0.026) * (1024 * 1024)
else:
area = input_shape[0] * input_shape[2] * input_shape[3]
return (area * 0.3) * (1024 * 1024)
11 changes: 11 additions & 0 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class SD15(supported_models_base.BASE):
}

latent_format = latent_formats.SD15
memory_usage_factor = 1.0

def process_clip_state_dict(self, state_dict):
k = list(state_dict.keys())
Expand Down Expand Up @@ -77,6 +78,7 @@ class SD20(supported_models_base.BASE):
}

latent_format = latent_formats.SD15
memory_usage_factor = 1.0

def model_type(self, state_dict, prefix=""):
if self.unet_config["in_channels"] == 4: #SD2.0 inpainting models are not v prediction
Expand Down Expand Up @@ -140,6 +142,7 @@ class SDXLRefiner(supported_models_base.BASE):
}

latent_format = latent_formats.SDXL
memory_usage_factor = 1.0

def get_model(self, state_dict, prefix="", device=None):
return model_base.SDXLRefiner(self, device=device)
Expand Down Expand Up @@ -178,6 +181,8 @@ class SDXL(supported_models_base.BASE):

latent_format = latent_formats.SDXL

memory_usage_factor = 0.7

def model_type(self, state_dict, prefix=""):
if 'edm_mean' in state_dict and 'edm_std' in state_dict: #Playground V2.5
self.latent_format = latent_formats.SDXL_Playground_2_5()
Expand Down Expand Up @@ -505,6 +510,9 @@ class SD3(supported_models_base.BASE):

unet_extra_config = {}
latent_format = latent_formats.SD3

memory_usage_factor = 1.2

text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
Expand Down Expand Up @@ -631,6 +639,9 @@ class Flux(supported_models_base.BASE):

unet_extra_config = {}
latent_format = latent_formats.Flux

memory_usage_factor = 2.6

supported_inference_dtypes = [torch.bfloat16, torch.float32]

vae_key_prefix = ["vae."]
Expand Down
2 changes: 2 additions & 0 deletions comfy/supported_models_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class BASE:
text_encoder_key_prefix = ["cond_stage_model."]
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]

memory_usage_factor = 2.0

manual_cast_dtype = None

@classmethod
Expand Down

0 comments on commit ea03c9d

Please sign in to comment.