Skip to content

add support for List[torch.Tensor] #3088

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import abc
import logging
import time
from typing import Any, Dict, List, Optional, Type, Union
from collections import defaultdict
from typing import Any, Dict, List, Optional, Sequence, Type, TypeVar, Union

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -106,6 +107,8 @@
}


T = TypeVar("T")

# Label used for emitting model metrics to the coresponding trainer publishers.
MODEL_METRIC_LABEL: str = "model"

Expand Down Expand Up @@ -370,31 +373,29 @@ def _get_metric_states(
world_size: int,
process_group: Union[dist.ProcessGroup, DeviceMesh],
) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]:
metric_computations = metric._metrics_computations
tasks = metric._tasks

state_aggregated = {}
for task, metric_computation in zip(tasks, metric_computations):
inputs = []
state_aggregated[task.name] = {}
result = defaultdict(dict)
for task, computation in zip(metric._tasks, metric._metrics_computations):
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
# `items`.
for attr, reduction_fn in metric_computation._reductions.items():
inputs.append((attr, getattr(metric_computation, attr), reduction_fn))

# TODO: do one all gather call per metric, instead of one per state
# this may require more logic as shapes of states are not guranteed to be same
# may need padding
for state, tensor, reduction_fn in inputs:
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
dist.all_gather(gather_list, tensor, group=process_group)
state_aggregated[task.name][state] = (
reduction_fn(torch.stack(gather_list))
if reduction_fn is not None
else gather_list
for state_name, reduction_fn in computation._reductions.items():
tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr(
computation, state_name
)

if isinstance(tensor_or_list, list):
gathered = _all_gather_tensor_list(
tensor_or_list, world_size, process_group
)
else:
gathered = torch.stack(
_all_gather_tensor(tensor_or_list, world_size, process_group)
)
reduced = (
reduction_fn(gathered) if reduction_fn is not None else gathered
)
result[task.name][state_name] = reduced

return state_aggregated
return result

def get_pre_compute_states(
self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None
Expand Down Expand Up @@ -611,3 +612,26 @@ def generate_metric_module(
)
metrics.to(device)
return metrics


def _all_gather_tensor(
tensor: torch.Tensor,
world_size: int,
pg: Union[dist.ProcessGroup, DeviceMesh],
) -> List[torch.Tensor]:
"""All-gather a single tensor and return the gathered list."""
out = [torch.empty_like(tensor) for _ in range(world_size)] # pragma: no cover
dist.all_gather(out, tensor, group=pg)
return out


def _all_gather_tensor_list(
tensors: List[torch.Tensor],
world_size: int,
pg: Union[dist.ProcessGroup, DeviceMesh],
) -> List[torch.Tensor]:
"""All-gather every tensor in a list and flatten the result."""
gathered: List[torch.Tensor] = [] # pragma: no cover
for t in tensors:
gathered.extend(_all_gather_tensor(t, world_size, pg))
return gathered
16 changes: 15 additions & 1 deletion torchrec/metrics/tests/test_metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
MetricsConfig,
RecMetricDef,
RecMetricEnum,
ThroughputDef,
)
from torchrec.metrics.model_utils import parse_task_model_outputs
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
Expand Down Expand Up @@ -681,7 +682,20 @@ def setUp(self, backend: str = "nccl") -> None:
def test_metric_module_gather_state(self) -> None:
world_size = 2
backend = "nccl"
metrics_config = DefaultMetricsConfig
# use NE to test torch.Tensor state and AUC to test List[torch.Tensor] state
metrics_config = MetricsConfig(
rec_tasks=[DefaultTaskInfo],
rec_metrics={
RecMetricEnum.NE: RecMetricDef(
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
),
RecMetricEnum.AUC: RecMetricDef(
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
),
},
throughput_metric=ThroughputDef(),
state_metrics=[],
)
batch_size = 128

self._run_multi_process_test(
Expand Down
Loading