Skip to content

Commit

Permalink
fix: Gradient problem when the number of devices is 1
Browse files Browse the repository at this point in the history
  • Loading branch information
ChaosCodes committed Nov 8, 2023
1 parent fb05026 commit 782f182
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main(fabric, train_data_dir, val_data_dir, resume):

fabric.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=True):
with fabric.init_module(empty_init=(fabric.world_size > 1)):
model = GPT(config)
model.apply(partial(model._init_weights ,n_layer=config.n_layer))

Expand Down

0 comments on commit 782f182

Please sign in to comment.