diff --git a/main.py b/main.py index 1f5f589..86ddf8f 100644 --- a/main.py +++ b/main.py @@ -119,8 +119,6 @@ def input_constructor(*largs, **lkwargs): model, device_ids=[int(iii) for iii in args.gpu.split(",")] ) n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) - logger.output_print("number of params: {}".format(n_parameters)) - # logger.output_print(args) optimizer = torch.optim.Adam( model.parameters(), diff --git a/transformer_models/Transformer.py b/transformer_models/Transformer.py index c1c9cb6..cf66e9a 100644 --- a/transformer_models/Transformer.py +++ b/transformer_models/Transformer.py @@ -71,12 +71,64 @@ def __init__( ), ) ), - Residual( - PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate)) - ), + Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))), ] ) self.net = nn.Sequential(*layers) def forward(self, x): return self.net(x) + + +def _register_ptflops(): + try: + from ptflops import flops_counter as fc + + def self_attention_counter_hook(module, input, output): + flops = 0 + + q = input[0] + k = input[0] + v = input[0] + batch_size = q.shape[1] + + num_heads = module.num_heads + embed_dim = module.qkv.in_features + kdim = embed_dim + vdim = embed_dim + + # initial projections + flops = ( + q.shape[0] * q.shape[2] * embed_dim + + k.shape[0] * k.shape[2] * kdim + + v.shape[0] * v.shape[2] * vdim + ) + if module.qkv.bias is not None: + flops += (q.shape[0] + k.shape[0] + v.shape[0]) * embed_dim + + # attention heads: scale, matmul, softmax, matmul + head_dim = embed_dim // num_heads + head_flops = ( + q.shape[0] * head_dim + + head_dim * q.shape[0] * k.shape[0] + + q.shape[0] * k.shape[0] + + q.shape[0] * k.shape[0] * head_dim + ) + + flops += num_heads * head_flops + + # final projection, bias is always enabled + flops += q.shape[0] * embed_dim * (embed_dim + 1) + + flops *= batch_size + module.__flops__ += int(flops) + + fc.MODULES_MAPPING[SelfAttention] = self_attention_counter_hook + + except ModuleNotFoundError: # pragma: no cover + pass + except Exception as e: # pragma: no cover + print(f"Failed to add flops_counter_hook: {e}") + + +_register_ptflops()