Skip to content

Commit

Permalink
Fix cooadtr merge
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHedegaard committed Dec 30, 2021
1 parent b1165eb commit 63473e6
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 13 deletions.
23 changes: 10 additions & 13 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def main(args):
np.random.seed(seed)
random.seed(seed)

model = transformer_models.VisionTransformer_v3(
# model = transformer_models.VisionTransformer_v3(
model = transformer_models.CoVisionTransformer(
args=args,
img_dim=args.enc_layers, # VisionTransformer_v3
img_dim=args.enc_layers,
patch_dim=args.patch_dim,
out_dim=args.numclass,
embedding_dim=args.embedding_dim,
Expand All @@ -78,18 +79,12 @@ def main(args):
try:
from ptflops import get_model_complexity_info

def input_constructor(*largs, **lkwargs):
return {
"sequence_input_rgb": torch.ones(()).new_empty(
(1, args.enc_layers, args.dim_feature // 3 * 2)
),
"sequence_input_flow": torch.ones(()).new_empty(
(1, args.enc_layers, args.dim_feature // 3)
),
}
# Warm up model
model.forward_steps(torch.randn(1, args.dim_feature, args.enc_layers))
model.call_mode = "forward_step"

flops, params = get_model_complexity_info(
model, (0, 0), input_constructor=input_constructor, as_strings=False
model, (args.dim_feature,), as_strings=False
)
print(f"Model FLOPs: {flops}")
print(f"Model params: {params}")
Expand All @@ -98,11 +93,12 @@ def input_constructor(*largs, **lkwargs):
print(e)
...

model.call_mode = "forward"
model.to(device)

loss_need = [
"labels_encoder",
"labels_decoder",
# "labels_decoder",
]
criterion = utl.SetCriterion(
num_classes=args.numclass, losses=loss_need, args=args
Expand Down Expand Up @@ -200,6 +196,7 @@ def input_constructor(*largs, **lkwargs):
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)

train_stats = train_one_epoch(
model,
criterion,
Expand Down
41 changes: 41 additions & 0 deletions transformer_models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch import nn
from .Attention import SelfAttention
from continual_transformers import CoSiTransformerEncoder, CoReSiTransformerEncoder


class Residual(nn.Module):
Expand Down Expand Up @@ -80,6 +81,46 @@ def forward(self, x):
return self.net(x)


def CoTransformerModel(
dim,
depth,
heads,
mlp_dim,
dropout_rate=0.1,
attn_dropout_rate=0.1,
sequence_len=64,
):
assert depth in {1, 2}

if depth == 1:
return CoSiTransformerEncoder(
sequence_len=sequence_len,
embed_dim=dim,
num_heads=heads,
dropout=dropout_rate,
in_proj_bias=False,
query_index=-1,
ff_hidden_dim=mlp_dim,
ff_activation=nn.GELU(),
device=None,
dtype=None,
)

# depth == 2
return CoReSiTransformerEncoder(
sequence_len=sequence_len,
embed_dim=dim,
num_heads=heads,
dropout=dropout_rate,
in_proj_bias=False,
query_index=-1,
ff_hidden_dim=mlp_dim,
ff_activation=nn.GELU(),
device=None,
dtype=None,
)


def _register_ptflops():
try:
from ptflops import flops_counter as fc
Expand Down

0 comments on commit 63473e6

Please sign in to comment.