Skip to content

Commit aa6eb33

Browse files
0x45f0x45f
authored andcommitted
set net.forward to original forward function in flops (#36852)
1 parent 027664e commit aa6eb33

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

python/paddle/hapi/dynamic_flops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import paddle.nn as nn
1818
import numpy as np
1919
from .static_flops import static_flops, Table
20+
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
2021

2122
__all__ = []
2223

@@ -100,6 +101,10 @@ def count_leaky_relu(m, x, y):
100101
#Total Flops: 347560 Total Params: 61610
101102
"""
102103
if isinstance(net, nn.Layer):
104+
# If net is a dy2stat model, net.forward is StaticFunction instance,
105+
# we set net.forward to original forward function.
106+
_, net.forward = unwrap_decorators(net.forward)
107+
103108
inputs = paddle.randn(input_size)
104109
return dynamic_flops(
105110
net,

0 commit comments

Comments
 (0)