From afdec0e66b4556ef5ea838edbd788313dce757e1 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 28 Nov 2023 06:31:34 +0900 Subject: [PATCH 01/10] Add type anno to utils --- ptflops/utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/ptflops/utils.py b/ptflops/utils.py index 82092db..2e5d508 100644 --- a/ptflops/utils.py +++ b/ptflops/utils.py @@ -1,5 +1,5 @@ ''' -Copyright (C) 2021 Sovrasov V. - All Rights Reserved +Copyright (C) 2021-2023 Sovrasov V. - All Rights Reserved * You may use, distribute and modify this code under the * terms of the MIT license. * You should have received a copy of the MIT license with @@ -7,7 +7,10 @@ ''' -def flops_to_string(flops, units=None, precision=2): +from typing import Union + + +def flops_to_string(flops: int, units: Union[str, None] = None, precision: int = 2): if units is None: if flops // 10**9 > 0: return str(round(flops / 10.**9, precision)) + ' GMac' @@ -28,7 +31,7 @@ def flops_to_string(flops, units=None, precision=2): return str(flops) + ' Mac' -def params_to_string(params_num, units=None, precision=2): +def params_to_string(params_num: int, units: Union[str, None] = None, precision: int = 2): if units is None: if params_num // 10 ** 6 > 0: return str(round(params_num / 10 ** 6, precision)) + ' M' @@ -41,5 +44,7 @@ def params_to_string(params_num, units=None, precision=2): return str(round(params_num / 10.**6, precision)) + ' ' + units elif units == 'K': return str(round(params_num / 10.**3, precision)) + ' ' + units + elif units == 'B': + return str(round(params_num / 10.**9, precision)) + ' ' + units else: return str(params_num) From 8781ecff9f56b64e3618314cb0df15d6a76f3bc6 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Tue, 28 Nov 2023 06:38:24 +0900 Subject: [PATCH 02/10] Clarify types in main API --- ptflops/flops_counter.py | 2 +- ptflops/pytorch_engine.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index 296ea49..a9c69ba 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -42,7 +42,7 @@ def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...], else: raise ValueError('Wrong backend name') - if as_strings: + if as_strings and flops_count is not None and params_count is not None: flops_string = flops_to_string( flops_count, units=flops_units, diff --git a/ptflops/pytorch_engine.py b/ptflops/pytorch_engine.py index a2e71e4..e724cd6 100644 --- a/ptflops/pytorch_engine.py +++ b/ptflops/pytorch_engine.py @@ -9,6 +9,7 @@ import sys import traceback from functools import partial +from typing import Union import torch import torch.nn as nn @@ -25,8 +26,9 @@ def get_flops_pytorch(model, input_res, verbose=False, ignore_modules=[], custom_modules_hooks={}, output_precision=3, - flops_units='GMac', - param_units='M'): + flops_units: Union[str, None] = 'GMac', + param_units: Union[str, None] = 'M') \ + -> Union[tuple[None, None], tuple[int, int]]: global CUSTOM_MODULES_MAPPING CUSTOM_MODULES_MAPPING = custom_modules_hooks flops_model = add_flops_counting_methods(model) From 688f94f6e5775220a875622124c4956285667e8b Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:30:14 +0900 Subject: [PATCH 03/10] Add a docstring for get_model_complexity_info --- ptflops/flops_counter.py | 53 ++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index a9c69ba..f86d8a7 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -7,7 +7,7 @@ ''' import sys -from typing import Any, Callable, Dict, TextIO, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple import torch.nn as nn @@ -15,17 +15,56 @@ from .utils import flops_to_string, params_to_string -def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...], +def get_model_complexity_info(model: nn.Module, + input_res: Tuple[int, ...], print_per_layer_stat: bool = True, as_strings: bool = True, - input_constructor: Union[Callable, None] = None, + input_constructor: Optional[Callable] = None, ost: TextIO = sys.stdout, - verbose: bool = False, ignore_modules=[], + verbose: bool = False, + ignore_modules: List[nn.Module] = [], custom_modules_hooks: Dict[nn.Module, Any] = {}, backend: str = 'pytorch', - flops_units: Union[str, None] = None, - param_units: Union[str, None] = None, - output_precision: int = 2): + flops_units: Optional[str] = None, + param_units: Optional[str] = None, + output_precision: int = 2) -> Tuple: + """ + Analyzes the input model and collects the amounts of parameters and MACs + required to make a forward pass of the model. + + :param model: Input model to analyze + :type model: nn.Module + :param input_res: A tuple that sets the input resolution for the model. Batch dimension is added automatically: + (3, 224, 224) -> (1, 3, 224, 224). + :type input_res: Tuple[int, ...] + :param print_per_layer_stat: Flag to enable or disable printing of per-layer MACs/params statistics. + This feature works only for layers derived from torch.nn.Module. Other operations are ignored. + :type print_per_layer_stat: bool + :param as_strings: Flag that allows to get ready-to-print string representation of the final params/MACs estimations. + Otherwise, a tuple with raw numbers will be returned. + :type as_strings: bool + :param input_constructor: A callable that takes the :input_res parameter and returns an output suitable for the model. + It can be used if model requires more than one input tensor or any other kind of irregular input. + :type input_constructor: Callable + :param ost: A stream to print output. + :type ost: TextIO + :param verbose: Parameter to control printing of extra information and warnings. + :type verbose: bool + :param ignore_modules: A list of torch.nn.Module modules to ignore. + :type ignore_modules: nn.Module + :param custom_modules_hooks: A dict that contains custom hooks on torch modules. + :type custom_modules_hooks: Dict[nn.Module, Any] + :param flops_units: Units for string representation of MACs (GMac, MMac or KMac). + :type flops_units: Optional[str] + :param param_units: Units for string representation of params (M, K or B). + :type param_units: Optional[str] + :param output_precision: Floating point precision for representing MACs/params in given units. + :type output_precision: int + + Returns: + Tuple: Return value is a tuple (macs, params): Nones in case of a failure during computations, or + strings if :as_strings is true or integers otherwise. + """ assert type(input_res) is tuple assert len(input_res) >= 1 assert isinstance(model, nn.Module) From 412b245cbf302e25aa5f9d10a1f8f6e76a33daf7 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:30:36 +0900 Subject: [PATCH 04/10] Update changelog --- CHANGELOG.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ad19c76..317c094 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,15 @@ # ptflops versions log +## v 0.7.2 +- Add type annotations and doc strings to the main API. +- Add support of HuggingFace/Timm VIT transformers. + ## v 0.7.1.2 - Fix failure when using input constructor. ## v 0.7.1 - Experimental support of torchvision.ops.DeformConv2d -- Experimantal support of torch.functional.* and tensor.* operators +- Experimental support of torch.functional.* and tensor.* operators ## v 0.7 - Add ConvNext to sample, fix wrong torchvision compatibility requirement. From 1a2e5680ce6037ceae47aa9458c257d4fcde902f Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:31:19 +0900 Subject: [PATCH 05/10] Ensure flops are returned as int --- ptflops/pytorch_engine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ptflops/pytorch_engine.py b/ptflops/pytorch_engine.py index e724cd6..1df2a3f 100644 --- a/ptflops/pytorch_engine.py +++ b/ptflops/pytorch_engine.py @@ -86,7 +86,7 @@ def reset_environment(): ) reset_environment() - return flops_count, params_count + return int(flops_count), params_count def accumulate_flops(self): From 86b0d2d13f392acca58ef62cdb571a70d1dd31a4 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:35:11 +0900 Subject: [PATCH 06/10] Update return type --- ptflops/flops_counter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index f86d8a7..6dc8893 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -7,7 +7,7 @@ ''' import sys -from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple +from typing import Any, Callable, Dict, List, Optional, TextIO, Tuple, Union import torch.nn as nn @@ -27,7 +27,7 @@ def get_model_complexity_info(model: nn.Module, backend: str = 'pytorch', flops_units: Optional[str] = None, param_units: Optional[str] = None, - output_precision: int = 2) -> Tuple: + output_precision: int = 2) -> Tuple[Union[str, int, None], Union[str, int, None]]: """ Analyzes the input model and collects the amounts of parameters and MACs required to make a forward pass of the model. @@ -62,7 +62,7 @@ def get_model_complexity_info(model: nn.Module, :type output_precision: int Returns: - Tuple: Return value is a tuple (macs, params): Nones in case of a failure during computations, or + Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple (macs, params): Nones in case of a failure during computations, or strings if :as_strings is true or integers otherwise. """ assert type(input_res) is tuple From 02a6844d61cb968b059d69b743735fb8450176e6 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:41:12 +0900 Subject: [PATCH 07/10] Fix flake8 --- ptflops/flops_counter.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index 6dc8893..dddc0aa 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -27,24 +27,28 @@ def get_model_complexity_info(model: nn.Module, backend: str = 'pytorch', flops_units: Optional[str] = None, param_units: Optional[str] = None, - output_precision: int = 2) -> Tuple[Union[str, int, None], Union[str, int, None]]: + output_precision: int = 2) -> Tuple[Union[str, int, None], + Union[str, int, None]]: """ Analyzes the input model and collects the amounts of parameters and MACs required to make a forward pass of the model. :param model: Input model to analyze :type model: nn.Module - :param input_res: A tuple that sets the input resolution for the model. Batch dimension is added automatically: - (3, 224, 224) -> (1, 3, 224, 224). + :param input_res: A tuple that sets the input resolution for the model. Batch + dimension is added automatically: (3, 224, 224) -> (1, 3, 224, 224). :type input_res: Tuple[int, ...] - :param print_per_layer_stat: Flag to enable or disable printing of per-layer MACs/params statistics. - This feature works only for layers derived from torch.nn.Module. Other operations are ignored. + :param print_per_layer_stat: Flag to enable or disable printing of per-layer + MACs/params statistics. This feature works only for layers derived + from torch.nn.Module. Other operations are ignored. :type print_per_layer_stat: bool - :param as_strings: Flag that allows to get ready-to-print string representation of the final params/MACs estimations. - Otherwise, a tuple with raw numbers will be returned. + :param as_strings: Flag that allows to get ready-to-print string representation + of the final params/MACs estimations. Otherwise, a tuple with raw numbers + will be returned. :type as_strings: bool - :param input_constructor: A callable that takes the :input_res parameter and returns an output suitable for the model. - It can be used if model requires more than one input tensor or any other kind of irregular input. + :param input_constructor: A callable that takes the :input_res parameter and + returns an output suitable for the model. It can be used if model requires + more than one input tensor or any other kind of irregular input. :type input_constructor: Callable :param ost: A stream to print output. :type ost: TextIO @@ -58,12 +62,14 @@ def get_model_complexity_info(model: nn.Module, :type flops_units: Optional[str] :param param_units: Units for string representation of params (M, K or B). :type param_units: Optional[str] - :param output_precision: Floating point precision for representing MACs/params in given units. + :param output_precision: Floating point precision for representing MACs/params in + given units. :type output_precision: int Returns: - Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple (macs, params): Nones in case of a failure during computations, or - strings if :as_strings is true or integers otherwise. + Tuple[Union[str, int, None], Union[str, int, None]]: Return value is a tuple + (macs, params): Nones in case of a failure during computations, or + strings if :as_strings is true or integers otherwise. """ assert type(input_res) is tuple assert len(input_res) >= 1 From 2372b25e78f7ae41f9b4d076d681694257ddd368 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:48:38 +0900 Subject: [PATCH 08/10] Futher update of type annotations --- ptflops/flops_counter.py | 4 ++-- ptflops/pytorch_engine.py | 16 +++++++++------- ptflops/utils.py | 6 +++--- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/ptflops/flops_counter.py b/ptflops/flops_counter.py index dddc0aa..20caf34 100644 --- a/ptflops/flops_counter.py +++ b/ptflops/flops_counter.py @@ -19,7 +19,7 @@ def get_model_complexity_info(model: nn.Module, input_res: Tuple[int, ...], print_per_layer_stat: bool = True, as_strings: bool = True, - input_constructor: Optional[Callable] = None, + input_constructor: Optional[Callable[[Tuple], Dict]] = None, ost: TextIO = sys.stdout, verbose: bool = False, ignore_modules: List[nn.Module] = [], @@ -49,7 +49,7 @@ def get_model_complexity_info(model: nn.Module, :param input_constructor: A callable that takes the :input_res parameter and returns an output suitable for the model. It can be used if model requires more than one input tensor or any other kind of irregular input. - :type input_constructor: Callable + :type input_constructor: Optional[Callable[[Tuple], Dict]] :param ost: A stream to print output. :type ost: TextIO :param verbose: Parameter to control printing of extra information and warnings. diff --git a/ptflops/pytorch_engine.py b/ptflops/pytorch_engine.py index 1df2a3f..b7ed8bd 100644 --- a/ptflops/pytorch_engine.py +++ b/ptflops/pytorch_engine.py @@ -9,7 +9,7 @@ import sys import traceback from functools import partial -from typing import Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -25,10 +25,10 @@ def get_flops_pytorch(model, input_res, input_constructor=None, ost=sys.stdout, verbose=False, ignore_modules=[], custom_modules_hooks={}, - output_precision=3, - flops_units: Union[str, None] = 'GMac', - param_units: Union[str, None] = 'M') \ - -> Union[tuple[None, None], tuple[int, int]]: + output_precision=2, + flops_units: Optional[str] = 'GMac', + param_units: Optional[str] = 'M') -> Tuple[Union[int, None], + Union[int, None]]: global CUSTOM_MODULES_MAPPING CUSTOM_MODULES_MAPPING = custom_modules_hooks flops_model = add_flops_counting_methods(model) @@ -99,8 +99,10 @@ def accumulate_flops(self): return sum -def print_model_with_flops(model, total_flops, total_params, flops_units='GMac', - param_units='M', precision=3, ost=sys.stdout): +def print_model_with_flops(model, total_flops, total_params, + flops_units: Optional[str] = 'GMac', + param_units: Optional[str] = 'M', + precision=3, ost=sys.stdout): if total_flops < 1: total_flops = 1 if total_params < 1: diff --git a/ptflops/utils.py b/ptflops/utils.py index 2e5d508..69575b5 100644 --- a/ptflops/utils.py +++ b/ptflops/utils.py @@ -7,10 +7,10 @@ ''' -from typing import Union +from typing import Optional -def flops_to_string(flops: int, units: Union[str, None] = None, precision: int = 2): +def flops_to_string(flops: int, units: Optional[str] = None, precision: int = 2): if units is None: if flops // 10**9 > 0: return str(round(flops / 10.**9, precision)) + ' GMac' @@ -31,7 +31,7 @@ def flops_to_string(flops: int, units: Union[str, None] = None, precision: int = return str(flops) + ' Mac' -def params_to_string(params_num: int, units: Union[str, None] = None, precision: int = 2): +def params_to_string(params_num: int, units: Optional[str] = None, precision: int = 2): if units is None: if params_num // 10 ** 6 > 0: return str(round(params_num / 10 ** 6, precision)) + ' M' From 41c869adf3a453c0be22b8fc9f2ef4131c635ddf Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:52:17 +0900 Subject: [PATCH 09/10] Update module public interface --- ptflops/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/ptflops/__init__.py b/ptflops/__init__.py index b052a65..445292f 100644 --- a/ptflops/__init__.py +++ b/ptflops/__init__.py @@ -1 +1,17 @@ +''' +Copyright (C) 2019-2023 Sovrasov V. - All Rights Reserved + * You may use, distribute and modify this code under the + * terms of the MIT license. + * You should have received a copy of the MIT license with + * this file. If not visit https://opensource.org/licenses/MIT +''' + + from .flops_counter import get_model_complexity_info +from .utils import flops_to_string, params_to_string + +__all__ = [ + "get_model_complexity_info", + "flops_to_string", + "params_to_string", + ] From f9067f0385185f3191a6327b0cf3f33de6b14f59 Mon Sep 17 00:00:00 2001 From: Vladislav Sovrasov Date: Fri, 22 Dec 2023 10:59:40 +0900 Subject: [PATCH 10/10] Fininsh covering of public api with docs --- ptflops/utils.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/ptflops/utils.py b/ptflops/utils.py index 69575b5..0e70a24 100644 --- a/ptflops/utils.py +++ b/ptflops/utils.py @@ -10,7 +10,15 @@ from typing import Optional -def flops_to_string(flops: int, units: Optional[str] = None, precision: int = 2): +def flops_to_string(flops: int, units: Optional[str] = None, precision: int = 2) -> str: + """ + Converts integer MACs representation to a readable string. + + :param flops: Input MACs. + :param units: Units for string representation of MACs (GMac, MMac or KMac). + :param precision: Floating point precision for representing MACs in + given units. + """ if units is None: if flops // 10**9 > 0: return str(round(flops / 10.**9, precision)) + ' GMac' @@ -31,7 +39,16 @@ def flops_to_string(flops: int, units: Optional[str] = None, precision: int = 2) return str(flops) + ' Mac' -def params_to_string(params_num: int, units: Optional[str] = None, precision: int = 2): +def params_to_string(params_num: int, units: Optional[str] = None, + precision: int = 2) -> str: + """ + Converts integer params representation to a readable string. + + :param flops: Input number of parameters. + :param units: Units for string representation of params (M, K or B). + :param precision: Floating point precision for representing params in + given units. + """ if units is None: if params_num // 10 ** 6 > 0: return str(round(params_num / 10 ** 6, precision)) + ' M'