Skip to content

Commit

Permalink
More accurate memory estimation for cosmos and hunyuan video.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Jan 16, 2025
1 parent 6320d05 commit 9d8b6c1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,8 +388,8 @@ def __init__(self, sd=None, device=None, config=None, dtype=None):
ddconfig = {'z_channels': 16, 'latent_channels': self.latent_channels, 'z_factor': 1, 'resolution': 1024, 'in_channels': 3, 'out_channels': 3, 'channels': 128, 'channels_mult': [2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [32], 'dropout': 0.0, 'patch_size': 4, 'num_groups': 1, 'temporal_compression': 8, 'spacial_compression': 8}
self.first_stage_model = comfy.ldm.cosmos.vae.CausalContinuousVideoTokenizer(**ddconfig)
#TODO: these values are a bit off because this is not a standard VAE
self.memory_used_decode = lambda shape, dtype: (220 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (500 * max(shape[2], 2) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (50 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (50 * (round((shape[2] + 7) / 8) * 8) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
else:
logging.warning("WARNING: No VAE weights detected, VAE not initalized.")
Expand Down
4 changes: 2 additions & 2 deletions comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ class HunyuanVideo(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.HunyuanVideo

memory_usage_factor = 2.0 #TODO
memory_usage_factor = 1.7 #TODO

supported_inference_dtypes = [torch.bfloat16, torch.float32]

Expand Down Expand Up @@ -839,7 +839,7 @@ class CosmosT2V(supported_models_base.BASE):
unet_extra_config = {}
latent_format = latent_formats.Cosmos1CV8x8x8

memory_usage_factor = 2.4 #TODO
memory_usage_factor = 1.6 #TODO

supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32] #TODO

Expand Down

0 comments on commit 9d8b6c1

Please sign in to comment.