Skip to content

Commit ff9c574

Browse files
authored
Re-enable llama3 runtime (#164)
* Re-enable llama3 runtime This is a copy of #163 but pointing on main * Add back verbose flag It's useful for debugging
1 parent 38ab40e commit ff9c574

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

examples/example_llama3.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,23 +192,23 @@ def add_tp_constraints(autop):
192192
add_tp_constraints(autop)
193193

194194
t = time.time()
195-
sharding_placement = autop.optimize_placement(verbose=False)
195+
sharding_placement = autop.optimize_placement(verbose=True)
196196
print(f"Took {time.time() - t:.2f} s")
197197
parallel_mod = autop.apply_placement(sharding_placement)
198198

199199
# run weight init on our sharded DTensor params
200-
# parallel_mod.to_empty(device="cuda")
201-
# parallel_mod.init_weights()
200+
parallel_mod.to_empty(device="cuda")
201+
parallel_mod.init_weights()
202202

203203
# now let's run it
204-
# x = (
205-
# torch.randint(
206-
# 0,
207-
# vocab_size,
208-
# (batch_size // mesh.shape[0], seqlen),
209-
# device=torch.device("cuda"),
210-
# ),
211-
# )
212-
# out = parallel_mod(*x)
213-
# out.backward(torch.randn_like(out))
204+
x = (
205+
torch.randint(
206+
0,
207+
vocab_size,
208+
(batch_size // mesh.shape[0], seqlen),
209+
device=torch.device("cuda"),
210+
),
211+
)
212+
out = parallel_mod(*x)
213+
out.backward(torch.randn_like(out))
214214
print("All good!")

0 commit comments

Comments
 (0)