-
Notifications
You must be signed in to change notification settings - Fork 224
Open
Description
- need comprehensive model profiling
- no flop count available
- parameter counts not reported
- memory usage unclear
fix:
pip install fvcore
then:
from fvcore.nn import FlopCountAnalysis, parameter_count_table
def profile_model(model, input_size=(32, 96, 7)):
# Create sample input
x = torch.randn(input_size)
x_mark = torch.randn(input_size[0], input_size[1], 4)
# Count parameters
params = parameter_count_table(model)
# Count FLOPs
flops = FlopCountAnalysis(model, (x, x_mark, x, x_mark))
return {
"params": params,
"flops": flops.total(),
"flops_readable": f"{flops.total() / 1e9:.2f}G"
}add to model's repr method
Metadata
Metadata
Assignees
Labels
No labels