You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Empty list metrics bypass a call to torch.barrier() and cause deadlocks or other strange synchronization issues.
To Reproduce
Define the ListMetric class as follows:
importtorchimporttorchmetricsclassListMetric(torchmetrics.Metric):
full_state_update: bool=Falsedef__init__(self):
super().__init__(sync_on_compute=False)
self.add_state("example_state", default=[], dist_reduce_fx="cat")
defupdate(x: torch.Tensor):
# in this example, x is a (2,)-shaped tensor and the internal state of this metric is a list of tensors of this shapeself.example_state.append(x)
defcompute(x: torch.Tensor):
returnself.example_statedefsync(
self,
dist_sync_fn: Optional[Callable] =None,
process_group: Optional[Any] =None,
should_sync: bool=True,
distributed_available: Optional[Callable] =None,
) ->None:
super().sync(
dist_sync_fn=dist_sync_fn,
process_group=process_group,
should_sync=should_sync,
distributed_available=distributed_available,
)
tensor_and_empty= (
isinstance(self.example_state, torch.Tensor)
andtorch.numel(self.example_state) ==0
)
list_and_empty= (
isinstance(self.example_state, List) andlen(self.example_state) ==0
)
iftensor_and_emptyorlist_and_empty:
self.example_state= []
return""" Torchmetrics has a strange quirk: Depending on whether it's being run in a distributed setting, the type of 'self.resolutions' may differ. The `dim_zero_cat()` function fixes this. see https://lightning.ai/docs/torchmetrics/v1.3.1/pages/implement.html "working with list states" """self.example_state=torchmetrics.utilities.dim_zero_cat(self.example_state)
self.example_state=list(torch.split(self.example_state, 2))
In your training loop, do this with 2 or more GPUs:
defon_validation_epoch_end(self, outputs):
lm=ListMetric()
ifself.local_rank==0:
lm.update(torch.zeros((2,)))
lm.sync() # everything but local_rank == 0 will skip the barrier inside this call to `sync()` and then the script will deadlock
Expected behavior
Empty list metrics properly sync (waiting on a barrier) with non-empty list metrics on different GPUs
Environment
Linux, torchmetrics 1.3.0.post0 from pip install
Additional context
It seems like this is caused by the line if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1: in Metric._sync_dist(), which causes return [function(x, *args, **kwargs) for x in data] in lightning_utilities.core.apply_func.py to execute. Since function contains the blocking call to wait on the barrier and the list is empty, no blocking ever occurs for GPUs whose metric states are an empty list.
The text was updated successfully, but these errors were encountered:
🐛 Bug
Empty list metrics bypass a call to
torch.barrier()
and cause deadlocks or other strange synchronization issues.To Reproduce
Define the
ListMetric
class as follows:In your training loop, do this with 2 or more GPUs:
Expected behavior
Empty list metrics properly sync (waiting on a barrier) with non-empty list metrics on different GPUs
Environment
Linux, torchmetrics 1.3.0.post0 from pip install
Additional context
It seems like this is caused by the line
if reduction_fn == dim_zero_cat and isinstance(input_dict[attr], list) and len(input_dict[attr]) > 1:
inMetric._sync_dist()
, which causesreturn [function(x, *args, **kwargs) for x in data]
inlightning_utilities.core.apply_func.py
to execute. Sincefunction
contains the blocking call to wait on the barrier and the list is empty, no blocking ever occurs for GPUs whose metric states are an empty list.The text was updated successfully, but these errors were encountered: