Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions comfy/ldm/hunyuan_video/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class HunyuanVideoParams:
qkv_bias: bool
guidance_embed: bool
byt5: bool
meanflow: bool


class SelfAttentionRef(nn.Module):
Expand Down Expand Up @@ -256,6 +257,11 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
else:
self.byt5_in = None

if params.meanflow:
self.time_r_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size, dtype=dtype, device=device, operations=operations)
else:
self.time_r_in = None

if final_layer:
self.final_layer = LastLayer(self.hidden_size, self.patch_size[-1], self.out_channels, dtype=dtype, device=device, operations=operations)

Expand All @@ -282,6 +288,14 @@ def forward_orig(
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256, time_factor=1.0).to(img.dtype))

if self.time_r_in is not None:
w = torch.where(transformer_options['sigmas'][0] == transformer_options['sample_sigmas'])[0] # This most likely could be improved
Copy link
Contributor

@RandomGitUser321 RandomGitUser321 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just use torch.isclose() instead of the comparison? Floats can get unreliable in situations like this without tolerance checks.

if len(w) > 0:
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
vec = (vec + vec_r) / 2

if ref_latent is not None:
ref_latent_ids = self.img_ids(ref_latent)
ref_latent = self.img_in(ref_latent)
Expand Down
12 changes: 10 additions & 2 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,20 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
dit_config["in_channels"] = in_w.shape[1] #SkyReels img2video has 32 input channels
dit_config["patch_size"] = list(in_w.shape[2:])
dit_config["out_channels"] = out_w.shape[0] // math.prod(dit_config["patch_size"])
if '{}vector_in.in_layer.weight'.format(key_prefix) in state_dict:
if any(s.startswith('{}vector_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["vec_in_dim"] = 768
dit_config["axes_dim"] = [16, 56, 56]
else:
dit_config["vec_in_dim"] = None

if len(dit_config["patch_size"]) == 2:
dit_config["axes_dim"] = [64, 64]
else:
dit_config["axes_dim"] = [16, 56, 56]

if any(s.startswith('{}time_r_in.'.format(key_prefix)) for s in state_dict_keys):
dit_config["meanflow"] = True
else:
dit_config["meanflow"] = False

dit_config["context_in_dim"] = state_dict['{}txt_in.input_embedder.weight'.format(key_prefix)].shape[1]
dit_config["hidden_size"] = in_w.shape[0]
Expand Down
Loading