Skip to content

Commit 63b2e09

Browse files
committed
update tests, make init_weights an explicit method for user to call
1 parent 7689c13 commit 63b2e09

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

autoparallel/api.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,20 @@ def apply_placement(self, sharding_placement=None):
418418
**sharded_weights_with_fqns,
419419
**sharded_buffers_with_fqns,
420420
}
421-
with stateless._reparametrize_module(self.model, sharded_params_buffers):
422-
self.model.init_weights()
421+
422+
def init_weights():
423+
with stateless._reparametrize_module(
424+
self.model, sharded_params_buffers
425+
):
426+
self.model.init_weights()
427+
428+
else:
429+
init_weights = None
423430

424431
self.parallel_model = self.parallel_model_fn(sharded_weights, sharded_buffers)
432+
# assign an init_weights method onto the output mod.
433+
# all it does is sneakily run the original user mod's init_weights method,
434+
# but with our new DTensor sharded params attached to the user module.
435+
self.parallel_model.init_weights = init_weights
425436

426437
return self.parallel_model

examples/example_autoparallel.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def input_fn():
100100
sharding_placement = autop.optimize_placement()
101101
parallel_mod = autop.apply_placement(sharding_placement)
102102

103+
# run weight init on our sharded DTensor params
104+
parallel_mod.init_weights()
105+
103106
# now let's run it
104107
x = (torch.rand(bs // mesh.shape[0], seq_len, dim1, device="cuda"),)
105108
out = parallel_mod(*x)

examples/example_llama3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -475,7 +475,6 @@ def __init__(self, model_args: TransformerModelArgs):
475475
self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args)
476476
self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps)
477477
self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False)
478-
self.init_weights()
479478

480479
def init_weights(
481480
self,
@@ -628,6 +627,9 @@ def input_fn():
628627
print(f"Took {time.time() - t:.2f} s")
629628
parallel_mod = autop.apply_placement(sharding_placement)
630629

630+
# run weight init on our sharded DTensor params
631+
parallel_mod.init_weights()
632+
631633
# now let's run it
632634
x = (
633635
torch.randint(

0 commit comments

Comments
 (0)