Skip to content

Commit

Permalink
Add ptflops hook for SelfAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasHedegaard committed Dec 10, 2021
1 parent bfd9b5f commit 9cd9d39
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
2 changes: 0 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ def input_constructor(*largs, **lkwargs):
model, device_ids=[int(iii) for iii in args.gpu.split(",")]
)
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.output_print("number of params: {}".format(n_parameters))
# logger.output_print(args)

optimizer = torch.optim.Adam(
model.parameters(),
Expand Down
58 changes: 55 additions & 3 deletions transformer_models/Transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,64 @@ def __init__(
),
)
),
Residual(
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))),
]
)
self.net = nn.Sequential(*layers)

def forward(self, x):
return self.net(x)


def _register_ptflops():
try:
from ptflops import flops_counter as fc

def self_attention_counter_hook(module, input, output):
flops = 0

q = input[0]
k = input[0]
v = input[0]
batch_size = q.shape[1]

num_heads = module.num_heads
embed_dim = module.qkv.in_features
kdim = embed_dim
vdim = embed_dim

# initial projections
flops = (
q.shape[0] * q.shape[2] * embed_dim
+ k.shape[0] * k.shape[2] * kdim
+ v.shape[0] * v.shape[2] * vdim
)
if module.qkv.bias is not None:
flops += (q.shape[0] + k.shape[0] + v.shape[0]) * embed_dim

# attention heads: scale, matmul, softmax, matmul
head_dim = embed_dim // num_heads
head_flops = (
q.shape[0] * head_dim
+ head_dim * q.shape[0] * k.shape[0]
+ q.shape[0] * k.shape[0]
+ q.shape[0] * k.shape[0] * head_dim
)

flops += num_heads * head_flops

# final projection, bias is always enabled
flops += q.shape[0] * embed_dim * (embed_dim + 1)

flops *= batch_size
module.__flops__ += int(flops)

fc.MODULES_MAPPING[SelfAttention] = self_attention_counter_hook

except ModuleNotFoundError: # pragma: no cover
pass
except Exception as e: # pragma: no cover
print(f"Failed to add flops_counter_hook: {e}")


_register_ptflops()

0 comments on commit 9cd9d39

Please sign in to comment.