Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

add extra work during phrase advance in classification task #630

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion classy_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def main(args, config):


def configure_hooks(args, config):
hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()]
hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook(verbose=True)]

# Make a folder to store checkpoints and tensorboard logging outputs
suffix = datetime.now().isoformat()
Expand Down
2 changes: 1 addition & 1 deletion classy_vision/dataset/dataloader_limit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __next__(self) -> Any:
if self.wrap_around:
# create a new iterator to load data from the beginning
logging.info(
f"Wrapping around after {self._count} calls. Limit: {self.limit}"
f"Wrapping around after {self._count - 1} calls. Limit: {self.limit}"
)
try:
self._iter = iter(self.dataloader)
Expand Down
233 changes: 99 additions & 134 deletions classy_vision/generic/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,24 @@ def get_shape(x: Union[Tuple, List, Dict]) -> Union[Tuple, List, Dict]:
return x.size()


def _layer_flops(layer: nn.Module, x: Any, y: Any) -> int:
def _get_batchsize_per_replica(x: Union[Tuple, List, Dict]) -> int:
"""
Some layer may take tuple/list/dict/list[dict] as input in forward function. We
recursively dive into the tuple/list until we meet a tensor and infer the batch size
"""
while isinstance(x, (list, tuple)):
assert len(x) > 0, "input x of tuple/list type must have at least one element"
x = x[0]

if isinstance(x, (dict,)):
# index zero is always equal to batch size. select an arbitrary key.
key_list = list(x.keys())
x = x[key_list[0]]

return x.size()[0]


def _layer_flops(layer: nn.Module, x: Any, y: Any, verbose: bool = False) -> int:
"""
Computes the number of FLOPs required for a single layer.

Expand Down Expand Up @@ -146,6 +163,36 @@ def flops(self, x):
/ layer.groups
)

# 3D convolution
elif layer_type in ["Conv3d"]:
out_t = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
// layer.stride[0]
+ 1
)
out_h = int(
(x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
// layer.stride[1]
+ 1
)
out_w = int(
(x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
// layer.stride[2]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* layer.kernel_size[1]
* layer.kernel_size[2]
* out_t
* out_h
* out_w
/ layer.groups
)

# learned group convolution:
elif layer_type in ["LearnedGroupConv"]:
conv = layer.conv
Expand All @@ -170,45 +217,31 @@ def flops(self, x):
)
flops = count1 + count2

# non-linearities:
# non-linearities are not considered in MAC counting
elif layer_type in ["ReLU", "ReLU6", "Tanh", "Sigmoid", "Softmax"]:
flops = x.numel()

# 2D pooling layers:
elif layer_type in ["AvgPool2d", "MaxPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
if isinstance(layer.kernel_size, int):
layer.kernel_size = (layer.kernel_size, layer.kernel_size)
kernel_ops = layer.kernel_size[0] * layer.kernel_size[1]
out_h = 1 + int(
(in_h + 2 * layer.padding - layer.kernel_size[0]) / layer.stride
)
out_w = 1 + int(
(in_w + 2 * layer.padding - layer.kernel_size[1]) / layer.stride
flops = 0

elif layer_type in [
"MaxPool1d",
"MaxPool2d",
"MaxPool3d",
"AdaptiveMaxPool1d",
"AdaptiveMaxPool2d",
"AdaptiveMaxPool3d",
]:
flops = 0

elif layer_type in ["AvgPool1d", "AvgPool2d", "AvgPool3d"]:
kernel_ops = 1
flops = kernel_ops * y.numel()

elif layer_type in ["AdaptiveAvgPool1d", "AdaptiveAvgPool2d", "AdaptiveAvgPool3d"]:
assert isinstance(layer.output_size, (list, tuple))
kernel = torch.Tensor(list(x.shape[2:])) // torch.Tensor(
[list(layer.output_size)]
)
flops = x.size()[0] * x.size()[1] * out_w * out_h * kernel_ops

# adaptive avg pool2d
# This is approximate and works only for downsampling without padding
# based on aten/src/ATen/native/AdaptiveAveragePooling.cpp
elif layer_type in ["AdaptiveAvgPool2d"]:
in_h = x.size()[2]
in_w = x.size()[3]
if isinstance(layer.output_size, int):
out_h, out_w = layer.output_size, layer.output_size
elif len(layer.output_size) == 1:
out_h, out_w = layer.output_size[0], layer.output_size[0]
else:
out_h, out_w = layer.output_size
if out_h > in_h or out_w > in_w:
raise ClassyProfilerNotImplementedError(layer)
batchsize_per_replica = x.size()[0]
num_channels = x.size()[1]
kh = in_h - out_h + 1
kw = in_w - out_w + 1
kernel_ops = kh * kw
flops = batchsize_per_replica * num_channels * out_h * out_w * kernel_ops
kernel_ops = torch.prod(kernel)
flops = kernel_ops * y.numel()

# linear layer:
elif layer_type in ["Linear"]:
Expand All @@ -224,94 +257,12 @@ def flops(self, x):
"SyncBatchNorm",
"LayerNorm",
]:
flops = 2 * x.numel()

# 3D convolution
elif layer_type in ["Conv3d"]:
out_t = int(
(x.size()[2] + 2 * layer.padding[0] - layer.kernel_size[0])
// layer.stride[0]
+ 1
)
out_h = int(
(x.size()[3] + 2 * layer.padding[1] - layer.kernel_size[1])
// layer.stride[1]
+ 1
)
out_w = int(
(x.size()[4] + 2 * layer.padding[2] - layer.kernel_size[2])
// layer.stride[2]
+ 1
)
flops = (
batchsize_per_replica
* layer.in_channels
* layer.out_channels
* layer.kernel_size[0]
* layer.kernel_size[1]
* layer.kernel_size[2]
* out_t
* out_h
* out_w
/ layer.groups
)

# 3D pooling layers
elif layer_type in ["AvgPool3d", "MaxPool3d"]:
in_t = x.size()[2]
in_h = x.size()[3]
in_w = x.size()[4]
if isinstance(layer.kernel_size, int):
layer.kernel_size = (
layer.kernel_size,
layer.kernel_size,
layer.kernel_size,
)
if isinstance(layer.padding, int):
layer.padding = (layer.padding, layer.padding, layer.padding)
if isinstance(layer.stride, int):
layer.stride = (layer.stride, layer.stride, layer.stride)
kernel_ops = layer.kernel_size[0] * layer.kernel_size[1] * layer.kernel_size[2]
out_t = 1 + int(
(in_t + 2 * layer.padding[0] - layer.kernel_size[0]) / layer.stride[0]
)
out_h = 1 + int(
(in_h + 2 * layer.padding[1] - layer.kernel_size[1]) / layer.stride[1]
)
out_w = 1 + int(
(in_w + 2 * layer.padding[2] - layer.kernel_size[2]) / layer.stride[2]
)
flops = batchsize_per_replica * x.size()[1] * out_t * out_h * out_w * kernel_ops

# adaptive avg pool3d
# This is approximate and works only for downsampling without padding
# based on aten/src/ATen/native/AdaptiveAveragePooling3d.cpp
elif layer_type in ["AdaptiveAvgPool3d"]:
in_t = x.size()[2]
in_h = x.size()[3]
in_w = x.size()[4]
out_t = layer.output_size[0]
out_h = layer.output_size[1]
out_w = layer.output_size[2]
if out_t > in_t or out_h > in_h or out_w > in_w:
raise ClassyProfilerNotImplementedError(layer)
batchsize_per_replica = x.size()[0]
num_channels = x.size()[1]
kt = in_t - out_t + 1
kh = in_h - out_h + 1
kw = in_w - out_w + 1
kernel_ops = kt * kh * kw
flops = (
batchsize_per_replica * num_channels * out_t * out_w * out_h * kernel_ops
)
# batchnorm can be merged into conv op. Thus, count 0 FLOPS
flops = 0

# dropout layer
elif layer_type in ["Dropout"]:
# At test time, we do not drop values but scale the feature map by the
# dropout ratio
flops = 1
for dim_size in x.size():
flops *= dim_size
flops = 0

elif layer_type == "Identity":
flops = 0
Expand All @@ -335,11 +286,14 @@ def flops(self, x):
f"params(M): {count_params(layer) / 1e6}",
f"flops(M): {int(flops) / 1e6}",
]
logging.debug("\t".join(message))
return flops
if verbose:
logging.info("\t".join(message))
return int(flops)


def _layer_activations(layer: nn.Module, x: Any, out: Any) -> int:
def _layer_activations(
layer: nn.Module, x: Any, out: Any, verbose: bool = False
) -> int:
"""
Computes the number of activations produced by a single layer.

Expand All @@ -360,8 +314,9 @@ def activations(self, x, out):
return 0

message = [f"module: {typestr}", f"activations: {activations}"]
logging.debug("\t".join(message))
return activations
if verbose:
logging.info("\t".join(message))
return int(activations)


def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str:
Expand All @@ -386,17 +341,19 @@ def summarize_profiler_info(prof: torch.autograd.profiler.profile) -> str:


class ComplexityComputer:
def __init__(self, compute_fn: Callable, count_unique: bool):
def __init__(self, compute_fn: Callable, count_unique: bool, verbose: bool = False):
self.compute_fn = compute_fn
self.count_unique = count_unique
self.count = 0
self.verbose = verbose
self.seen_modules = set()

def compute(self, layer: nn.Module, x: Any, out: Any, module_name: str):
if self.count_unique and module_name in self.seen_modules:
return
logging.debug(f"module name: {module_name}")
self.count += self.compute_fn(layer, x, out)
self.count += self.compute_fn(layer, x, out, self.verbose)
if self.verbose:
logging.info(f"module name: {module_name}, count {self.count}")
self.seen_modules.add(module_name)

def reset(self):
Expand Down Expand Up @@ -482,6 +439,7 @@ def compute_complexity(
input_key: Optional[Union[str, List[str]]] = None,
patch_attr: str = None,
compute_unique: bool = False,
verbose: bool = False,
) -> int:
"""
Compute the complexity of a forward pass.
Expand All @@ -501,7 +459,7 @@ def compute_complexity(
else:
input = get_model_dummy_input(model, input_shape, input_key)

complexity_computer = ComplexityComputer(compute_fn, compute_unique)
complexity_computer = ComplexityComputer(compute_fn, compute_unique, verbose)

# measure FLOPs:
modify_forward(model, complexity_computer, patch_attr=patch_attr)
Expand All @@ -519,25 +477,32 @@ def compute_flops(
model: nn.Module,
input_shape: Tuple[int] = (3, 224, 224),
input_key: Optional[Union[str, List[str]]] = None,
verbose: bool = False,
) -> int:
"""
Compute the number of FLOPs needed for a forward pass.
"""
return compute_complexity(
model, _layer_flops, input_shape, input_key, patch_attr="flops"
model, _layer_flops, input_shape, input_key, patch_attr="flops", verbose=verbose
)


def compute_activations(
model: nn.Module,
input_shape: Tuple[int] = (3, 224, 224),
input_key: Optional[Union[str, List[str]]] = None,
verbose: bool = False,
) -> int:
"""
Compute the number of activations created in a forward pass.
"""
return compute_complexity(
model, _layer_activations, input_shape, input_key, patch_attr="activations"
model,
_layer_activations,
input_shape,
input_key,
patch_attr="activations",
verbose=verbose,
)


Expand Down
8 changes: 8 additions & 0 deletions classy_vision/hooks/loss_lr_meter_logging_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from typing import Optional

import torch
from classy_vision.generic.distributed_util import get_rank
from classy_vision.hooks import register_hook
from classy_vision.hooks.classy_hook import ClassyHook
Expand Down Expand Up @@ -49,6 +50,13 @@ def on_phase_end(self, task) -> None:
# for meters to not provide a sync function.
self._log_loss_lr_meters(task, prefix="Synced meters: ", log_batches=True)

logging.info(
f"max memory allocated(MB) {torch.cuda.max_memory_allocated() // 1e6}"
)
logging.info(
f"max memory reserved(MB) {torch.cuda.max_memory_reserved() // 1e6}"
)

def on_step(self, task) -> None:
"""
Log the LR every log_freq batches, if log_freq is not None.
Expand Down
Loading