Skip to content

Commit bc4b8a0

Browse files
committed
plumb args/kwargs to init_weights
1 parent 63b2e09 commit bc4b8a0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

autoparallel/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,11 +419,11 @@ def apply_placement(self, sharding_placement=None):
419419
**sharded_buffers_with_fqns,
420420
}
421421

422-
def init_weights():
422+
def init_weights(*args, **kwargs):
423423
with stateless._reparametrize_module(
424424
self.model, sharded_params_buffers
425425
):
426-
self.model.init_weights()
426+
self.model.init_weights(*args, **kwargs)
427427

428428
else:
429429
init_weights = None

0 commit comments

Comments
 (0)