1515
1616from autoparallel .api import AutoParallel
1717
18+ world_size = 256
19+
20+ fake_store = FakeStore ()
21+ torch .distributed .init_process_group (
22+ "fake" , store = fake_store , rank = 0 , world_size = world_size
23+ )
24+ mesh = torch .distributed .device_mesh .init_device_mesh (
25+ "cuda" ,
26+ (world_size // 32 , 8 , 4 ),
27+ mesh_dim_names = (
28+ "dp" ,
29+ "tp" ,
30+ "cp" ,
31+ ),
32+ )
33+ assert mesh .ndim == 3 , "Please also update local_map"
34+
1835
1936def policy_fn (ctx , op , * args , ** kwargs ):
2037 if (
@@ -37,7 +54,7 @@ def policy_fn(ctx, op, *args, **kwargs):
3754 ),
3855 redistribute_inputs = True ,
3956 in_grad_placements = None ,
40- device_mesh = None ,
57+ device_mesh = mesh ,
4158)
4259def replicate_linear (w , x ):
4360 return torch .matmul (x , w .t ())
@@ -54,7 +71,7 @@ def replicate_linear(w, x):
5471 ),
5572 redistribute_inputs = True ,
5673 in_grad_placements = None ,
57- device_mesh = None ,
74+ device_mesh = mesh ,
5875)
5976def sharded_pointwise (x , scalar ):
6077 return x + scalar , scalar
@@ -69,7 +86,7 @@ def sharded_pointwise(x, scalar):
6986 ),
7087 redistribute_inputs = True ,
7188 in_grad_placements = None ,
72- device_mesh = None ,
89+ device_mesh = mesh ,
7390)
7491def context_parallel_attention (query , key , value ):
7592 out = nn .functional .scaled_dot_product_attention (
@@ -128,22 +145,6 @@ def forward(self, x):
128145 return o
129146
130147
131- world_size = 256
132-
133- fake_store = FakeStore ()
134- torch .distributed .init_process_group (
135- "fake" , store = fake_store , rank = 0 , world_size = world_size
136- )
137- mesh = torch .distributed .device_mesh .init_device_mesh (
138- "cuda" ,
139- (world_size // 32 , 8 , 4 ),
140- mesh_dim_names = (
141- "dp" ,
142- "tp" ,
143- "cp" ,
144- ),
145- )
146-
147148bs = 8 * mesh .shape [0 ]
148149seq_len = 256
149150nheads = 48
0 commit comments