Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ class LayerNorm(nn.LayerNorm):

def forward(self, x):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
layer_dtype = self.bias.dtype
ret = super().forward(x.type(layer_dtype))
return ret.type(orig_type)


Expand All @@ -51,10 +52,10 @@ def forward(self, x):

class ResidualAttentionBlock(nn.Module):
def __init__(
self, d_model, n_head, attn_mask=None, drop_path=0.0,
self, d_model, n_head, attn_mask=None, drop_path=0.0,
):
super().__init__()
super().__init__()

self.n_head = n_head
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
logger.info(f'Drop path rate: {drop_path}')
Expand Down Expand Up @@ -151,11 +152,11 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):

class Transformer(nn.Module):
def __init__(
self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0.,
self, width, layers, heads, attn_mask=None, backbone_drop_path_rate=0.,
use_checkpoint=False, checkpoint_num=[0], t_size=8,
return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
cls_dropout=0.5, num_classes=400,
):
super().__init__()
Expand All @@ -165,7 +166,7 @@ def __init__(
b_dpr = [x.item() for x in torch.linspace(0, backbone_drop_path_rate, layers)]
self.resblocks = nn.ModuleList([
ResidualAttentionBlock(
width, heads, attn_mask,
width, heads, attn_mask,
drop_path=b_dpr[i],
) for i in range(layers)
])
Expand All @@ -187,7 +188,7 @@ def __init__(
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layers)]
self.dec = nn.ModuleList([
Extractor(
n_dim, n_head, mlp_factor=mlp_factor,
n_dim, n_head, mlp_factor=mlp_factor,
dropout=mlp_dropout[i], drop_path=dpr[i],
) for i in range(n_layers)
])
Expand Down Expand Up @@ -224,7 +225,7 @@ def forward(self, x, mode='video', return_all_feats=False):
_, tmp_feats = tmp_x[:1], tmp_x[1:]
tmp_feats = tmp_feats.permute(1, 3, 2, 0).reshape(N, C, T_down, H, W)
tmp_feats = self.dpe[j](tmp_feats).view(N, C, T_down, L - 1).permute(3, 0, 2, 1)
# tmp_x[1:] = tmp_x[1:] + tmp_feats # memory leak
# tmp_x[1:] = tmp_x[1:] + tmp_feats # memory leak
tmp_x = torch.cat([tmp_x[:1], tmp_x[1:] + tmp_feats], dim=0) # no memory leak
# enhancer
tmp_x = tmp_x.permute(2, 0, 1, 3).flatten(0, 1) # T * L, N, C
Expand All @@ -242,16 +243,16 @@ def forward(self, x, mode='video', return_all_feats=False):

class VisionTransformer(nn.Module):
def __init__(
self,
self,
# backbone
input_resolution, patch_size, width, layers, heads, output_dim, backbone_drop_path_rate=0.,
use_checkpoint=False, checkpoint_num=[0], t_size=8,
# extractor
return_list=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11],
n_layers=12, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
mlp_dropout=[0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
cls_dropout=0.5, num_classes=400,

):
super().__init__()
self.input_resolution = input_resolution
Expand All @@ -265,18 +266,18 @@ def __init__(

self.transformer = Transformer(
width, layers, heads,
backbone_drop_path_rate=backbone_drop_path_rate,
backbone_drop_path_rate=backbone_drop_path_rate,
use_checkpoint=use_checkpoint, checkpoint_num=checkpoint_num, t_size=t_size,
return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head,
mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout,
return_list=return_list, n_layers=n_layers, n_dim=n_dim, n_head=n_head,
mlp_factor=mlp_factor, drop_path_rate=drop_path_rate, mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout, num_classes=num_classes,
)

def forward(self, x, mode='video', return_all_feats=False):
x = self.conv1(x) # shape = [*, width, grid, grid]
N, C, T, H, W = x.shape
x = x.permute(0, 2, 3, 4, 1).reshape(N * T, H * W, C)

x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
Expand Down Expand Up @@ -313,10 +314,10 @@ def load_state_dict(model, state_dict):

def vit_only_global_b32(
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
t_size=16, backbone_drop_path_rate=0.,
t_size=16, backbone_drop_path_rate=0.,
return_list=[8, 9, 10, 11],
n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
cls_dropout=0.5, num_classes=400,
):
model = VisionTransformer(
Expand All @@ -329,15 +330,15 @@ def vit_only_global_b32(
use_checkpoint=use_checkpoint,
checkpoint_num=checkpoint_num,
t_size=t_size,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
num_classes=num_classes,
)

Expand All @@ -350,10 +351,10 @@ def vit_only_global_b32(

def vit_only_global_b16(
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
t_size=16, backbone_drop_path_rate=0.,
t_size=16, backbone_drop_path_rate=0.,
return_list=[8, 9, 10, 11],
n_layers=4, n_dim=768, n_head=12, mlp_factor=4.0, drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
cls_dropout=0.5, num_classes=400,
):
model = VisionTransformer(
Expand All @@ -366,15 +367,15 @@ def vit_only_global_b16(
use_checkpoint=use_checkpoint,
checkpoint_num=checkpoint_num,
t_size=t_size,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
num_classes=num_classes,
)

Expand All @@ -387,10 +388,10 @@ def vit_only_global_b16(

def vit_only_global_l14(
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
t_size=16, backbone_drop_path_rate=0.,
t_size=16, backbone_drop_path_rate=0.,
return_list=[20, 21, 22, 23],
n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
cls_dropout=0.5, num_classes=400,
):
model = VisionTransformer(
Expand All @@ -403,15 +404,15 @@ def vit_only_global_l14(
use_checkpoint=use_checkpoint,
checkpoint_num=checkpoint_num,
t_size=t_size,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
num_classes=num_classes,
)

Expand All @@ -424,10 +425,10 @@ def vit_only_global_l14(

def vit_only_global_l14_336(
pretrained=True, use_checkpoint=False, checkpoint_num=[0],
t_size=16, backbone_drop_path_rate=0.,
t_size=16, backbone_drop_path_rate=0.,
return_list=[20, 21, 22, 23],
n_layers=4, n_dim=1024, n_head=16, mlp_factor=4.0, drop_path_rate=0.,
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
mlp_dropout=[0.5, 0.5, 0.5, 0.5],
cls_dropout=0.5, num_classes=400,
):
model = VisionTransformer(
Expand All @@ -440,15 +441,15 @@ def vit_only_global_l14_336(
use_checkpoint=use_checkpoint,
checkpoint_num=checkpoint_num,
t_size=t_size,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
backbone_drop_path_rate=backbone_drop_path_rate,
return_list=return_list,
n_layers=n_layers,
n_dim=n_dim,
n_head=n_head,
mlp_factor=mlp_factor,
drop_path_rate=drop_path_rate,
mlp_dropout=mlp_dropout,
cls_dropout=cls_dropout,
num_classes=num_classes,
)

Expand All @@ -473,12 +474,12 @@ def vit_only_global_l14_336(
num_frames = 8

model = vit_only_global_l14(
pretrained=False,
pretrained=False,
t_size=num_frames, backbone_drop_path_rate=0.2, drop_path_rate=0.4,
use_checkpoint=True, checkpoint_num=[0],
)

flops = FlopCountAnalysis(model, torch.rand(1, 3, num_frames, 224, 224))
s = time.time()
logger.info(flop_count_table(flops, max_depth=1))
logger.info(time.time()-s)
logger.info(time.time()-s)
2 changes: 1 addition & 1 deletion InternVideo2/multi_modality/demo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import yaml

from .easydict import EasyDict
from easydict import EasyDict

__all__ = ["Config", "pretty_text"]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

origin_num_frames = 4

use_half_precision = False
use_bf16 = False
use_half_precision = True
use_bf16 = True

inputs = dict(
image_res=224,
Expand Down Expand Up @@ -55,7 +55,7 @@
checkpoint_num=40,
use_flash_attn=use_half_precision,
use_fused_rmsnorm=use_half_precision,
use_fused_mlp=use_half_precision,
use_fused_mlp=False,
# clip teacher
clip_teacher=None,
clip_input_resolution=224,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import torch.nn as nn

import ipdb
from einops import rearrange

from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
Expand Down Expand Up @@ -33,7 +33,7 @@ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
key_padding_mask: a bool tensor of shape (B, S)
"""
assert not need_weights
assert qkv.dtype in [torch.float16, torch.bfloat16]
assert qkv.dtype in [torch.float16, torch.bfloat16], "qkv type is :" + str(qkv.dtype)
assert qkv.is_cuda

if cu_seqlens is None:
Expand Down Expand Up @@ -68,4 +68,4 @@ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
softmax_scale=self.softmax_scale, causal=causal
)

return output, None
return output, None
Loading