Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

[Common] add flop counter pass for nni.fx. #5344

Merged
merged 122 commits into from
Jun 9, 2023

Conversation

super-dainiu
Copy link
Contributor

Description

Add a flop count and param count pass for nni.fx

model = symbolic_trace(model)
model = counter_pass(model, torch.randn(8, 3, 224, 224), verbose=True)

Test Options

  • fast test
  • full test - HPO
  • full test - NAS
  • full test - compression

Checklist

  • test case
  • doc

How to test

@liuzhe-lz liuzhe-lz requested review from J-shang and Bonytu February 13, 2023 02:42
@J-shang J-shang added the v3.1 label Feb 20, 2023
BWD = auto()


def _format_flops(flops: float) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid import of a private function?

Returns
-------
module: torch.fx.GraphModule
The same module as the input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation looks good to me, could we return a dict like {'flops': {'total': xxx, 'encoder': xxx}, 'params': {'total': xxx, 'encoder': xxx}}?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There could be another function to calculate these metrics.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think returning a dict might be confusing because it's not clear to which granularity should we make these results.

May it be fx.node or nn.Module. should depend on the clear use case though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think returning a dict might be confusing because it's not clear to which granularity should we make these results.

May it be fx.node or nn.Module. should depend on the clear use case though.

Good idea! nn.Module or op level is better!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully now it returns what is correct.

@Bonytu Bonytu merged commit 928575b into microsoft:master Jun 9, 2023
@super-dainiu super-dainiu deleted the flop_count branch June 10, 2023 17:28
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.