Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update init #92

Merged
merged 2 commits into from
Nov 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions lit_gpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,18 @@ def __init__(self, config: Config) -> None:
def _init_weights(self, module: nn.Module, n_layer) -> None:
"""Meant to be used with `gpt.apply(gpt._init_weights)`."""
# GPT-NeoX https://arxiv.org/pdf/2204.06745.pdf
# print module name
if isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
# RWKV: set it to 1e-4
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1)))
# torch.nn.init.normal_(module.weight, -1e-4, 1e-4)
# torch.nn.init.uniform_(module.weight, -1e-4, 1e-4)
elif isinstance(module, nn.Linear):
# fan-in variance scaling intializer
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / module.weight.size(1)))
torch.nn.init.normal_(module.weight, mean=0.0, std=math.sqrt(2.0 / 5 / self.config.n_embd))
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
# GPT-NeoX
for name, p in module.named_parameters():
if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU)): #if use xformer swiglu, fc2 layer will be renamed to w3
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(p.shape[-1]) / n_layer)
if (name == "proj.weight" and isinstance(module, LLaMAMLP)) or (name == "w3.weight" and isinstance(module, SwiGLU) or (name=="proj.weight" and isinstance(module, CausalSelfAttention))): #if use xformer swiglu, fc2 layer will be renamed to w3
nn.init.normal_(p, mean=0.0, std=1 / math.sqrt(self.config.n_embd) / n_layer)


def reset_cache(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion pretrain/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def main(fabric, train_data_dir, val_data_dir, resume):

fabric.print(f"Loading model with {config.__dict__}")
t0 = time.perf_counter()
with fabric.init_module(empty_init=(fabric.world_size > 1)):
with fabric.init_module(empty_init=False):
model = GPT(config)
model.apply(partial(model._init_weights ,n_layer=config.n_layer))

Expand Down