Skip to content

Commit

Permalink
Add CoTransformerModel
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHedegaard committed Dec 10, 2021
1 parent 9cd9d39 commit b5fadc1
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
47 changes: 47 additions & 0 deletions transformer_models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from torch import nn
from .Attention import SelfAttention
import continual as co
from continual_transformers import CoReMultiheadAttention, CoSiMultiheadAttention


class Residual(nn.Module):
Expand Down Expand Up @@ -80,6 +82,51 @@ def forward(self, x):
return self.net(x)


def CoTransformerModel(
dim,
depth,
heads,
mlp_dim,
dropout_rate=0.1,
attn_dropout_rate=0.1,
sequence_len=64,
):
assert depth in {1, 2}

layers = []

for d in range(depth):
CoMHA = (
CoReMultiheadAttention if d == 0 and depth == 2 else CoSiMultiheadAttention
)
layers.extend(
[
co.Residual(
co.Sequential(
co.forward_stepping(nn.LayerNorm(dim)),
CoMHA(
dim,
heads,
attn_dropout_rate,
sequence_len=sequence_len,
forward_returns_attn_mask=False,
),
nn.Dropout(p=dropout_rate),
)
),
co.Residual(
co.Sequential(
co.forward_stepping(nn.LayerNorm(dim)),
co.forward_stepping(FeedForward(dim, mlp_dim, dropout_rate)),
)
),
]
)

net = co.Sequential(*layers)
return net


def _register_ptflops():
try:
from ptflops import flops_counter as fc
Expand Down
12 changes: 10 additions & 2 deletions transformer_models/ViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch.nn.functional as F
from .decoder import Decoder, DecoderLayer
from .attn import FullAttention, ProbAttention, AttentionLayer
from .Transformer import TransformerModel
from .Transformer import TransformerModel, CoTransformerModel
from ipdb import set_trace
from .PositionalEncoding import (
FixedPositionalEncoding,
Expand Down Expand Up @@ -72,7 +72,15 @@ def __init__(

self.pe_dropout = nn.Dropout(p=self.dropout_rate)

self.encoder = TransformerModel(
# self.encoder = TransformerModel(
# embedding_dim,
# num_layers,
# num_heads,
# hidden_dim,
# self.dropout_rate,
# self.attn_dropout_rate,
# )
self.encoder = CoTransformerModel(
embedding_dim,
num_layers,
num_heads,
Expand Down

0 comments on commit b5fadc1

Please sign in to comment.