Skip to content

Commit 39b95c3

Browse files
committed
Implement GPT2 class
1 parent a9c5f3d commit 39b95c3

File tree

3 files changed

+31
-5
lines changed

3 files changed

+31
-5
lines changed

pytorch/gpt_pytorch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import sys
44
import os
55
from utils import load_encoder_hparams_and_params
6+
from model import GPT2
7+
import torchsummaryX
68

79

810
if __name__ == "__main__":
@@ -26,3 +28,6 @@
2628
input_ids = encoder.encode(args.prompt)
2729
input_text = encoder.decode(input_ids)
2830
print("input_ids:", input_ids)
31+
32+
model = GPT2(params, hparams, drop_p=0.1)
33+
torchsummaryX.summary(model, torch.ones(1, len(input_ids), dtype=torch.long))

pytorch/model.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,32 @@ def forward(self, x):
6060

6161

6262
class GPT2(nn.Module):
63-
def __init__(self):
63+
def __init__(self, params, hparams, drop_p=0.1):
6464
super().__init__()
65-
66-
def forward(self, x):
67-
pass
65+
self.params = params
66+
self.hparams = hparams
67+
self.drop_p = drop_p
68+
self.h_dim = hparams["n_embd"]
69+
self.n_heads = hparams["n_head"]
70+
71+
self.wte = self.params["wte.weight"]
72+
self.wpe = self.params["wpe.weight"]
73+
74+
self.blocks = []
75+
for _ in range(self.hparams["n_layer"]):
76+
block = TransformerDecoderBlock(
77+
h_dim=self.h_dim, n_heads=self.n_heads, drop_p=self.drop_p
78+
)
79+
self.blocks.append(block)
80+
self.layer_norm = nn.LayerNorm(self.h_dim)
81+
82+
def forward(self, input_ids):
83+
x = self.wte[input_ids] + self.wpe[list(range(input_ids.shape[1]))]
84+
for block in self.blocks:
85+
x = block(x)
86+
x = self.layer_norm(x)
87+
out = x @ self.wte.T
88+
return out
6889

6990

7091
if __name__ == "__main__":

pytorch/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# import picoGPT
99
sys.path.append(os.path.join(os.path.dirname(__file__), ".."))
1010
sys.path.append(os.path.join(os.path.dirname(__file__), "../picoGPT"))
11-
from encoder import get_encoder
11+
from picoGPT.encoder import get_encoder
1212

1313

1414
def save_file(file_path, req):

0 commit comments

Comments
 (0)