-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[Common] add flop counter pass for nni.fx. #5344
Conversation
BWD = auto() | ||
|
||
|
||
def _format_flops(flops: float) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
duplicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid import of a private function?
Returns | ||
------- | ||
module: torch.fx.GraphModule | ||
The same module as the input. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The implementation looks good to me, could we return a dict like {'flops': {'total': xxx, 'encoder': xxx}, 'params': {'total': xxx, 'encoder': xxx}}
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be another function to calculate these metrics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think returning a dict might be confusing because it's not clear to which granularity should we make these results.
May it be fx.node or nn.Module. should depend on the clear use case though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think returning a dict might be confusing because it's not clear to which granularity should we make these results.
May it be fx.node or nn.Module. should depend on the clear use case though.
Good idea! nn.Module
or op level is better!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully now it returns what is correct.
Description
Add a flop count and param count pass for nni.fx
Test Options
Checklist
How to test