Skip to content

Commit 35504d5

Browse files
iamzainhudafacebook-github-bot
authored andcommitted
add support for List[torch.Tensor]
Summary: list[torch.tensor] states were not handled correctly, this diff adds the requisite support TODO: in next diff we will optimize the collective calls such that we are not all gathering per state per metric. Differential Revision: D76471511
1 parent 4e43395 commit 35504d5

File tree

2 files changed

+61
-23
lines changed

2 files changed

+61
-23
lines changed

torchrec/metrics/metric_module.py

Lines changed: 46 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
import abc
1313
import logging
1414
import time
15-
from typing import Any, Dict, List, Optional, Type, Union
15+
from collections import defaultdict
16+
from typing import Any, Dict, List, Optional, Sequence, Type, TypeVar, Union
1617

1718
import torch
1819
import torch.distributed as dist
@@ -106,6 +107,8 @@
106107
}
107108

108109

110+
T = TypeVar("T")
111+
109112
# Label used for emitting model metrics to the coresponding trainer publishers.
110113
MODEL_METRIC_LABEL: str = "model"
111114

@@ -370,31 +373,29 @@ def _get_metric_states(
370373
world_size: int,
371374
process_group: Union[dist.ProcessGroup, DeviceMesh],
372375
) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]:
373-
metric_computations = metric._metrics_computations
374-
tasks = metric._tasks
375-
376-
state_aggregated = {}
377-
for task, metric_computation in zip(tasks, metric_computations):
378-
inputs = []
379-
state_aggregated[task.name] = {}
376+
result = defaultdict(dict)
377+
for task, computation in zip(metric._tasks, metric._metrics_computations):
380378
# pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
381379
# `items`.
382-
for attr, reduction_fn in metric_computation._reductions.items():
383-
inputs.append((attr, getattr(metric_computation, attr), reduction_fn))
384-
385-
# TODO: do one all gather call per metric, instead of one per state
386-
# this may require more logic as shapes of states are not guranteed to be same
387-
# may need padding
388-
for state, tensor, reduction_fn in inputs:
389-
gather_list = [torch.empty_like(tensor) for _ in range(world_size)]
390-
dist.all_gather(gather_list, tensor, group=process_group)
391-
state_aggregated[task.name][state] = (
392-
reduction_fn(torch.stack(gather_list))
393-
if reduction_fn is not None
394-
else gather_list
380+
for state_name, reduction_fn in computation._reductions.items():
381+
tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr(
382+
computation, state_name
383+
)
384+
385+
if isinstance(tensor_or_list, list):
386+
gathered = _all_gather_tensor_list(
387+
tensor_or_list, world_size, process_group
388+
)
389+
else:
390+
gathered = torch.stack(
391+
_all_gather_tensor(tensor_or_list, world_size, process_group)
392+
)
393+
reduced = (
394+
reduction_fn(gathered) if reduction_fn is not None else gathered
395395
)
396+
result[task.name][state_name] = reduced
396397

397-
return state_aggregated
398+
return result
398399

399400
def get_pre_compute_states(
400401
self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None
@@ -611,3 +612,26 @@ def generate_metric_module(
611612
)
612613
metrics.to(device)
613614
return metrics
615+
616+
617+
def _all_gather_tensor(
618+
tensor: torch.Tensor,
619+
world_size: int,
620+
pg: Union[dist.ProcessGroup, DeviceMesh],
621+
) -> List[torch.Tensor]:
622+
"""All-gather a single tensor and return the gathered list."""
623+
out = [torch.empty_like(tensor) for _ in range(world_size)]
624+
dist.all_gather(out, tensor, group=pg)
625+
return out
626+
627+
628+
def _all_gather_tensor_list(
629+
tensors: List[torch.Tensor],
630+
world_size: int,
631+
pg: Union[dist.ProcessGroup, DeviceMesh],
632+
) -> List[torch.Tensor]:
633+
"""All-gather every tensor in a list and flatten the result."""
634+
gathered: List[torch.Tensor] = []
635+
for t in tensors:
636+
gathered.extend(_all_gather_tensor(t, world_size, pg))
637+
return gathered

torchrec/metrics/tests/test_metric_module.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
MetricsConfig,
4040
RecMetricDef,
4141
RecMetricEnum,
42+
ThroughputDef,
4243
)
4344
from torchrec.metrics.model_utils import parse_task_model_outputs
4445
from torchrec.metrics.rec_metric import RecMetricList, RecTaskInfo
@@ -681,7 +682,20 @@ def setUp(self, backend: str = "nccl") -> None:
681682
def test_metric_module_gather_state(self) -> None:
682683
world_size = 2
683684
backend = "nccl"
684-
metrics_config = DefaultMetricsConfig
685+
# use NE to test torch.Tensor state and AUC to test List[torch.Tensor] state
686+
metrics_config = MetricsConfig(
687+
rec_tasks=[DefaultTaskInfo],
688+
rec_metrics={
689+
RecMetricEnum.NE: RecMetricDef(
690+
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
691+
),
692+
RecMetricEnum.AUC: RecMetricDef(
693+
rec_tasks=[DefaultTaskInfo], window_size=_DEFAULT_WINDOW_SIZE
694+
),
695+
},
696+
throughput_metric=ThroughputDef(),
697+
state_metrics=[],
698+
)
685699
batch_size = 128
686700

687701
self._run_multi_process_test(

0 commit comments

Comments
 (0)