File tree Expand file tree Collapse file tree 1 file changed +13
-13
lines changed Expand file tree Collapse file tree 1 file changed +13
-13
lines changed Original file line number Diff line number Diff line change @@ -192,23 +192,23 @@ def add_tp_constraints(autop):
192192 add_tp_constraints (autop )
193193
194194 t = time .time ()
195- sharding_placement = autop .optimize_placement (verbose = False )
195+ sharding_placement = autop .optimize_placement (verbose = True )
196196 print (f"Took { time .time () - t :.2f} s" )
197197 parallel_mod = autop .apply_placement (sharding_placement )
198198
199199# run weight init on our sharded DTensor params
200- # parallel_mod.to_empty(device="cuda")
201- # parallel_mod.init_weights()
200+ parallel_mod .to_empty (device = "cuda" )
201+ parallel_mod .init_weights ()
202202
203203# now let's run it
204- # x = (
205- # torch.randint(
206- # 0,
207- # vocab_size,
208- # (batch_size // mesh.shape[0], seqlen),
209- # device=torch.device("cuda"),
210- # ),
211- # )
212- # out = parallel_mod(*x)
213- # out.backward(torch.randn_like(out))
204+ x = (
205+ torch .randint (
206+ 0 ,
207+ vocab_size ,
208+ (batch_size // mesh .shape [0 ], seqlen ),
209+ device = torch .device ("cuda" ),
210+ ),
211+ )
212+ out = parallel_mod (* x )
213+ out .backward (torch .randn_like (out ))
214214print ("All good!" )
You can’t perform that action at this time.
0 commit comments