Skip to content
This repository was archived by the owner on Sep 18, 2024. It is now read-only.

[Common] add flop counter pass for nni.fx. #5344

Merged
merged 122 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
122 commits
Select commit Hold shift + click to select a range
a12acb9
Fix issue #5299.
super-dainiu Feb 9, 2023
653ad19
Fix issue #5299.
super-dainiu Feb 9, 2023
82b3f55
[Common] add flop counter pass for nni.fx.
super-dainiu Feb 10, 2023
ca8791f
Merge branch 'microsoft:master' into flop_count
super-dainiu Feb 10, 2023
19a2680
remove depreciate.
super-dainiu Feb 10, 2023
184566f
Merge branch 'master' of https://github.com/super-dainiu/nni into flo…
super-dainiu Feb 10, 2023
53c924f
Merge branch 'flop_count' of https://github.com/super-dainiu/nni into…
super-dainiu Feb 10, 2023
1ea2e9c
Merge branch 'microsoft:master' into flop_count
super-dainiu Feb 22, 2023
b619ca9
[pylint]
super-dainiu Feb 22, 2023
b0faa3d
[pylint]
super-dainiu Feb 22, 2023
d806d7f
[pyright] fix.
super-dainiu Mar 7, 2023
7a71cd3
[pyright] fix.
super-dainiu Mar 7, 2023
f1520e7
[Remote v3] Stage 0 - WebSocket command channel server (#5346)
liuzhe-lz Feb 23, 2023
33601cb
New trial runner and import trial command channel (#5398)
ultmaster Feb 23, 2023
4352139
Refactor nas.nn (Stage 4) - Model Space Hub (#5282)
ultmaster Feb 24, 2023
13874c4
NAS execution engine (stage 1) - Engine interface (#5358)
ultmaster Feb 24, 2023
3d4b256
Tuner Command Channel (#5364)
ultmaster Feb 27, 2023
0da7804
Standardize and improve styles (#5402)
Lijiaoa Feb 27, 2023
5208ab7
NAS execution engine (stage 2) - Sequential and TS (#5359)
ultmaster Feb 27, 2023
7ac8c18
NAS execution engine (stage 3) - CGO (#5360)
ultmaster Feb 27, 2023
c611480
[Compression] Improve quantization speedup with TensorRT (#5039)
QuanluZhang Feb 28, 2023
299a842
[MAGIC] Debug MacOS simple integration test (#5413)
ultmaster Mar 1, 2023
b563280
Rewrite resume and view for experiment (#5365)
ultmaster Mar 1, 2023
db46379
NAS strategy (stage 1) - interface (#5371)
ultmaster Mar 1, 2023
f0caa49
NAS profiler (stage 1) - interface (#5361)
ultmaster Mar 1, 2023
4863827
[quant] add comments (#5410)
QuanluZhang Mar 1, 2023
47120dd
add Python API to kill a trail job (#5411)
HuangLuGuang Mar 2, 2023
8ee826d
NAS profiler (stage 2) - Shape inference and utils (#5362)
ultmaster Mar 2, 2023
0ad807e
NAS oneshot (stage 1) - Expression and profiler utils (#5366)
ultmaster Mar 2, 2023
c63f777
NAS profiler (stage 3) - FLOPs and nn-Meter (#5363)
ultmaster Mar 2, 2023
5d4f417
NAS strategy (stage 2) - utils and bruteforce strategy (#5367)
ultmaster Mar 3, 2023
148db32
NAS strategy (stage 3) - RL strategy (#5368)
ultmaster Mar 3, 2023
03471ee
NAS oneshot (stage 4) - Fix issues with ProxylessNAS (#5369)
ultmaster Mar 3, 2023
16fa0a8
NAS oneshot (stage 5) - strategy (#5370)
ultmaster Mar 3, 2023
e060677
NAS oneshot (stage 2) - Supernet modules (#5372)
ultmaster Mar 3, 2023
2f88ec4
[Compression] pruning stage 0: update base class & pruning utils (#5386)
J-shang Mar 3, 2023
417e07e
[Compression] pruning speedup: support input/output masks (#5385)
J-shang Mar 3, 2023
cf09329
NAS strategy (stage 5) - middleware (#5375)
ultmaster Mar 3, 2023
95f5c29
Bug fix for trial runner and etc. (#5422)
ultmaster Mar 6, 2023
6437970
[Compression] pruning stage 1: add pruning tools (#5387)
J-shang Mar 7, 2023
2b157b9
NAS oneshot (stage 3) - Supernet lightning modules (#5373)
ultmaster Mar 7, 2023
87f581a
NAS strategy (stage 4) - Evolution, HPO and oneshot strategy (#5374)
ultmaster Mar 7, 2023
257a076
NAS experiment (stage 1) - checkpoint serialization utils (#5376)
ultmaster Mar 7, 2023
5c41db2
[Compression] pruning stage 2: add basic/slim/taylor pruner (#5388)
J-shang Mar 7, 2023
0265052
[Compression] pruning stage 3: add scheduled/movement pruner (#5389)
J-shang Mar 7, 2023
61cf4ed
[Compression] pruning stage 4: update speedup support for embedding (…
J-shang Mar 7, 2023
a3dc1ef
NAS experiment (stage 2) - config (#5377)
ultmaster Mar 8, 2023
c98f5d6
[Compression] add basic distiller (#5397)
J-shang Mar 8, 2023
40b7cc7
Patch optimizer for adding param_groups (#5399)
Bonytu Mar 8, 2023
a3f9e37
[Compression] Quantization: add module fusion (#5400)
Bonytu Mar 9, 2023
8567d0d
[Compression] modify concrete_trace_utils for producing fx graph (#5403)
junesnow86 Mar 9, 2023
8da4cec
NAS experiment (stage 3) - main body (#5378)
ultmaster Mar 10, 2023
5b06c0d
NAS benchmark (stage 1) - move files (#5379)
ultmaster Mar 10, 2023
147def3
[Compression] Quantization: add quantizers (#5401)
Bonytu Mar 14, 2023
1834667
feat(speedup): support to call "super()" with concrete trace util. (…
Louis-J Mar 15, 2023
396f2f4
update webui doc (#5419)
Lijiaoa Mar 15, 2023
6aaa9a7
[Compression] allow dummy input as dict. (#5440)
super-dainiu Mar 15, 2023
eb7c9e0
[Common] add warnings and retain the module after inappropriate weigh…
super-dainiu Mar 15, 2023
1a434b0
Fix two bugs of experiment resume and improve unittest of nnimanager …
QuanluZhang Mar 15, 2023
b48eede
Remote training service v3 (#5428)
liuzhe-lz Mar 15, 2023
fd03b07
[Compression] fix tests & bugs (#5444)
J-shang Mar 15, 2023
bda96bc
Bump deps version and migrate to npm (#5443)
liuzhe-lz Mar 15, 2023
206b00c
Enable and fix pipeline issues in NAS (#5439)
ultmaster Mar 16, 2023
ea87c16
NAS benchmark (stage 2) - space and evaluator (#5380)
ultmaster Mar 20, 2023
454418e
Merge command channel APIs and add unit tests (#5450)
liuzhe-lz Mar 20, 2023
5f4f527
[Compression] fix bug in evaluator (#5461)
Bonytu Mar 20, 2023
c622ea9
[Common] let **kwargs be possible during intermediate tracing (#5462)
super-dainiu Mar 20, 2023
830a166
[Compression] optimize the logic about adding param group (#5463)
Bonytu Mar 21, 2023
740f2f0
Reduce NNI manager's poll interval when using v3 training services (#…
liuzhe-lz Mar 21, 2023
43e1846
Increase mac os typescript ut time limit (#5465)
liuzhe-lz Mar 22, 2023
46cf056
[Common] allow preserving self.training in F.dropout by hacking modul…
super-dainiu Mar 23, 2023
116cf2f
Migrate tuner command channel to v3 channel (#5475)
liuzhe-lz Mar 24, 2023
5745db9
[Bugbash] fix bugs in compression (#5472)
J-shang Mar 24, 2023
b4d22cc
[Compression] huggingface transformers extension (#5137)
J-shang Mar 25, 2023
a0731f7
[bug bash]webui improvements (#5478)
Lijiaoa Mar 29, 2023
73add92
[Support] Remote and local had the same log mode (#5480)
Lijiaoa Mar 29, 2023
0c7ff5b
[Bugbash] bugfix & add support for pytorch 2.0 (#5484)
J-shang Mar 29, 2023
fb5c6a7
[Compression] example & test (#5429)
J-shang Mar 29, 2023
3083616
[Common] Fix some discovered issues during test. (#5487)
super-dainiu Mar 30, 2023
e017ae9
[Bugbash] Bug fix (#5467)
Bonytu Mar 30, 2023
a618ec6
add AMC MIT license (#5498)
ethanhe42 Apr 3, 2023
cbfb561
[Compression] Add support for torch 2.0 (#5492)
Bonytu Apr 7, 2023
ef8d6d0
[Hotfix] update mmdet to 3.0 (#5515)
super-dainiu Apr 12, 2023
d3b29ed
[Compatible py37] replace walrus op (#5500)
J-shang Apr 12, 2023
0f2bdb1
[Compression] optimize masks memory cost & support mask dynamic tenso…
J-shang Apr 12, 2023
b1c308f
[Compression] v2.5 pruning tutorial (#5476)
J-shang Apr 13, 2023
778481f
webui some changes as the description (#5496)
Lijiaoa Apr 13, 2023
c670d09
[Compression] add conv1d replacer. (#5512)
super-dainiu Apr 14, 2023
85f70a9
Fix wrong examples in assessor docs (#5520)
dofuuz Apr 17, 2023
74ea5ad
[Webui] Remove trial Add_resumed status (#5519)
Lijiaoa Apr 17, 2023
cbf31cf
NAS 3.0 docs (quick update) (#5509)
ultmaster Apr 17, 2023
a03166c
[Compression] Add quantization tutorial (#5454)
Bonytu Apr 23, 2023
76a2151
[Compression] preview doc (#5491)
J-shang Apr 25, 2023
887052d
[Compression] Quantization Preview Doc (#5516)
Bonytu Apr 25, 2023
1fbc648
Refactor openRow components (#5525)
Lijiaoa Apr 26, 2023
beccc52
[Common] Fix diffusers.Unet trace (#5527)
super-dainiu Apr 26, 2023
b537858
NAS doc update (#5529)
liuzhe-lz Apr 26, 2023
b678b98
Fix training service bugs (#5528)
liuzhe-lz Apr 26, 2023
9178a30
[Hot-fix] fix link click error (#5532)
Lijiaoa Apr 28, 2023
dbc3adb
Bump release pipeline image (#5530)
liuzhe-lz Apr 28, 2023
cbbb6fe
[Bugfix] fix pruning speedup value inplace change issue (#5534)
J-shang May 5, 2023
849bf70
Fix compatibility with filelock 3.12.0 (#5541)
liuzhe-lz May 8, 2023
8e5336c
Pin filelock version (#5543)
J-shang May 8, 2023
7db1142
Remove old compression doc (#5546)
J-shang May 8, 2023
a143785
Fix local IT (#5544)
liuzhe-lz May 8, 2023
2cef011
Fix `html_logo` for sphinx docs (#5545)
ultmaster May 9, 2023
7a727e8
[Release] update webui gif (#5542)
Lijiaoa May 9, 2023
94bdf95
Release Note of v3.0 Preview (#5548)
liuzhe-lz May 10, 2023
c0e4c0d
[Tracer] remove an unused attribute that causes error. (#5550)
super-dainiu May 10, 2023
e8b9fd8
Get rid of IoC and remove unused training services (#5567)
liuzhe-lz May 18, 2023
d446abd
add avgpool2_formula to shape_formula.py (#5565)
avisinghal6 May 20, 2023
5a9be89
[Compression] avoid duplicated replacement for a single target submod…
super-dainiu May 24, 2023
dd89611
fix example config list (#5554) (#5557)
J-shang May 24, 2023
a7d6d20
add feat
super-dainiu May 27, 2023
072b463
lint
super-dainiu May 27, 2023
30d863d
Merge branch 'microsoft:master' into flop_count
super-dainiu May 27, 2023
03315c1
fix
super-dainiu May 27, 2023
eab3ce3
fix
super-dainiu May 27, 2023
f470f40
fix
super-dainiu May 27, 2023
7970cda
by type
super-dainiu May 29, 2023
47aa662
pyright
super-dainiu May 29, 2023
9e17ceb
Merge branch 'master' into flop_count
super-dainiu Jun 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nni/common/concrete_trace_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@
More information about concrete tracing can be found in the :func:`concrete_trace` documentation.
"""
from .concrete_tracer import ConcreteTracer, concrete_trace
from .counter import counter_pass
275 changes: 275 additions & 0 deletions nni/common/concrete_trace_utils/counter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,275 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from typing import Any, Callable, Dict, List, Tuple, Optional, Union
from dataclasses import dataclass, field

import torch
import torch.fx
from torch.fx import Interpreter
from torch.fx.node import Argument, Node

from .flop_utils import flop_count

# pyright: reportUnboundVariable=false
# pyright: reportGeneralTypeIssues=false

Target = Union[Callable[..., Any], str]

def _format_flops(flops: float) -> str:
"""Returns a formatted flops string"""
if flops > 1e12:
return f'{flops / 1e12:.2f} TFLOPs'
elif flops > 1e9:
return f'{flops / 1e9:.2f} GFLOPs'
elif flops > 1e6:
return f'{flops / 1e6:.2f} MFLOPs'
elif flops > 1e3:
return f'{flops / 1e3:.2f} kFLOPs'
return f'{flops} FLOPs'


def _format_memory(nbytes) -> str:
"""Returns a formatted memory size string"""
KB = 1024
MB = 1024 * KB
GB = 1024 * MB
if (abs(nbytes) >= GB):
return '{:.2f} Gb'.format(nbytes * 1.0 / GB)
elif (abs(nbytes) >= MB):
return '{:.2f} Mb'.format(nbytes * 1.0 / MB)
elif (abs(nbytes) >= KB):
return '{:.2f} Kb'.format(nbytes * 1.0 / KB)
else:
return str(nbytes) + ' b'

def compute_size_in_bytes(elem: Union[torch.Tensor, Dict, List, Tuple, int]) -> int:
"""Compute the size of a tensor or a collection of tensors in bytes.

Args:
elem (Union[torch.Tensor, Dict, List, Tuple, int])): Arbitrary nested ``torch.Tensor`` data structure.

Returns:
int: The size of the tensor or the collection of tensors in bytes.
"""
nbytes = 0
if isinstance(elem, torch.Tensor):
if elem.is_quantized:
nbytes += elem.numel() * torch._empty_affine_quantized([], dtype=elem.dtype).element_size()
else:
nbytes += elem.numel() * torch.tensor([], dtype=elem.dtype).element_size()
elif isinstance(elem, dict):
value_list = [v for _, v in elem.items()]
nbytes += compute_size_in_bytes(value_list)
elif isinstance(elem, tuple) or isinstance(elem, list) or isinstance(elem, set):
for e in elem:
nbytes += compute_size_in_bytes(e)
return nbytes


@dataclass
class NInfo:
r"""
The base class to store all profiling and static graph analysis information
needed for a ``Node`` in FX.Graph
"""

# binded back to ``Node``
node: Node

# directory
mod_dir: str = '' # TODO: trace this in concrete_trace

# parameter within this ``Node``
parameters: Dict[str, torch.nn.Parameter] = field(default_factory=lambda: {})

# compute cost
flops: Optional[int] = 0

def __new__(cls, node: Node, **kwargs):
orig_init = cls.__init__

# if initialized, return the existing one
# should disable the __init__ function
if node.meta.get('info', None) is not None:

def _dummy(self, *args, **kwargs):
if getattr(self, '_is_init', False):
self._is_init = True
orig_init(self, *args, **kwargs)
cls.__init__ = orig_init

cls.__init__ = _dummy
return node.meta['info']
return super().__new__(cls)

def __post_init__(self):
self.node.meta['info'] = self

@property
def param_size(self) -> int:
return compute_size_in_bytes(self.parameters)


class GraphCounter(Interpreter):
_to_profile = ['call_module', 'call_function']
_maybe_profile = ['get_attr']

def __init__(self, module):
super().__init__(module)

def run_node(self, node: Node) -> Any:
"""Dispatch to the appropriate method for running a node.

This method inherits ``run_node`` in `Interpreter` but adds the following features:

- ``call_module`` and ``call_function`` are the only two types of nodes that are profiled.
- ``call_method``, ``placeholder``, and ``output`` are not profiled because they are not
computationally intensive.
- ``get_attr`` is a maybe_profiled node. It is profiled because its ``rst`` can be a `nn.Parameter`.

Parameters
----------
node: Node
The node to run.

Returns
-------
rst: Any
The result of running the node.
"""
rst = super().run_node(node)
if node.op in self._to_profile:
NInfo(node, flops=rst[1], parameters=rst[2])
rst = rst[0]
elif node.op in self._maybe_profile:
NInfo(node, parameters=rst[2])
rst = rst[0]
else:
NInfo(node)
return rst

def call_function(self, target: Callable, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
rst = super().call_function(target, args, kwargs)
return rst, flop_count(target, *args, **kwargs), {} # FIXME: call_function might also have flops

def call_module(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
# Execute the method and return the result
assert isinstance(target, str)
submod = self.fetch_attr(target)
return (
submod(*args, **kwargs),
flop_count(submod, *args, **kwargs),
{
k: v for k, v in submod.named_parameters()
}
)

def get_attr(self, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Any]) -> Any:
assert isinstance(target, str)
rst = self.fetch_attr(target)
if isinstance(rst, torch.nn.Parameter):
return (
rst,
0,
{target: rst}
)
return rst, 0, {}

def summarize(self) -> str:
"""
Summarizes the profiled statistics of the `GraphModule` in
tabular format. Note that this API requires the ``tabulate`` module
to be installed.

Returns:
str: The summary of the profiled statistics
"""
# https://github.com/pytorch/pytorch/blob/master/torch/fx/graph.py
try:
from tabulate import tabulate
except ImportError:
print("`summary` relies on the library `tabulate`, "
"which could not be found on this machine. Run `pip "
"install tabulate` to install the library.")

# Build up a list of summary information for each node
node_summaries: List[List[Any]] = []

for node in self.module.graph.nodes:
node: Node
n_info = NInfo(node)
node_summaries.append([
node.op,
str(node),
_format_memory(n_info.param_size),
_format_flops(n_info.flops), # type: ignore
])

# Use the ``tabulate`` library to create a well-formatted table
# presenting our summary information
headers: List[str] = [
'Op type',
'Op',
'Param size',
'FLOPs',
]

return tabulate(node_summaries, headers=headers, stralign='right')

def as_dict(self, by_type = False) -> Dict[str, Dict[str, Union[int, None]]]:
"""
Returns the profiled statistics as a dictionary.

Parameters
----------
by_type: bool
Whether to return the statistics by module type. If ``False``, the statistics
are returned by module name. If ``True``, the statistics are returned by module
type.
"""
if by_type:
ret = {'flops': {}, 'params': {}}
for node in self.module.graph.nodes:
node: Node
if node.op == 'call_module':
module = self.fetch_attr(node.target)
if type(module).__name__ not in ret['flops']:
ret['flops'][type(module).__name__] = 0
if type(module) not in ret['params']:
ret['params'][type(module).__name__] = 0
ret['flops'][type(module).__name__] += NInfo(node).flops
ret['params'][type(module).__name__] += NInfo(node).param_size
return ret
else:
return {
'flops':
{node.name: NInfo(node).flops for node in self.module.graph.nodes if node.op == 'call_module'},
'params':
{node.name: NInfo(node).param_size for node in self.module.graph.nodes if node.op == 'call_module'}
}


def counter_pass(module: torch.fx.GraphModule,
*args,
verbose = False,
by_type = False) -> Dict[str, Dict[str, Union[int, None]]]:
"""A pass that counts the number of FLOPs and parameters in a model.

Parameters
----------
module: torch.fx.GraphModule
The module to be profiled.

verbose: bool
Whether to print the summary of the profiled statistics. Default: False.

Returns
-------
dictionary: A dictionary that contains the profiled statistics.
"""
interp = GraphCounter(module)
interp.run(*args)
if verbose:
print(interp.summarize())
return interp.as_dict(by_type)
Loading