Skip to content

Commit

Permalink
ACoLP
Browse files Browse the repository at this point in the history
  • Loading branch information
southnx committed Oct 3, 2023
1 parent c7571d5 commit c215111
Show file tree
Hide file tree
Showing 8 changed files with 901 additions and 0 deletions.
83 changes: 83 additions & 0 deletions models/Action_Prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
"""
Action Prompt Network
"""

import torch
from torch import nn

from models.HOIPrompting import MulitHeadAttention, HOIPrompt
# from models.args import get_args

# args = get_args()
# device = torch.device('cuda', args.local_rank)


class PositionEmbed(nn.Module):
def __init__(self):
super().__init__()

def forward(self):
pos_embd = 0

return pos_embd


class ActionPrompt(nn.Module):
def __init__(self, num_actions = 50, dim = 1024):
super().__init__()
self.num_actions = num_actions
self.dim = dim
self.action_prompt_model = HOIPrompt(embed_dim = self.dim)
self.atten = MulitHeadAttention(dim = self.dim, num_heads = 1)

def forward(self, comb_fea: list, action_fea: list):
print("comb_fea shape: ", comb_fea.shape)
_, num_comb, _ = comb_fea.shape
_, num_act, _ = action_fea.shape
total_vis_prompt = []
# total_action_prompt = []
# total_action_prompt = torch.zeros(self.dim).to(device)
total_action_prompt = torch.zeros((1, self.dim)).cuda()

# generate visual prompts for each human-object combination
for i in range(num_comb):
single_comb_fea = comb_fea[:, i, :].unsqueeze(0)
# print(single_comb_fea.shape) # torch.Size([1, 19, 64])
single_visual_prompt = self.action_prompt_model(torch.Tensor(action_fea), single_comb_fea)
# print("single_visual_prompt shape: ", single_visual_prompt.shape) # torch.Size([1, 50, 64])
total_vis_prompt.append(single_visual_prompt) # [torch.Size([1, 50, 64]), ...]

# generate action prompts for each action
for i in range(num_act):
single_act_fea = action_fea[:, i, :].unsqueeze(0)
# learned_act_prompt = torch.zeros((1, 1, self.dim)).to(device)
# avg_act_prompt = torch.zeros(self.dim).to(device)
avg_act_prompt = torch.zeros(self.dim).cuda()
for j in range(num_comb):
single_act_prompt = self.atten(
total_vis_prompt[j][:, i, :].unsqueeze(0),
single_act_fea,
single_act_fea
).squeeze(0).squeeze(0)
# print("single_act_prompt: ", single_act_prompt.shape) # torch.Size([1, 1, 64])
avg_act_prompt += single_act_prompt
avg_act_prompt /= num_comb
# print("avg_act_prompt: ", avg_act_prompt)
# # total_action_prompt.append(avg_act_prompt)
# if i == 0:
# total_action_prompt = torch.stack((total_action_prompt, avg_act_prompt))
# else:
# total_action_prompt = torch.cat((total_action_prompt, avg_act_prompt.unsqueeze(0)))
total_action_prompt = torch.cat((total_action_prompt, avg_act_prompt.unsqueeze(0)), dim = 0)

del total_vis_prompt
return total_action_prompt[1:]


if __name__ == '__main__':
f_ac = torch.randn(1, 50, 1024)
f_cb = torch.randn(1, 130, 1024)
model = ActionPrompt()
act_ptompt = model(f_cb, f_ac)
print("act_ptompt: ", act_ptompt)
print(act_ptompt.shape) # torch.Size([50, 1024])
99 changes: 99 additions & 0 deletions models/HOIPrompting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from timm.models.layers import trunc_normal_
import torch
from torch import nn
import sys
sys.path.append("../")
from clip.model import QuickGELU


class MulitHeadAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads

self.scale = qk_scale or head_dim ** -0.5

self.q_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.k_proj = nn.Linear(dim, dim, bias=qkv_bias)
self.v_proj = nn.Linear(dim, dim, bias=qkv_bias)

self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

def forward(self, q, k, v):
B, N, C = q.shape
B, M, C = k.shape
q = self.q_proj(q).reshape(B, N, self.num_heads, C // self.num_heads).permute(0,2,1,3)
k = self.k_proj(k).reshape(B, M, self.num_heads, C // self.num_heads).permute(0,2,1,3)
v = self.v_proj(v).reshape(B, M, self.num_heads, C // self.num_heads).permute(0,2,1,3)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class PromptGeneratorLayer(nn.Module):
def __init__(
self,
d_model,
nhead,
dropout=0.,
):
super().__init__()
self.cross_attn = MulitHeadAttention(d_model, nhead, proj_drop=dropout)

self.norm1 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)

self.dropout = nn.Dropout(dropout)

self.mlp = nn.Sequential(
nn.Linear(d_model, d_model * 4),
QuickGELU(),
nn.Dropout(dropout),
nn.Linear(d_model * 4, d_model)
)

def forward(self, x, visual):
q = k = v = self.norm1(x)
x = x + self.cross_attn(q, visual, visual)
x = x + self.dropout(self.mlp(self.norm3(x)))
return x


class HOIPrompt(nn.Module):
def __init__(self, layers=2, embed_dim=64, alpha=0.1,):
super().__init__()
self.norm = nn.LayerNorm(embed_dim)
self.decoder = nn.ModuleList([PromptGeneratorLayer(embed_dim, embed_dim//64) for _ in range(layers)])
self.alpha = nn.Parameter(torch.ones(embed_dim) * alpha)
self.apply(self._init_weights)


def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)


def forward(self, text, visual):
# B, N, C = visual.shape
visual = self.norm(visual)
for layer in self.decoder:
text = layer(text, visual)
# print("alpha: ", self.alpha)

return self.alpha * text
# return self.alpha * text + text

Loading

0 comments on commit c215111

Please sign in to comment.