Skip to content

Commit

Permalink
Merge pull request #127 from sovrasov/upd_docs
Browse files Browse the repository at this point in the history
Update docstring for public API
  • Loading branch information
sovrasov authored Dec 21, 2023
2 parents b5c9ef7 + f9067f0 commit 6e5b4d8
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 18 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
# ptflops versions log

## v 0.7.2
- Add type annotations and doc strings to the main API.
- Add support of HuggingFace/Timm VIT transformers.

## v 0.7.1.2
- Fix failure when using input constructor.

## v 0.7.1
- Experimental support of torchvision.ops.DeformConv2d
- Experimantal support of torch.functional.* and tensor.* operators
- Experimental support of torch.functional.* and tensor.* operators

## v 0.7
- Add ConvNext to sample, fix wrong torchvision compatibility requirement.
Expand Down
16 changes: 16 additions & 0 deletions ptflops/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
'''
Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''


from .flops_counter import get_model_complexity_info
from .utils import flops_to_string, params_to_string

__all__ = [
"get_model_complexity_info",
"flops_to_string",
"params_to_string",
]
61 changes: 53 additions & 8 deletions ptflops/flops_counter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,70 @@
'''

import sys
from typing import Any, Callable, Dict, TextIO, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union

import torch.nn as nn

from .pytorch_engine import get_flops_pytorch
from .utils import flops_to_string, params_to_string


def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...],
def get_model_complexity_info(model: nn.Module,
input_res: Tuple[int, ...],
print_per_layer_stat: bool = True,
as_strings: bool = True,
input_constructor: Union[Callable, None] = None,
input_constructor: Optional[Callable[[Tuple], Dict]] = None,
ost: TextIO = sys.stdout,
verbose: bool = False, ignore_modules=[],
verbose: bool = False,
ignore_modules: List[nn.Module] = [],
custom_modules_hooks: Dict[nn.Module, Any] = {},
backend: str = 'pytorch',
flops_units: Union[str, None] = None,
param_units: Union[str, None] = None,
output_precision: int = 2):
flops_units: Optional[str] = None,
param_units: Optional[str] = None,
output_precision: int = 2) -> Tuple[Union[str, int, None],
Union[str, int, None]]:
"""
Analyzes the input model and collects the amounts of parameters and MACs
required to make a forward pass of the model.
:param model: Input model to analyze
:type model: nn.Module
:param input_res: A tuple that sets the input resolution for the model. Batch
dimension is added automatically: (3, 224, 224) -> (1, 3, 224, 224).
:type input_res: Tuple[int, ...]
:param print_per_layer_stat: Flag to enable or disable printing of per-layer
MACs/params statistics. This feature works only for layers derived
from torch.nn.Module. Other operations are ignored.
:type print_per_layer_stat: bool
:param as_strings: Flag that allows to get ready-to-print string representation
of the final params/MACs estimations. Otherwise, a tuple with raw numbers
will be returned.
:type as_strings: bool
:param input_constructor: A callable that takes the :input_res parameter and
returns an output suitable for the model. It can be used if model requires
more than one input tensor or any other kind of irregular input.
:type input_constructor: Optional[Callable[[Tuple], Dict]]
:param ost: A stream to print output.
:type ost: TextIO
:param verbose: Parameter to control printing of extra information and warnings.
:type verbose: bool
:param ignore_modules: A list of torch.nn.Module modules to ignore.
:type ignore_modules: nn.Module
:param custom_modules_hooks: A dict that contains custom hooks on torch modules.
:type custom_modules_hooks: Dict[nn.Module, Any]
:param flops_units: Units for string representation of MACs (GMac, MMac or KMac).
:type flops_units: Optional[str]
:param param_units: Units for string representation of params (M, K or B).
:type param_units: Optional[str]
:param output_precision: Floating point precision for representing MACs/params in
given units.
:type output_precision: int
Returns:
Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple
(macs, params): Nones in case of a failure during computations, or
strings if :as_strings is true or integers otherwise.
"""
assert type(input_res) is tuple
assert len(input_res) >= 1
assert isinstance(model, nn.Module)
Expand All @@ -42,7 +87,7 @@ def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...],
else:
raise ValueError('Wrong backend name')

if as_strings:
if as_strings and flops_count is not None and params_count is not None:
flops_string = flops_to_string(
flops_count,
units=flops_units,
Expand Down
16 changes: 10 additions & 6 deletions ptflops/pytorch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import sys
import traceback
from functools import partial
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
Expand All @@ -24,9 +25,10 @@ def get_flops_pytorch(model, input_res,
input_constructor=None, ost=sys.stdout,
verbose=False, ignore_modules=[],
custom_modules_hooks={},
output_precision=3,
flops_units='GMac',
param_units='M'):
output_precision=2,
flops_units: Optional[str] = 'GMac',
param_units: Optional[str] = 'M') -> Tuple[Union[int, None],
Union[int, None]]:
global CUSTOM_MODULES_MAPPING
CUSTOM_MODULES_MAPPING = custom_modules_hooks
flops_model = add_flops_counting_methods(model)
Expand Down Expand Up @@ -84,7 +86,7 @@ def reset_environment():
)
reset_environment()

return flops_count, params_count
return int(flops_count), params_count


def accumulate_flops(self):
Expand All @@ -97,8 +99,10 @@ def accumulate_flops(self):
return sum


def print_model_with_flops(model, total_flops, total_params, flops_units='GMac',
param_units='M', precision=3, ost=sys.stdout):
def print_model_with_flops(model, total_flops, total_params,
flops_units: Optional[str] = 'GMac',
param_units: Optional[str] = 'M',
precision=3, ost=sys.stdout):
if total_flops < 1:
total_flops = 1
if total_params < 1:
Expand Down
28 changes: 25 additions & 3 deletions ptflops/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,24 @@
'''
Copyright (C) 2021 Sovrasov V. - All Rights Reserved
Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved
* You may use, distribute and modify this code under the
* terms of the MIT license.
* You should have received a copy of the MIT license with
* this file. If not visit https://opensource.org/licenses/MIT
'''


def flops_to_string(flops, units=None, precision=2):
from typing import Optional


def flops_to_string(flops: int, units: Optional[str] = None, precision: int = 2) -> str:
"""
Converts integer MACs representation to a readable string.
:param flops: Input MACs.
:param units: Units for string representation of MACs (GMac, MMac or KMac).
:param precision: Floating point precision for representing MACs in
given units.
"""
if units is None:
if flops // 10**9 > 0:
return str(round(flops / 10.**9, precision)) + ' GMac'
Expand All @@ -28,7 +39,16 @@ def flops_to_string(flops, units=None, precision=2):
return str(flops) + ' Mac'


def params_to_string(params_num, units=None, precision=2):
def params_to_string(params_num: int, units: Optional[str] = None,
precision: int = 2) -> str:
"""
Converts integer params representation to a readable string.
:param flops: Input number of parameters.
:param units: Units for string representation of params (M, K or B).
:param precision: Floating point precision for representing params in
given units.
"""
if units is None:
if params_num // 10 ** 6 > 0:
return str(round(params_num / 10 ** 6, precision)) + ' M'
Expand All @@ -41,5 +61,7 @@ def params_to_string(params_num, units=None, precision=2):
return str(round(params_num / 10.**6, precision)) + ' ' + units
elif units == 'K':
return str(round(params_num / 10.**3, precision)) + ' ' + units
elif units == 'B':
return str(round(params_num / 10.**9, precision)) + ' ' + units
else:
return str(params_num)

0 comments on commit 6e5b4d8

Please sign in to comment.