Skip to content

mark metricmodule methods as experimental #2983

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import DeviceMesh
from torch.profiler import record_function
from torchrec.metrics.accuracy import AccuracyMetric
from torchrec.metrics.auc import AUCMetric
Expand Down Expand Up @@ -67,6 +68,7 @@
from torchrec.metrics.unweighted_ne import UnweightedNEMetric
from torchrec.metrics.weighted_avg import WeightedAvgMetric
from torchrec.metrics.xauc import XAUCMetric
from torchrec.utils.experimental import experimental


logger: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -354,6 +356,120 @@ def reset(self) -> None:
def get_required_inputs(self) -> Optional[List[str]]:
return self.rec_metrics.get_required_inputs()

def _get_throughput_metric_states(
self, metric: ThroughputMetric
) -> Dict[str, Dict[str, torch.Tensor]]:
states = {}
# this doesn't use `state_dict` as some buffers are not persistent
for name, buf in metric.named_buffers():
states[name] = buf
return {metric._metric_name.value: states}

def _get_metric_states(
self,
metric: RecMetric,
world_size: int,
process_group: Union[dist.ProcessGroup, DeviceMesh],
reduce_metrics: bool = True,
) -> Dict[str, Dict[str, torch.Tensor]]:
metric_computations = metric._metrics_computations
tasks = metric._tasks

state_aggregated = {}
for task, metric_computation in zip(tasks, metric_computations):
inputs = []
state_aggregated[task.name] = {}
for attr, reduction_fn in metric_computation._reductions.items():
inputs.append((attr, getattr(metric_computation, attr), reduction_fn))

# TODO: do one all gather call per metric, instead of one per state
# this may require more logic as shapes of states are not guranteed to be same
# may need padding
for state, tensor, reduction_fn in inputs:
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(gather_list, tensor, group=process_group)
state_aggregated[task.name][state] = (
reduction_fn(torch.stack(gather_list))
if reduction_fn is not None and reduce_metrics
else gather_list
)

return state_aggregated

@experimental
def get_pre_compute_states(
self, pg: Union[dist.ProcessGroup, DeviceMesh], reduce_metrics: bool = True
) -> Dict[str, Dict[str, Dict[str, torch.Tensor]]]:
"""
This function returns the states per rank for each metric to be saved. The states are are aggregated by the state defined reduction_function.
This can be optionall disabled by setting ``reduce_metrics`` to False. The output on each rank is identical.

Each metric has N number of tasks associated with it. This is reflected in the metric state, where the size of the tensor is
typically (n_tasks, 1). Depending on the `RecComputeMode` the metric is in, the number of tasks can be 1 or len(tasks).

The output of the data is defined as nested dictionary, a dict of ``metric._namespace`` each mapping to a dict of tasks and their states and associated tensors:
metric : str -> { task : {state : tensor} }

Args:
pg (Union[dist.ProcessGroup, DeviceMesh]): the process group to use for all gather.
reduce_metrics (bool): whether to reduce the metrics or not. Default is True.

Returns:
Dict[str, Dict[str, Dict[str, torch.Tensor]]]: the states for each metric to be saved
"""
if isinstance(pg, DeviceMesh):
process_group: dist.ProcessGroup = pg.get_group(mesh_dim="shard")
else:
process_group: dist.ProcessGroup = pg
aggregated_states = {}
world_size = dist.get_world_size(
process_group
) # Under 2D parallel context, this should be sharding world size

for metric in self.rec_metrics.rec_metrics:
aggregated_states[metric._namespace.value] = self._get_metric_states(
metric, world_size, process_group, reduce_metrics
)

# throughput metric requires special handling, since it's not a RecMetric
throughput_metric = self.throughput_metric
if throughput_metric is not None:
aggregated_states[throughput_metric._namespace.value] = (
self._get_throughput_metric_states(throughput_metric)
)

return aggregated_states

@experimental
def load_pre_compute_states(
self, source: Dict[str, Dict[str, Dict[str, torch.Tensor]]]
) -> None:
"""
Load states from ``get_pre_compute_states``. This is called on every rank, no collectives are called in this function.

Args:
source (Dict[str, Dict[str, Dict[str, torch.Tensor]]]): the source states to load from. This
is the output of ``get_pre_compute_states``.

Returns:
None
"""
for metric in self.rec_metrics.rec_metrics:
states = source[metric._namespace.value]
for task, metric_computation in zip(
metric._tasks, metric._metrics_computations
):
state = states[task.name]
for attr, tensor in state.items():
setattr(metric_computation, attr, tensor)

if self.throughput_metric is not None:
states = source[self.throughput_metric._namespace.value][
self.throughput_metric._metric_name.value # pyre-ignore[16]
]
for name, buf in self.throughput_metric.named_buffers(): # pyre-ignore[16]
buf.copy_(states[name])


def _generate_rec_metrics(
metrics_config: MetricsConfig,
Expand Down
6 changes: 6 additions & 0 deletions torchrec/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ def _fused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value, metric_report.description

def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
"""
For each task, we generate an associated RecMetricComputation object for it.
This would mean in the states of each RecMetricComputation object, the n_tasks dimension is 1.
"""
for task, metric_computation in zip(self._tasks, self._metrics_computations):
metric_computation.pre_compute()
for metric_report in getattr(
Expand All @@ -494,6 +498,7 @@ def _unfused_tasks_iter(self, compute_scope: str) -> ComputeIterType:
or metric_computation.has_valid_update[0] > 0
else torch.zeros_like(metric_report.value)
)
# ultimately compute result comes here, and is then written to tensorboard, for fused tasks we need to know the metric prefix val and description
yield task, metric_report.name, valid_metric_value, compute_scope + metric_report.metric_prefix.value, metric_report.description

def _fuse_update_buffers(self) -> Dict[str, RecModelOutput]:
Expand Down Expand Up @@ -758,6 +763,7 @@ def update(
def compute(self) -> Dict[str, torch.Tensor]:
self._check_fused_update(force=True)
ret = {}
# we need to know the tasks, how does this relate to a recmetric and computed value
for task, metric_name, metric_value, prefix, description in self._tasks_iter(
""
):
Expand Down
90 changes: 89 additions & 1 deletion torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import torch
import torch.distributed as dist
import torch.distributed.launcher as pet
from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
MultiProcessTestBase,
)
from torchrec.metrics.auc import AUCMetric
from torchrec.metrics.metric_module import (
generate_metric_module,
Expand All @@ -32,14 +36,15 @@
BatchSizeStage,
DefaultMetricsConfig,
DefaultTaskInfo,
EmptyMetricsConfig,
MetricsConfig,
RecMetricDef,
RecMetricEnum,
)
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
from torchrec.metrics.test_utils import gen_test_batch, get_launch_config
from torchrec.metrics.throughput import ThroughputMetric
from torchrec.test_utils import seed_and_log, skip_if_asan_class

METRIC_MODULE_PATH = "torchrec.metrics.metric_module"

Expand Down Expand Up @@ -603,3 +608,86 @@ def test_save_and_load_state_dict(self) -> None:
no_bss_metric_module.load_state_dict(state_dict)
# Make sure num_batch wasn't created on the throughput module (and no exception was thrown above)
self.assertFalse(hasattr(no_bss_metric_module.throughput_metric, "_num_batch"))


def metric_module_gather_state(
rank: int,
world_size: int,
backend: str,
config: MetricsConfig,
batch_size: int,
local_size: Optional[int] = None,
) -> None:
"""
We compare the computed values of the metric module using the get_pre_compute_states API.
"""
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=batch_size,
world_size=world_size,
my_rank=rank,
state_metrics_mapping={},
device=ctx.device,
process_group=ctx.pg,
)

test_batches = []
for _ in range(100):
test_batch = gen_test_batch(batch_size)
for k in test_batch.keys():
test_batch[k] = test_batch[k].to(ctx.device)
# save to re run
test_batches.append(test_batch)
metric_module.update(test_batch)

computed_value = metric_module.compute()
states = metric_module.get_pre_compute_states(pg=ctx.pg) # pyre-ignore[6]

torch.distributed.barrier(ctx.pg)
# Compare to computing metrics on metric module that loads from pre_compute_states
new_metric_module = generate_metric_module(
TestMetricModule,
metrics_config=config,
batch_size=batch_size,
world_size=1,
my_rank=0,
state_metrics_mapping={},
device=torch.device(f"cuda:{rank}"),
process_group=dist.new_group(ranks=[rank], backend="nccl"),
)
new_metric_module.load_pre_compute_states(states)
new_computed_value = new_metric_module.compute()

for metric, tensor in computed_value.items():
new_tensor = new_computed_value[metric]
torch.testing.assert_close(tensor, new_tensor, check_device=False)


@skip_if_asan_class
class MetricModuleDistributedTest(MultiProcessTestBase):

@seed_and_log
def setUp(self, backend: str = "nccl") -> None:
super().setUp()
self.backend = backend

if torch.cuda.is_available():
self.device = torch.device("cuda")
else:
self.skipTest("CUDA required for distributed test")

def test_metric_module_gather_state(self) -> None:
world_size = 2
backend = "nccl"
metrics_config = DefaultMetricsConfig
batch_size = 128

self._run_multi_process_test(
callable=metric_module_gather_state,
world_size=world_size,
backend=backend,
batch_size=batch_size,
config=metrics_config,
)
9 changes: 9 additions & 0 deletions torchrec/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from . import experimental # noqa
70 changes: 70 additions & 0 deletions torchrec/utils/experimental.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from __future__ import annotations

import functools
import warnings
from typing import Any, Callable, overload, ParamSpec, Type, TypeVar, Union

P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")


@overload
def experimental(
obj: Callable[P, R],
feature: str | None = None,
since: str | None = None,
) -> Callable[P, R]: ...


@overload
def experimental(
obj: Type[T],
feature: str | None = None,
since: str | None = None,
) -> Type[T]: ...


def experimental(
obj: Union[Callable[P, R], Type[T]],
feature: str | None = None,
since: str | None = None,
) -> Union[Callable[P, R], Type[T]]:
tag: str = feature or obj.__name__ # pyre-ignore[16]
message_parts: list[str] = [
f"`{tag}` is *experimental* and may change or be removed without notice."
]
if since:
message_parts.insert(0, f"[since {since}]")
warning_message: str = " ".join(message_parts)

@functools.lru_cache(maxsize=1)
def _issue_warning() -> None:
warnings.warn(warning_message, UserWarning, stacklevel=3)

if isinstance(obj, type):
orig_init: Callable[..., None] = obj.__init__

@functools.wraps(orig_init)
def new_init(self, *args: Any, **kwargs: Any) -> Any: # pyre-ignore[3]
_issue_warning()
return orig_init(self, *args, **kwargs)

obj.__init__ = new_init
return obj
else:

@functools.wraps(obj)
def wrapper(*args: Any, **kwargs: Any) -> Any: # pyre-ignore[3]
_issue_warning()
return obj(*args, **kwargs) # pyre-ignore[29]

return wrapper
Loading