-
Notifications
You must be signed in to change notification settings - Fork 103
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
158 additions
and
21 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
from einops import rearrange | ||
from .globals import get_enhance_weight, get_num_frames | ||
|
||
def get_feta_scores(query, key): | ||
img_q, img_k = query, key | ||
|
||
num_frames = get_num_frames() | ||
|
||
B, S, N, C = img_q.shape | ||
|
||
# Calculate spatial dimension | ||
spatial_dim = S // num_frames | ||
|
||
# Add time dimension between spatial and head dims | ||
query_image = img_q.reshape(B, spatial_dim, num_frames, N, C) | ||
key_image = img_k.reshape(B, spatial_dim, num_frames, N, C) | ||
|
||
# Expand time dimension | ||
query_image = query_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] | ||
key_image = key_image.expand(-1, -1, num_frames, -1, -1) # [B, S, T, N, C] | ||
|
||
# Reshape to match feta_score input format: [(B S) N T C] | ||
query_image = rearrange(query_image, "b s t n c -> (b s) n t c") #torch.Size([3200, 24, 5, 128]) | ||
key_image = rearrange(key_image, "b s t n c -> (b s) n t c") | ||
|
||
return feta_score(query_image, key_image, C, num_frames) | ||
|
||
def feta_score(query_image, key_image, head_dim, num_frames): | ||
scale = head_dim**-0.5 | ||
query_image = query_image * scale | ||
attn_temp = query_image @ key_image.transpose(-2, -1) # translate attn to float32 | ||
attn_temp = attn_temp.to(torch.float32) | ||
attn_temp = attn_temp.softmax(dim=-1) | ||
|
||
# Reshape to [batch_size * num_tokens, num_frames, num_frames] | ||
attn_temp = attn_temp.reshape(-1, num_frames, num_frames) | ||
|
||
# Create a mask for diagonal elements | ||
diag_mask = torch.eye(num_frames, device=attn_temp.device).bool() | ||
diag_mask = diag_mask.unsqueeze(0).expand(attn_temp.shape[0], -1, -1) | ||
|
||
# Zero out diagonal elements | ||
attn_wo_diag = attn_temp.masked_fill(diag_mask, 0) | ||
|
||
# Calculate mean for each token's attention matrix | ||
# Number of off-diagonal elements per matrix is n*n - n | ||
num_off_diag = num_frames * num_frames - num_frames | ||
mean_scores = attn_wo_diag.sum(dim=(1, 2)) / num_off_diag | ||
|
||
enhance_scores = mean_scores.mean() * (num_frames + get_enhance_weight()) | ||
enhance_scores = enhance_scores.clamp(min=1) | ||
return enhance_scores |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
NUM_FRAMES = None | ||
FETA_WEIGHT = None | ||
ENABLE_FETA_SINGLE = False | ||
ENABLE_FETA_DOUBLE = False | ||
|
||
|
||
def set_num_frames(num_frames: int): | ||
global NUM_FRAMES | ||
NUM_FRAMES = num_frames | ||
|
||
|
||
def get_num_frames() -> int: | ||
return NUM_FRAMES | ||
|
||
|
||
def enable_enhance(single, double): | ||
global ENABLE_FETA_SINGLE, ENABLE_FETA_DOUBLE | ||
ENABLE_FETA_SINGLE = single | ||
ENABLE_FETA_DOUBLE = double | ||
|
||
def disable_enhance(): | ||
global ENABLE_FETA_SINGLE, ENABLE_FETA_DOUBLE | ||
ENABLE_FETA_SINGLE = False | ||
ENABLE_FETA_DOUBLE = False | ||
|
||
def is_enhance_enabled_single() -> bool: | ||
return ENABLE_FETA_SINGLE | ||
|
||
def is_enhance_enabled_double() -> bool: | ||
return ENABLE_FETA_DOUBLE | ||
|
||
def set_enhance_weight(feta_weight: float): | ||
global FETA_WEIGHT | ||
FETA_WEIGHT = feta_weight | ||
|
||
|
||
def get_enhance_weight() -> float: | ||
return FETA_WEIGHT |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters