Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SparsificationLoggingModifier implementation #1453

Merged
merged 9 commits into from
Mar 28, 2023
1 change: 1 addition & 0 deletions src/sparseml/pytorch/sparsification/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.

from .modifier_epoch import *
from .modifier_logging import *
from .modifier_lr import *
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
from .modifier_params import *
from .modifier_regularizer import *
79 changes: 79 additions & 0 deletions src/sparseml/pytorch/sparsification/training/modifier_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from torch.nn import Module
from torch.optim.optimizer import Optimizer

from sparseml.pytorch.sparsification.modifier import (
PyTorchModifierYAML,
ScheduledUpdateModifier,
)
from sparseml.pytorch.utils import log_module_sparsification_info


__all__ = ["SparsificationLoggingModifier"]


@PyTorchModifierYAML()
class SparsificationLoggingModifier(ScheduledUpdateModifier):
"""
Modifier to log the sparsification information of a module.
Whenever this modifier is called, it will log the sparsification information
of the module that it is attached to, using the logger(s) provided to it.

| Sample yaml:
| !SparsificationLoggingModifier
| start_epoch: 0.0
| end_epoch: 10.0
| update_frequency: 1


:param start_epoch: The epoch to start the modifier at
(set to -1.0, so it starts immediately)
:param end_epoch: The epoch to end the modifier at,
(set to -1.0, so it doesn't end)
:param update_frequency: if set to -1.0, will log module's
sparsification information on each training step.
If set to a positive integer, will update at the given frequency,
at every epoch
"""

def __init__(
self,
start_epoch: float,
end_epoch: float = -1.0,
update_frequency: float = 1.0,
):
super(SparsificationLoggingModifier, self).__init__(
start_epoch=start_epoch,
end_epoch=end_epoch,
update_frequency=update_frequency,
end_comparator=-1,
)

def update(
self, module: Module, optimizer: Optimizer, epoch: float, steps_per_epoch: int
):
"""
Calls into the lr scheduler to step given the epoch
Additionally will first set the lr to the init_lr if not set yet

:param module: module to modify
:param optimizer: optimizer to modify
:param epoch: current epoch and progress within the current epoch
:param steps_per_epoch: number of steps taken within each epoch
(calculate batch number using this and epoch)
"""
super().update(module, optimizer, epoch, steps_per_epoch)
log_module_sparsification_info(module=module, logger=self.loggers, step=epoch)
1 change: 1 addition & 0 deletions src/sparseml/pytorch/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .distributed import *
from .exporter import *
from .helpers import *
from .log_sparsification_info import *
from .logger import *
from .loss import *
from .model import *
Expand Down
44 changes: 44 additions & 0 deletions src/sparseml/pytorch/utils/log_sparsification_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch

from sparseml.pytorch.utils.logger import BaseLogger
from sparseml.pytorch.utils.sparsification_info.module_sparsification_info import (
ModuleSparsificationInfo,
)


__all__ = ["log_module_sparsification_info"]


def log_module_sparsification_info(
module: torch.nn.Module, logger: BaseLogger, step: Optional[float] = None
):
"""
Log the sparsification information for the given module to the given logger

:param module: the module to log the sparsification information for
:param logger: the logger to log the sparsification information to
:param step: the global step for when the sparsification information
is being logged. By default, is None
"""
sparsification_info = ModuleSparsificationInfo.from_module(module)
for tag, value in sparsification_info.loggable_items():
if isinstance(value, dict):
logger.log_scalars(tag=tag, step=step, values=value)
else:
logger.log_scalar(tag=tag, step=step, value=value)
13 changes: 13 additions & 0 deletions src/sparseml/pytorch/utils/sparsification_info/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading