|
12 | 12 | import abc
|
13 | 13 | import logging
|
14 | 14 | import time
|
15 |
| -from typing import Any, Dict, List, Optional, Type, Union |
| 15 | +from collections import defaultdict |
| 16 | +from typing import Any, Dict, List, Optional, Sequence, Type, TypeVar, Union |
16 | 17 |
|
17 | 18 | import torch
|
18 | 19 | import torch.distributed as dist
|
|
106 | 107 | }
|
107 | 108 |
|
108 | 109 |
|
| 110 | +T = TypeVar("T") |
| 111 | + |
109 | 112 | # Label used for emitting model metrics to the coresponding trainer publishers.
|
110 | 113 | MODEL_METRIC_LABEL: str = "model"
|
111 | 114 |
|
@@ -370,31 +373,29 @@ def _get_metric_states(
|
370 | 373 | world_size: int,
|
371 | 374 | process_group: Union[dist.ProcessGroup, DeviceMesh],
|
372 | 375 | ) -> Dict[str, Dict[str, Union[torch.Tensor, List[torch.Tensor]]]]:
|
373 |
| - metric_computations = metric._metrics_computations |
374 |
| - tasks = metric._tasks |
375 |
| - |
376 |
| - state_aggregated = {} |
377 |
| - for task, metric_computation in zip(tasks, metric_computations): |
378 |
| - inputs = [] |
379 |
| - state_aggregated[task.name] = {} |
| 376 | + result = defaultdict(dict) |
| 377 | + for task, computation in zip(metric._tasks, metric._metrics_computations): |
380 | 378 | # pyre-fixme[16]: Item `Tensor` of `Tensor | Module` has no attribute
|
381 | 379 | # `items`.
|
382 |
| - for attr, reduction_fn in metric_computation._reductions.items(): |
383 |
| - inputs.append((attr, getattr(metric_computation, attr), reduction_fn)) |
384 |
| - |
385 |
| - # TODO: do one all gather call per metric, instead of one per state |
386 |
| - # this may require more logic as shapes of states are not guranteed to be same |
387 |
| - # may need padding |
388 |
| - for state, tensor, reduction_fn in inputs: |
389 |
| - gather_list = [torch.empty_like(tensor) for _ in range(world_size)] |
390 |
| - dist.all_gather(gather_list, tensor, group=process_group) |
391 |
| - state_aggregated[task.name][state] = ( |
392 |
| - reduction_fn(torch.stack(gather_list)) |
393 |
| - if reduction_fn is not None |
394 |
| - else gather_list |
| 380 | + for state_name, reduction_fn in computation._reductions.items(): |
| 381 | + tensor_or_list: Union[List[torch.Tensor], torch.Tensor] = getattr( |
| 382 | + computation, state_name |
| 383 | + ) |
| 384 | + |
| 385 | + if isinstance(tensor_or_list, list): |
| 386 | + gathered = _all_gather_tensor_list( |
| 387 | + tensor_or_list, world_size, process_group |
| 388 | + ) |
| 389 | + else: |
| 390 | + gathered = torch.stack( |
| 391 | + _all_gather_tensor(tensor_or_list, world_size, process_group) |
| 392 | + ) |
| 393 | + reduced = ( |
| 394 | + reduction_fn(gathered) if reduction_fn is not None else gathered |
395 | 395 | )
|
| 396 | + result[task.name][state_name] = reduced |
396 | 397 |
|
397 |
| - return state_aggregated |
| 398 | + return result |
398 | 399 |
|
399 | 400 | def get_pre_compute_states(
|
400 | 401 | self, pg: Optional[Union[dist.ProcessGroup, DeviceMesh]] = None
|
@@ -611,3 +612,26 @@ def generate_metric_module(
|
611 | 612 | )
|
612 | 613 | metrics.to(device)
|
613 | 614 | return metrics
|
| 615 | + |
| 616 | + |
| 617 | +def _all_gather_tensor( |
| 618 | + tensor: torch.Tensor, |
| 619 | + world_size: int, |
| 620 | + pg: Union[dist.ProcessGroup, DeviceMesh], |
| 621 | +) -> List[torch.Tensor]: |
| 622 | + """All-gather a single tensor and return the gathered list.""" |
| 623 | + out = [torch.empty_like(tensor) for _ in range(world_size)] |
| 624 | + dist.all_gather(out, tensor, group=pg) |
| 625 | + return out |
| 626 | + |
| 627 | + |
| 628 | +def _all_gather_tensor_list( |
| 629 | + tensors: List[torch.Tensor], |
| 630 | + world_size: int, |
| 631 | + pg: Union[dist.ProcessGroup, DeviceMesh], |
| 632 | +) -> List[torch.Tensor]: |
| 633 | + """All-gather every tensor in a list and flatten the result.""" |
| 634 | + gathered: List[torch.Tensor] = [] |
| 635 | + for t in tensors: |
| 636 | + gathered.extend(_all_gather_tensor(t, world_size, pg)) |
| 637 | + return gathered |
0 commit comments