Skip to content

Commit

Permalink
Add Enhance-A-Video
Browse files Browse the repository at this point in the history
  • Loading branch information
kijai committed Dec 21, 2024
1 parent 26aaa26 commit b9ea9bf
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 21 deletions.
Empty file added enhance_a_video/__init__.py
Empty file.
53 changes: 53 additions & 0 deletions enhance_a_video/enhance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
from einops import rearrange
from .globals import get_enhance_weight, get_num_frames

def get_feta_scores(query, key):
img_q, img_k = query, key

num_frames = get_num_frames()

B, S, N, C = img_q.shape

# Calculate spatial dimension
spatial_dim = S // num_frames

# Add time dimension between spatial and head dims
query_image = img_q.reshape(B, spatial_dim, num_frames, N, C)
key_image = img_k.reshape(B, spatial_dim, num_frames, N, C)

# Expand time dimension
query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]
key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C]

# Reshape to match feta_score input format: [(B S) N T C]
query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128])
key_image = rearrange(key_image, "b s t n c -> (b s) n t c")

return feta_score(query_image, key_image, C, num_frames)

def feta_score(query_image, key_image, head_dim, num_frames):
scale = head_dim**-0.5
query_image = query_image * scale
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32
attn_temp = attn_temp.to(torch.float32)
attn_temp = attn_temp.softmax(dim=-1)

# Reshape to [batch_size * num_tokens, num_frames, num_frames]
attn_temp = attn_temp.reshape(-1, num_frames, num_frames)

# Create a mask for diagonal elements
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool()
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1)

# Zero out diagonal elements
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0)

# Calculate mean for each token's attention matrix
# Number of off-diagonal elements per matrix is n*n - n
num_off_diag = num_frames * num_frames - num_frames
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag

enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight())
enhance_scores = enhance_scores.clamp(min=1)
return enhance_scores
38 changes: 38 additions & 0 deletions enhance_a_video/globals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
NUM_FRAMES = None
FETA_WEIGHT = None
ENABLE_FETA_SINGLE = False
ENABLE_FETA_DOUBLE = False


def set_num_frames(num_frames: int):
global NUM_FRAMES
NUM_FRAMES = num_frames


def get_num_frames() -> int:
return NUM_FRAMES


def enable_enhance(single, double):
global ENABLE_FETA_SINGLE, ENABLE_FETA_DOUBLE
ENABLE_FETA_SINGLE = single
ENABLE_FETA_DOUBLE = double

def disable_enhance():
global ENABLE_FETA_SINGLE, ENABLE_FETA_DOUBLE
ENABLE_FETA_SINGLE = False
ENABLE_FETA_DOUBLE = False

def is_enhance_enabled_single() -> bool:
return ENABLE_FETA_SINGLE

def is_enhance_enabled_double() -> bool:
return ENABLE_FETA_DOUBLE

def set_enhance_weight(feta_weight: float):
global FETA_WEIGHT
FETA_WEIGHT = feta_weight


def get_enhance_weight() -> float:
return FETA_WEIGHT
20 changes: 17 additions & 3 deletions hyvideo/diffusion/pipelines/pipeline_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

EXAMPLE_DOC_STRING = """"""
from ...modules.posemb_layers import get_nd_rotary_pos_embed
from ....enhance_a_video.globals import enable_enhance, disable_enhance, set_enhance_weight

def get_rotary_pos_embed(transformer, latent_video_length, height, width):
target_ndim = 3
Expand Down Expand Up @@ -416,6 +417,7 @@ def __call__(
stg_start_percent: Optional[float] = 0.0,
stg_end_percent: Optional[float] = 1.0,
context_options: Optional[Dict[str, Any]] = None,
feta_args: Optional[Dict] = None,
**kwargs,
):
r"""
Expand Down Expand Up @@ -556,10 +558,16 @@ def __call__(
**extra_set_timesteps_kwargs,
)

#if "884" in vae_ver:

latent_video_length = (video_length - 1) // 4 + 1
# elif "888" in vae_ver:
# video_length = (video_length - 1) // 8 + 1
if feta_args is not None:
set_enhance_weight(feta_args["weight"])
feta_start_percent = feta_args["start_percent"]
feta_end_percent = feta_args["end_percent"]
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()


# context windows
use_context_schedule = False
Expand Down Expand Up @@ -667,6 +675,12 @@ def __call__(
input_prompt_embeds = prompt_embeds[1].unsqueeze(0)
input_prompt_mask = prompt_mask[1].unsqueeze(0)
input_prompt_embeds_2 = prompt_embeds_2[1].unsqueeze(0)

if feta_args is not None:
if feta_start_percent <= current_step_percentage <= feta_end_percent:
enable_enhance(feta_args["single_blocks"], feta_args["double_blocks"])
else:
disable_enhance()

latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

Expand Down
39 changes: 23 additions & 16 deletions hyvideo/modules/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from .mlp_layers import MLP, MLPEmbedder, FinalLayer
from .modulate_layers import ModulateDiT, modulate, apply_gate
from .token_refiner import SingleTokenRefiner
import comfy.model_management as mm
from ...enhance_a_video.enhance import get_feta_scores
from ...enhance_a_video.globals import is_enhance_enabled_single, is_enhance_enabled_double, set_num_frames

class MMDoubleStreamBlock(nn.Module):
"""
Expand Down Expand Up @@ -176,12 +177,7 @@ def forward(
# Apply RoPE if needed.
if freqs_cis is not None:
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
#img_q, img_k = img_qq, img_kk
#assert (
# img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
#), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"


# Prepare txt for attention.
txt_modulated = self.txt_norm1(txt)
txt_modulated = modulate(
Expand All @@ -196,13 +192,14 @@ def forward(
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v)
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v)

if is_enhance_enabled_double():
feta_scores = get_feta_scores(img_q, img_k)

# Run actual attention.
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)
v = torch.cat((img_v, txt_v), dim=1)
#assert (
# cu_seqlens_q.shape[0] == 2 * img.shape[0] + 1
#), f"cu_seqlens_q.shape:{cu_seqlens_q.shape}, img.shape[0]:{img.shape[0]}"

attn = attention(
q,
k,
Expand All @@ -218,6 +215,8 @@ def forward(
)

img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :]
if is_enhance_enabled_double():
img_attn *= feta_scores

# Calculate the img bloks.
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate)
Expand Down Expand Up @@ -346,14 +345,16 @@ def forward(
if freqs_cis is not None:
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :]
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :]
img_qq, img_kk = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
assert (
img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
img_q, img_k = img_qq, img_kk
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, head_first=False)
# assert (
# img_qq.shape == img_q.shape and img_kk.shape == img_k.shape
# ), f"img_kk: {img_qq.shape}, img_q: {img_q.shape}, img_kk: {img_kk.shape}, img_k: {img_k.shape}"
q = torch.cat((img_q, txt_q), dim=1)
k = torch.cat((img_k, txt_k), dim=1)

if is_enhance_enabled_single():
feta_scores = get_feta_scores(img_q, img_k)

# Compute attention.
#assert (
# cu_seqlens_q.shape[0] == 2 * x.shape[0] + 1
Expand Down Expand Up @@ -411,10 +412,15 @@ def forward(
batch_size=x.shape[0],
attn_mask=attn_mask
)

if is_enhance_enabled_single():
attn *= feta_scores

# Compute activation in mlp stream, cat again and run second linear layer.
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
return x + apply_gate(output, gate=mod_gate)
output = x + apply_gate(output, gate=mod_gate)


return output


class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin):
Expand Down Expand Up @@ -673,6 +679,7 @@ def forward(
oh // self.patch_size[1],
ow // self.patch_size[2],
)
set_num_frames(img.shape[2])

# Prepare modulation vectors.
vec = self.time_in(t)
Expand Down
29 changes: 27 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,27 @@ def INPUT_TYPES(s):

def setargs(self, **kwargs):
return (kwargs, )

class HyVideoEnhanceAVideo:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"weight": ("FLOAT", {"default": 2.0, "min": 0, "max": 100, "step": 0.01, "tooltip": "The feta Weight of the Enhance-A-Video"}),
"single_blocks": ("BOOLEAN", {"default": False, "tooltip": "Enable Enhance-A-Video for single blocks"}),
"double_blocks": ("BOOLEAN", {"default": False, "tooltip": "Enable Enhance-A-Video for double blocks"}),
"start_percent": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "Start percentage of the steps to apply Enhance-A-Video"}),
"end_percent": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "End percentage of the steps to apply Enhance-A-Video"}),
},
}
RETURN_TYPES = ("FETAARGS",)
RETURN_NAMES = ("feta_args",)
FUNCTION = "setargs"
CATEGORY = "HunyuanVideoWrapper"
DESCRIPTION = "https://github.com/NUS-HPC-AI-Lab/Enhance-A-Video"

def setargs(self, **kwargs):
return (kwargs, )

class HyVideoSTG:
@classmethod
Expand Down Expand Up @@ -1020,7 +1041,7 @@ def INPUT_TYPES(s):
"num_frames": ("INT", {"default": 49, "min": 1, "max": 1024, "step": 4}),
"steps": ("INT", {"default": 30, "min": 1}),
"embedded_guidance_scale": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 9.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"flow_shift": ("FLOAT", {"default": 9.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"force_offload": ("BOOLEAN", {"default": True}),

Expand All @@ -1030,6 +1051,7 @@ def INPUT_TYPES(s):
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"stg_args": ("STGARGS", ),
"context_options": ("COGCONTEXT", ),
"feta_args": ("FETAARGS", ),
}
}

Expand All @@ -1039,7 +1061,7 @@ def INPUT_TYPES(s):
CATEGORY = "HunyuanVideoWrapper"

def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scale, seed, width, height, num_frames,
samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None):
samples=None, denoise_strength=1.0, force_offload=True, stg_args=None, context_options=None, feta_args=None):
model = model.model

device = mm.get_torch_device()
Expand Down Expand Up @@ -1128,6 +1150,7 @@ def process(self, model, hyvid_embeds, flow_shift, steps, embedded_guidance_scal
stg_start_percent=stg_args["stg_start_percent"] if stg_args is not None else 0.0,
stg_end_percent=stg_args["stg_end_percent"] if stg_args is not None else 1.0,
context_options=context_options,
feta_args=feta_args,
)

print_memory(device)
Expand Down Expand Up @@ -1375,6 +1398,7 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
"HyVideoTextEmbedsSave": HyVideoTextEmbedsSave,
"HyVideoTextEmbedsLoad": HyVideoTextEmbedsLoad,
"HyVideoContextOptions": HyVideoContextOptions,
"HyVideoEnhanceAVideo": HyVideoEnhanceAVideo,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"HyVideoSampler": "HunyuanVideo Sampler",
Expand All @@ -1396,4 +1420,5 @@ def sample(self, samples, seed, min_val, max_val, r_bias, g_bias, b_bias):
"HyVideoTextEmbedsSave": "HunyuanVideo TextEmbeds Save",
"HyVideoTextEmbedsLoad": "HunyuanVideo TextEmbeds Load",
"HyVideoContextOptions": "HunyuanVideo Context Options",
"HyVideoEnhanceAVideo": "HunyuanVideo Enhance A Video",
}

0 comments on commit b9ea9bf

Please sign in to comment.