2020import queue
2121import threading
2222from datetime import timedelta
23- from typing import TYPE_CHECKING , Dict , List , Optional , Tuple , Type , Union
23+ from typing import TYPE_CHECKING , Any , Dict , List , Optional , Tuple , Type , Union
2424
2525import torch
2626import torch .distributed as dist
@@ -871,6 +871,8 @@ def extend_device_mesh(
871871
872872
873873class ManagedDeviceMesh (DeviceMesh ):
874+ replicate_pg_singleton : Optional ["ManagedProcessGroup" ] = None
875+
874876 def __init__ (
875877 self ,
876878 mesh : Optional [DeviceMesh ],
@@ -899,6 +901,16 @@ def __init__(
899901 self ._flatten_mesh_list : Tuple [DeviceMesh , ...] = tuple ()
900902 self ._thread_id : Optional [int ] = None
901903
904+ def __getstate__ (self ) -> Dict [str , Any ]:
905+ state = self .__dict__ .copy ()
906+ state ["replicate_pg" ] = None
907+ return state
908+
909+ def __setstate__ (self , state : Dict [str , Any ]) -> None :
910+ self .__dict__ .update (state )
911+ assert self .replicate_pg_singleton is not None
912+ self .replicate_pg = self .replicate_pg_singleton
913+
902914 def __getitem__ (self , mesh_dim_names : Union [str , Tuple [str , ...]]) -> DeviceMesh :
903915 if isinstance (mesh_dim_names , str ):
904916 if mesh_dim_names == self .replicate_dim_name :
@@ -916,13 +928,16 @@ def __getitem__(self, mesh_dim_names: Union[str, Tuple[str, ...]]) -> DeviceMesh
916928 return self .mesh [mesh_dim_names ]
917929 else :
918930 assert isinstance (mesh_dim_names , tuple )
919- if self .replicate_dim_name in mesh_dim_names :
931+ if self .replicate_dim_name not in mesh_dim_names :
920932 assert self .mesh is not None
921933 return self .mesh [mesh_dim_names ]
922934 else :
935+ mesh_dim_names_wo_replicate = tuple (
936+ n for n in mesh_dim_names if n != self .replicate_dim_name
937+ )
923938 assert self .mesh is not None
924939 return ManagedDeviceMesh (
925- self .mesh [mesh_dim_names ],
940+ self .mesh [mesh_dim_names_wo_replicate ],
926941 mesh_dim_names ,
927942 self .replicate_pg ,
928943 mesh_dim_names .index (self .replicate_dim_name ),
@@ -957,14 +972,18 @@ def _flatten(self, mesh_dim_name: Optional[str]) -> "DeviceMesh":
957972 return flatten_mesh
958973
959974 def size (self , mesh_dim : Optional [int ] = None ) -> int :
975+ replicate_pg_size = self .replicate_pg .size ()
976+ # We have to lie to the users if there are zero particpants.
977+ # This is possible during the initialization stage of training.
978+ replicate_pg_size = 1 if replicate_pg_size == 0 else replicate_pg_size
960979 if mesh_dim is None :
961980 if self .mesh is None :
962- return self . replicate_pg . size ()
981+ return replicate_pg_size
963982 else :
964983 assert self .mesh is not None
965- return self .mesh .size () * self . replicate_pg . size ()
984+ return self .mesh .size () * replicate_pg_size
966985 elif mesh_dim == self .replicate_dim :
967- return self . replicate_pg . size ()
986+ return replicate_pg_size
968987 else :
969988 assert self .mesh is not None
970989 return self .mesh .size (self ._real_mesh_dim (mesh_dim ))
@@ -1014,7 +1033,16 @@ def get_coordinate(self) -> Optional[List[int]]:
10141033 dimensions of the mesh. If this rank is not part of the mesh, return None.
10151034 """
10161035 assert self .mesh is not None
1017- return self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1036+ coordinate = (
1037+ self .mesh ._coordinate_on_dim if self .mesh ._coordinate_on_dim else None
1038+ )
1039+ if not coordinate :
1040+ return coordinate
1041+
1042+ # We need to copy be cause we are going to modify the coordinate.
1043+ coordinate = coordinate .copy ()
1044+ coordinate .insert (get_rank (self .replicate_pg ), self .replicate_dim )
1045+ return coordinate
10181046
10191047 def get_all_groups (self ) -> List [BaseProcessGroup ]:
10201048 raise NotImplementedError
@@ -1076,19 +1104,11 @@ def ft_init_device_mesh(
10761104 mesh_dim_names = tuple (_mesh_dim_names ),
10771105 )
10781106
1079- if device_type == "cpu" :
1080- pg = ProcessGroupGloo ()
1081- elif device_type == "cuda" :
1082- pg = ProcessGroupNCCL ()
1083- else :
1084- raise ValueError ()
1085-
1086- manager ._pg = pg
10871107 replicate_pg = ManagedProcessGroup (manager )
1088- # We have to use MultiProcessTestCase, otherwise c10d will complain
1089- # the same backend has been registered.
10901108 replicate_pg .register (mesh_dim_names [replicate_dim ])
10911109
1110+ ManagedDeviceMesh .replicate_pg_singleton = replicate_pg
1111+
10921112 return ManagedDeviceMesh (
10931113 mesh = mesh ,
10941114 mesh_dim_names = mesh_dim_names ,
0 commit comments