Skip to content

Commit 5e5e46d

Browse files
Not really tested WAN Phantom Support. (Comfy-Org#8321)
1 parent 4eba316 commit 5e5e46d

File tree

3 files changed

+52
-1
lines changed

3 files changed

+52
-1
lines changed

comfy/ldm/wan/model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,13 +539,20 @@ def block_wrap(args):
539539
x = self.unpatchify(x, grid_sizes)
540540
return x
541541

542-
def forward(self, x, timestep, context, clip_fea=None, transformer_options={}, **kwargs):
542+
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
543543
bs, c, t, h, w = x.shape
544544
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
545+
545546
patch_size = self.patch_size
546547
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
547548
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
548549
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
550+
551+
if time_dim_concat is not None:
552+
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
553+
x = torch.cat([x, time_dim_concat], dim=2)
554+
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
555+
549556
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device, dtype=x.dtype)
550557
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).reshape(-1, 1, 1)
551558
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).reshape(1, -1, 1)

comfy/model_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,11 @@ def extra_conds(self, **kwargs):
10571057
clip_vision_output = kwargs.get("clip_vision_output", None)
10581058
if clip_vision_output is not None:
10591059
out['clip_fea'] = comfy.conds.CONDRegular(clip_vision_output.penultimate_hidden_states)
1060+
1061+
time_dim_concat = kwargs.get("time_dim_concat", None)
1062+
if time_dim_concat is not None:
1063+
out['time_dim_concat'] = comfy.conds.CONDRegular(self.process_latent_in(time_dim_concat))
1064+
10601065
return out
10611066

10621067

comfy_extras/nodes_wan.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,44 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, sta
345345
out_latent["samples"] = latent
346346
return (positive, negative, out_latent)
347347

348+
class WanPhantomSubjectToVideo:
349+
@classmethod
350+
def INPUT_TYPES(s):
351+
return {"required": {"positive": ("CONDITIONING", ),
352+
"negative": ("CONDITIONING", ),
353+
"vae": ("VAE", ),
354+
"width": ("INT", {"default": 832, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
355+
"height": ("INT", {"default": 480, "min": 16, "max": nodes.MAX_RESOLUTION, "step": 16}),
356+
"length": ("INT", {"default": 81, "min": 1, "max": nodes.MAX_RESOLUTION, "step": 4}),
357+
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096}),
358+
},
359+
"optional": {"images": ("IMAGE", ),
360+
}}
361+
362+
RETURN_TYPES = ("CONDITIONING", "CONDITIONING", "CONDITIONING", "LATENT")
363+
RETURN_NAMES = ("positive", "negative_text", "negative_img_text", "latent")
364+
FUNCTION = "encode"
365+
366+
CATEGORY = "conditioning/video_models"
367+
368+
def encode(self, positive, negative, vae, width, height, length, batch_size, images):
369+
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
370+
cond2 = negative
371+
if images is not None:
372+
images = comfy.utils.common_upscale(images[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
373+
latent_images = []
374+
for i in images:
375+
latent_images += [vae.encode(i.unsqueeze(0)[:, :, :, :3])]
376+
concat_latent_image = torch.cat(latent_images, dim=2)
377+
378+
positive = node_helpers.conditioning_set_values(positive, {"time_dim_concat": concat_latent_image})
379+
cond2 = node_helpers.conditioning_set_values(negative, {"time_dim_concat": concat_latent_image})
380+
negative = node_helpers.conditioning_set_values(negative, {"time_dim_concat": comfy.latent_formats.Wan21().process_out(torch.zeros_like(concat_latent_image))})
381+
382+
out_latent = {}
383+
out_latent["samples"] = latent
384+
return (positive, cond2, negative, out_latent)
385+
348386
NODE_CLASS_MAPPINGS = {
349387
"WanImageToVideo": WanImageToVideo,
350388
"WanFunControlToVideo": WanFunControlToVideo,
@@ -353,4 +391,5 @@ def encode(self, positive, negative, vae, width, height, length, batch_size, sta
353391
"WanVaceToVideo": WanVaceToVideo,
354392
"TrimVideoLatent": TrimVideoLatent,
355393
"WanCameraImageToVideo": WanCameraImageToVideo,
394+
"WanPhantomSubjectToVideo": WanPhantomSubjectToVideo,
356395
}

0 commit comments

Comments
 (0)