diff --git a/src/sparseml/pytorch/sparsification/training/__init__.py b/src/sparseml/pytorch/sparsification/training/__init__.py index 96dfc9225a6..49848fb3dd7 100644 --- a/src/sparseml/pytorch/sparsification/training/__init__.py +++ b/src/sparseml/pytorch/sparsification/training/__init__.py @@ -15,6 +15,7 @@ # limitations under the License. from .modifier_epoch import * +from .modifier_logging import * from .modifier_lr import * from .modifier_params import * from .modifier_regularizer import * diff --git a/src/sparseml/pytorch/sparsification/training/modifier_logging.py b/src/sparseml/pytorch/sparsification/training/modifier_logging.py new file mode 100644 index 00000000000..6a205f6fc4b --- /dev/null +++ b/src/sparseml/pytorch/sparsification/training/modifier_logging.py @@ -0,0 +1,79 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.nn import Module +from torch.optim.optimizer import Optimizer + +from sparseml.pytorch.sparsification.modifier import ( + PyTorchModifierYAML, + ScheduledUpdateModifier, +) +from sparseml.pytorch.utils import log_module_sparsification_info + + +__all__ = ["SparsificationLoggingModifier"] + + +@PyTorchModifierYAML() +class SparsificationLoggingModifier(ScheduledUpdateModifier): + """ + Modifier to log the sparsification information of a module. + Whenever this modifier is called, it will log the sparsification information + of the module that it is attached to, using the logger(s) provided to it. + + | Sample yaml: + | !SparsificationLoggingModifier + | start_epoch: 0.0 + | end_epoch: 10.0 + | update_frequency: 1 + + + :param start_epoch: The epoch to start the modifier at + (set to -1.0, so it starts immediately) + :param end_epoch: The epoch to end the modifier at, + (set to -1.0, so it doesn't end) + :param update_frequency: if set to -1.0, will log module's + sparsification information on each training step. + If set to a positive integer, will update at the given frequency, + at every epoch + """ + + def __init__( + self, + start_epoch: float, + end_epoch: float = -1.0, + update_frequency: float = 1.0, + ): + super(SparsificationLoggingModifier, self).__init__( + start_epoch=start_epoch, + end_epoch=end_epoch, + update_frequency=update_frequency, + end_comparator=-1, + ) + + def update( + self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int + ): + """ + Calls into the lr scheduler to step given the epoch + Additionally will first set the lr to the init_lr if not set yet + + :param module: module to modify + :param optimizer: optimizer to modify + :param epoch: current epoch and progress within the current epoch + :param steps_per_epoch: number of steps taken within each epoch + (calculate batch number using this and epoch) + """ + super().update(module, optimizer, epoch, steps_per_epoch) + log_module_sparsification_info(module=module, logger=self.loggers, step=epoch) diff --git a/src/sparseml/pytorch/utils/__init__.py b/src/sparseml/pytorch/utils/__init__.py index 88fad7adf76..c99bd5638b1 100644 --- a/src/sparseml/pytorch/utils/__init__.py +++ b/src/sparseml/pytorch/utils/__init__.py @@ -23,6 +23,7 @@ from .distributed import * from .exporter import * from .helpers import * +from .log_sparsification_info import * from .logger import * from .loss import * from .model import * diff --git a/src/sparseml/pytorch/utils/log_sparsification_info.py b/src/sparseml/pytorch/utils/log_sparsification_info.py new file mode 100644 index 00000000000..5679aec70a9 --- /dev/null +++ b/src/sparseml/pytorch/utils/log_sparsification_info.py @@ -0,0 +1,44 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch + +from sparseml.pytorch.utils.logger import BaseLogger +from sparseml.pytorch.utils.sparsification_info.module_sparsification_info import ( + ModuleSparsificationInfo, +) + + +__all__ = ["log_module_sparsification_info"] + + +def log_module_sparsification_info( + module: torch.nn.Module, logger: BaseLogger, step: Optional[float] = None +): + """ + Log the sparsification information for the given module to the given logger + + :param module: the module to log the sparsification information for + :param logger: the logger to log the sparsification information to + :param step: the global step for when the sparsification information + is being logged. By default, is None + """ + sparsification_info = ModuleSparsificationInfo.from_module(module) + for tag, value in sparsification_info.loggable_items(): + if isinstance(value, dict): + logger.log_scalars(tag=tag, step=step, values=value) + else: + logger.log_scalar(tag=tag, step=step, value=value) diff --git a/src/sparseml/pytorch/utils/sparsification_info/__init__.py b/src/sparseml/pytorch/utils/sparsification_info/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/src/sparseml/pytorch/utils/sparsification_info/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/sparseml/pytorch/utils/sparsification_info/configs.py b/src/sparseml/pytorch/utils/sparsification_info/configs.py new file mode 100644 index 00000000000..32292eee008 --- /dev/null +++ b/src/sparseml/pytorch/utils/sparsification_info/configs.py @@ -0,0 +1,274 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from typing import Dict, Generator, Tuple, Union + +import torch.nn +from pydantic import BaseModel, Field + +from sparseml.pytorch.utils.sparsification_info.helpers import ( + get_leaf_operations, + get_precision_information, + is_quantized, +) + + +__all__ = [ + "SparsificationSummaries", + "SparsificationPruning", + "SparsificationQuantization", + "SparsificationInfo", +] + + +class SparsificationInfo(BaseModel, ABC): + @classmethod + @abstractmethod + def from_module( + cls, + module: torch.nn.Module, + **kwargs, + ) -> "SparsificationInfo": + """ + Factory method to create SparsificationInfo object from a module. + + :param module: The module to create the SparsificationInfo object from. + :param kwargs: Additional arguments to pass to the SparsificationInfo object. + :return: A SparsificationInfo object. + """ + raise NotImplementedError() + + @abstractmethod + def loggable_items( + self, + ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + """ + Yield the loggable items for SparsificationInfo object. + + :return: A generator that yields the loggable items for this object. + """ + raise NotImplementedError() + + +class CountAndPercent(BaseModel): + count: int = Field(description="The count of items") + percent: float = Field(description="The percentage of those items out of the total") + + +class SparsificationSummaries(SparsificationInfo): + """ + A model that contains the sparsification summaries for a torch module. + """ + + quantized: CountAndPercent = Field( + description="A model that contains the number of " + "operations/the percent of operations that are quantized." + ) + pruned: CountAndPercent = Field( + description="A model that contains the number of " + "parameters/the percent of parameters that are pruned." + ) + parameter_counts: Dict[str, int] = Field( + description="A dictionary that maps the name of a parameter " + "to the number of elements (weights) in that parameter." + ) + operation_counts: Dict[str, int] = Field( + description="A dictionary that maps the name of an operation " + "to the number of times that operation is used in the model." + ) + + @classmethod + def from_module( + cls, + module=torch.nn.Module, + pruning_thresholds: Tuple[float, float] = (0.05, 1 - 1e-9), + ) -> "SparsificationSummaries": + """ + Factory method to create a SparsificationSummaries object from a module. + + :param module: The module to create the SparsificationSummaries object from. + :param pruning_thresholds: The lower and upper thresholds used to determine + whether a parameter is pruned. If it's percentage of zero weights is between + the lower and upper thresholds, it is considered pruned. + :return: A SparsificationSummaries object. + """ + operations = get_leaf_operations(module) + num_quantized_ops = sum([is_quantized(op) for op in operations]) + total_num_params = len(list(module.parameters())) + + lower_threshold_pruning = min(pruning_thresholds) + upper_threshold_pruning = max(pruning_thresholds) + total_num_params_pruned = 0 + count_parameters = defaultdict(int) + + for param_name, param in module.named_parameters(): + num_parameters = param.numel() + num_zero_parameters = param.numel() - param.count_nonzero().item() + + if ( + lower_threshold_pruning + <= num_zero_parameters / num_parameters + <= upper_threshold_pruning + ): + total_num_params_pruned += 1 + + count_parameters[param_name] = num_parameters + + return cls( + pruned=CountAndPercent( + count=total_num_params_pruned, + percent=total_num_params_pruned / total_num_params, + ), + quantized=CountAndPercent( + count=num_quantized_ops, percent=num_quantized_ops / len(operations) + ), + parameter_counts=count_parameters, + operation_counts=Counter([op.__class__.__name__ for op in operations]), + ) + + def loggable_items( + self, + ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + """ + Yield the loggable items for SparsificationSummaries object. + + :return: A generator that yields the loggable items for this object. + """ + main_tag = self.__class__.__name__ + yield f"{main_tag}/OperationCounts", self.operation_counts + yield f"{main_tag}/ParameterCounts", self.parameter_counts + yield f"{main_tag}/QuantizedOperations/count", self.quantized.count + yield f"{main_tag}/QuantizedOperations/percent", self.quantized.percent + yield f"{main_tag}/PrunedParameters/count", self.pruned.count + yield f"{main_tag}/PrunedParameters/percent", self.pruned.percent + + +class SparsificationPruning(SparsificationInfo): + """ + A model that contains the pruning information for a torch module. + """ + + sparse_parameters: Dict[str, CountAndPercent] = Field( + description="A dictionary that maps the name of a parameter " + "to the number/percent of weights that are zeroed out " + "in that layer." + ) + + @classmethod + def from_module(cls, module: torch.nn.Module) -> "SparsificationPruning": + """ + Factory method to create a SparsificationPruning object from a module. + + :param module: The module to create the SparsificationPruning object from. + :return: A SparsificationPruning object. + """ + sparse_parameters_count = defaultdict(CountAndPercent) + for param_name, param in module.named_parameters(): + num_parameters = param.numel() + num_zero_parameters = param.numel() - param.count_nonzero().item() + + zero_count = num_zero_parameters + zero_count_percent = num_zero_parameters / num_parameters + + sparse_parameters_count[param_name] = CountAndPercent( + count=zero_count, percent=zero_count_percent + ) + + return cls(sparse_parameters=sparse_parameters_count) + + def loggable_items( + self, + ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + """ + Yield the loggable items for SparsificationPruning object. + + :return: A generator that yields the loggable items for this object. + """ + main_tag = self.__class__.__name__ + for param_name, count_and_percent in self.sparse_parameters.items(): + yield f"{main_tag}/SparseParameters/{param_name}/count", count_and_percent.count # noqa: E501 + yield f"{main_tag}/SparseParameters/{param_name}/percent", count_and_percent.percent # noqa: E501 + + +class SparsificationQuantization(SparsificationInfo): + """ + A model that contains the quantization information for a torch module. + """ + + enabled: Dict[str, bool] = Field( + description="A dictionary that maps the name of an " + "operation to a boolean flag that indicates whether " + "the operation is quantized or not." + ) + precision: Dict[str, Union[BaseModel, None, int]] = Field( + description="A dictionary that maps the name of a layer" + "to the precision of that layer." + ) + + class Config: + arbitrary_types_allowed = True + + @classmethod + def from_module( + cls, + module: torch.nn.Module, + ) -> "SparsificationQuantization": + """ + Factory method to create a SparsificationQuantization object from a module. + + :param module: The module to create the SparsificationQuantization object from. + :return: A SparsificationQuantization object. + """ + operations = get_leaf_operations(module) + enabled = defaultdict(bool) + precision = defaultdict(str) + for op in operations: + operation_name = op.__class__.__name__ + operation_counter = 0 + # make sure that the operation name is unique + while enabled.get(operation_name) is not None: + operation_counter += 1 + operation_name = f"{op.__class__.__name__}_{operation_counter}" + + enabled[operation_name] = is_quantized(op) + precision[operation_name] = get_precision_information(op) + + return cls(enabled=enabled, precision=precision) + + def loggable_items( + self, + ) -> Generator[Tuple[str, Union[Dict[str, int], float, int]], None, None]: + """ + Yield the loggable items for SparsificationQuantization object. + + :return: A generator that yields the loggable items for this object. + """ + main_tag = self.__class__.__name__ + for operation in self.enabled.keys(): + yield f"{main_tag}/{operation}/enabled", self.enabled[operation] + + precision = self.precision[operation] + if precision is None: + yield f"{main_tag}/{operation}/precision", precision + elif isinstance(precision, int): + yield f"{main_tag}/{operation}/precision.weights/num_bits", precision + elif isinstance(precision, BaseModel): + yield f"{main_tag}/{operation}/precision/weights/num_bits", precision.weights.num_bits # noqa: E501 + yield f"{main_tag}/{operation}/precision/input_activations/num_bits", precision.input_activations.num_bits # noqa: E501 + else: + raise ValueError( + f"The precision is not a valid type {type(precision)}." + ) diff --git a/src/sparseml/pytorch/utils/sparsification_info/helpers.py b/src/sparseml/pytorch/utils/sparsification_info/helpers.py new file mode 100644 index 00000000000..5245fe823ab --- /dev/null +++ b/src/sparseml/pytorch/utils/sparsification_info/helpers.py @@ -0,0 +1,123 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union + +import torch +from torch.nn.modules.linear import Identity +from torch.quantization import QuantWrapper + + +__all__ = ["get_leaf_operations", "is_quantized", "get_precision_information"] + + +def get_leaf_operations( + model: torch.nn.Module, + operations_to_skip: Optional[List[torch.nn.Module]] = None, + operations_to_unwrap: Optional[List[torch.nn.Module]] = None, +) -> List[torch.nn.Module]: + """ + Get the leaf operations in the model + (those that do not have operations as children) + + :param model: the model to get the leaf operations from + :param operations_to_skip: a list of leaf operations that will be + omitted when getting the leaf operations. If None passed, by + default the Identity operation will be skipped + :param operations_to_unwrap: a list of operations that will be unwrapped + when getting the leaf operations. Unwrapping means that we directly + add the module(s) that is/are wrapped by the operation (i.e. operation's + `module` attribute) to the list + of leaf operations. If None passed, by default the QuantWrapper + operation will be unwrapped + :return: a list of the leaf operations + """ + if operations_to_skip is None: + operations_to_skip = [Identity] + + if operations_to_unwrap is None: + operations_to_unwrap = [QuantWrapper] + + leaf_operations = [] + children = list(model.children()) + + if children == []: + return model + else: + for child in children: + if isinstance(child, tuple(operations_to_unwrap)): + leaf_operations.append(child.module) + continue + try: + leaf_operations.extend(get_leaf_operations(child)) + except TypeError: + leaf_operations.append(get_leaf_operations(child)) + leaf_operations = [ + op for op in leaf_operations if not isinstance(op, tuple(operations_to_skip)) + ] + return leaf_operations + + +def is_quantized(operation: torch.nn.Module) -> bool: + """ + Check whether the operation is quantized (contains + a quantization scheme) + """ + return hasattr(operation, "quantization_scheme") + + +def get_precision_information( + operation: torch.nn.Module, +) -> Union[None, int, "QuantizationScheme"]: # noqa F821 + """ + Get the information about the precision of the operation. + + 1) If operation is quantized, returns the quantization + scheme of the operation. + 2) If operation is not quantized, returns the numer of bits + of the operation's weights. + 3) If operation is not quantized and does not have a weights, + returns None. + + :param operation: the operation to get the quantization scheme from + :return: the quantization scheme of the operation, the number of bits + of the operation's weights, or None if the operation is not quantized + and does not have a weight + """ + + if hasattr(operation, "quantization_scheme"): + return getattr(operation, "quantization_scheme") + elif hasattr(operation, "weight"): + return _get_num_bits(operation.weight.dtype) + else: + return None + + +def _get_num_bits(dtype: torch.dtype) -> int: + # Get the number of bits of a torch dtype + if dtype == torch.float16: + return 16 + elif dtype == torch.float32: + return 32 + elif dtype == torch.float64: + return 64 + elif dtype == torch.int8: + return 8 + elif dtype == torch.int16: + return 16 + elif dtype == torch.int32: + return 32 + elif dtype == torch.int64: + return 64 + else: + raise ValueError("Unknown dtype: {}".format(dtype)) diff --git a/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py b/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py new file mode 100644 index 00000000000..c5ab483f449 --- /dev/null +++ b/src/sparseml/pytorch/utils/sparsification_info/module_sparsification_info.py @@ -0,0 +1,78 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Generator, Tuple + +import torch +from pydantic import Field + +from sparseml.pytorch.utils.sparsification_info.configs import ( + SparsificationInfo, + SparsificationPruning, + SparsificationQuantization, + SparsificationSummaries, +) + + +class ModuleSparsificationInfo(SparsificationInfo): + """ + Pydantic model for storing sparsification information of a torch module. + """ + + summary_info: SparsificationSummaries = Field( + description="Model that holds the sparsification summary info of the module" + ) + pruning_info: SparsificationPruning = Field( + description="Model that holds the pruning info of the module" + ) + quantization_info: SparsificationQuantization = Field( + description="Model that holds the quantization info of the module" + ) + + @classmethod + def from_module(cls, module: torch.nn.Module) -> "ModuleSparsificationInfo": + """ + Factory method to create a ModuleSparsificationInfo object from a torch module. + + :param module: the module to create the ModuleSparsificationInfo object from + :return: the ModuleSparsificationInfo object created from the module + """ + if not isinstance(module, torch.nn.Module): + raise ValueError( + "Module must be a torch.nn.Module, not {}".format(type(module)) + ) + + return cls( + summary_info=SparsificationSummaries.from_module(module), + pruning_info=SparsificationPruning.from_module(module), + quantization_info=SparsificationQuantization.from_module(module), + ) + + return cls( + summary_info=SparsificationSummaries.from_module(module), + pruning_info=SparsificationPruning.from_module(module), + quantization_info=SparsificationQuantization.from_module(module), + ) + + def loggable_items(self) -> Generator[Tuple[str, Any], None, None]: + """ + A generator that yields the loggable items of + the ModuleSparsificationInfo object. + + :return a generator that yields a tuple of: + - the name of the loggable item + - the value of the loggable item + """ + for info in [self.summary_info, self.pruning_info, self.quantization_info]: + yield from info.loggable_items() diff --git a/tests/sparseml/pytorch/sparsification/training/test_modifier_logging.py b/tests/sparseml/pytorch/sparsification/training/test_modifier_logging.py new file mode 100644 index 00000000000..909dd70478b --- /dev/null +++ b/tests/sparseml/pytorch/sparsification/training/test_modifier_logging.py @@ -0,0 +1,64 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest + +from sparseml.pytorch.sparsification.training import SparsificationLoggingModifier + + +@pytest.mark.skipif( + os.getenv("NM_ML_SKIP_PYTORCH_TESTS", False), + reason="Skipping pytorch tests", +) +@pytest.mark.parametrize( + "start_epoch, end_epoch, update_frequency", + [ + (0.0, 10.0, 2), + (5.0, -1, 1), + (0.0, 1.0, -1), + ], +) +def test_epoch_range_yaml(start_epoch, end_epoch, update_frequency): + yaml_str = """ + !SparsificationLoggingModifier + start_epoch: {start_epoch} + end_epoch: {end_epoch} + update_frequency: {update_frequency} + """.format( + start_epoch=start_epoch, end_epoch=end_epoch, update_frequency=update_frequency + ) + yaml_modifier = SparsificationLoggingModifier.load_obj(yaml_str) + serialized_modifier = SparsificationLoggingModifier.load_obj(str(yaml_modifier)) + obj_modifier = SparsificationLoggingModifier( + start_epoch=start_epoch, end_epoch=end_epoch, update_frequency=update_frequency + ) + + assert isinstance(yaml_modifier, SparsificationLoggingModifier) + assert ( + yaml_modifier.start_epoch + == serialized_modifier.start_epoch + == obj_modifier.start_epoch + ) + assert ( + yaml_modifier.end_epoch + == serialized_modifier.end_epoch + == obj_modifier.end_epoch + ) + assert ( + yaml_modifier.update_frequency + == serialized_modifier.update_frequency + == obj_modifier.update_frequency + ) diff --git a/tests/sparseml/pytorch/utils/test_sparsification_info/__init__.py b/tests/sparseml/pytorch/utils/test_sparsification_info/__init__.py new file mode 100644 index 00000000000..0c44f887a47 --- /dev/null +++ b/tests/sparseml/pytorch/utils/test_sparsification_info/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/sparseml/pytorch/utils/test_sparsification_info/test_configs.py b/tests/sparseml/pytorch/utils/test_sparsification_info/test_configs.py new file mode 100644 index 00000000000..6a7418c4b49 --- /dev/null +++ b/tests/sparseml/pytorch/utils/test_sparsification_info/test_configs.py @@ -0,0 +1,177 @@ +# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +from sparseml.pytorch.optim import ScheduledModifierManager +from sparseml.pytorch.utils.sparsification_info.configs import ( + SparsificationPruning, + SparsificationQuantization, + SparsificationSummaries, +) + + +QUANT_RECIPE = """ +!QuantizationModifier + start_epoch: 0.0 + scheme: + input_activations: + num_bits: 8 + symmetric: False + weights: + num_bits: 4 + symmetric: True + scheme_overrides: + classifier: + input_activations: + num_bits: 8 + symmetric: False + weights: null + Conv2d: + input_activations: + num_bits: 8 + symmetric: True + ignore: ["ReLU", "input"] + """ + + +def _create_test_model(quantization_recipe: str) -> torch.nn.Module: + sub_model_1 = torch.nn.Sequential( + torch.nn.Conv2d(1, 16, 3, padding=1), + torch.nn.ReLU(), + torch.nn.Conv2d(16, 32, 3, padding=1), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2), + ) + sub_model_2 = torch.nn.Sequential( + torch.nn.Flatten(), torch.nn.Linear(32 * 14 * 14, 10) + ) + + model = torch.nn.ModuleList([sub_model_1, sub_model_2]) + + # set some weights to zero to simulate pruning + named_parameters = dict(model.named_parameters()) + named_parameters["1.1.weight"].data[:5, :] = torch.zeros_like( + named_parameters["1.1.weight"].data + )[:5, :] + + manager = ScheduledModifierManager.from_yaml(quantization_recipe) + manager.apply(model) + return model + + +_expected_summaries = { + "OperationCounts": { + "ConvReLU2d": 2, + "MaxPool2d": 1, + "Flatten": 1, + "Linear": 1, + }, + "ParameterCounts": { + "0.0.module.weight": 144, + "0.0.module.bias": 16, + "0.2.module.weight": 4608, + "0.2.module.bias": 32, + "1.1.module.weight": 62720, + "1.1.module.bias": 10, + }, + "QuantizedOperations/count": 4, + "QuantizedOperations/percent": 0.8, + "PrunedParameters/count": 1, + "PrunedParameters/percent": 0.16666666666666666, +} + +_expected_pruning = { + "0.0.module.weight/count": 0, + "0.0.module.weight/percent": 0.0, + "0.0.module.bias/count": 0, + "0.0.module.bias/percent": 0.0, + "0.2.module.weight/count": 0, + "0.2.module.weight/percent": 0.0, + "0.2.module.bias/count": 0, + "0.2.module.bias/percent": 0.0, + "1.1.module.weight/count": 31360, + "1.1.module.weight/percent": 0.5, + "1.1.module.bias/count": 0, + "1.1.module.bias/percent": 0.0, +} + +_expected_quantization = { + "ConvReLU2d/enabled": True, + "ConvReLU2d/precision/weights/num_bits": 4, + "ConvReLU2d/precision/input_activations/num_bits": 8, + "ConvReLU2d_1/enabled": True, + "ConvReLU2d_1/precision/weights/num_bits": 4, + "ConvReLU2d_1/precision/input_activations/num_bits": 8, + "MaxPool2d/enabled": True, + "MaxPool2d/precision/weights/num_bits": 4, + "MaxPool2d/precision/input_activations/num_bits": 8, + "Flatten/enabled": False, + "Flatten/precision": None, + "Linear/enabled": True, + "Linear/precision/weights/num_bits": 4, + "Linear/precision/input_activations/num_bits": 8, +} + + +@pytest.mark.parametrize( + "model, expected_summaries, expected_pruning, expected_quantization", + [ + ( + _create_test_model(quantization_recipe=QUANT_RECIPE), + _expected_summaries, + _expected_pruning, + _expected_quantization, + ) + ], +) +class TestSparsificationModels: + @pytest.fixture() + def setup(self, model, expected_summaries, expected_pruning, expected_quantization): + self.expected_summaries = expected_summaries + self.expected_pruning = expected_pruning + self.expected_quantization = expected_quantization + + yield model + + def test_sparsification_summaries(self, setup): + sparsification_summary = SparsificationSummaries.from_module(module=setup) + for tag, item in sparsification_summary.loggable_items(): + assert ( + self.expected_summaries[tag.replace("SparsificationSummaries/", "")] + == item + ) + + def test_sparsification_pruning(self, setup): + sparsification_pruning = SparsificationPruning.from_module(module=setup) + for tag, item in sparsification_pruning.loggable_items(): + assert ( + self.expected_pruning[ + tag.replace("SparsificationPruning/SparseParameters/", "") + ] + == item + ) + + def test_sparsification_quantization(self, setup): + sparsification_quantization = SparsificationQuantization.from_module( + module=setup + ) + for tag, item in sparsification_quantization.loggable_items(): + assert ( + self.expected_quantization[ + tag.replace("SparsificationQuantization/", "") + ] + == item + )