Skip to content

Commit 31e328e

Browse files
committed
Add terminate block
1 parent 8ca3d48 commit 31e328e

File tree

1 file changed

+174
-163
lines changed

1 file changed

+174
-163
lines changed

gpt_pytorch.ipynb

Lines changed: 174 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -1,165 +1,176 @@
11
{
2-
"nbformat": 4,
3-
"nbformat_minor": 0,
4-
"metadata": {
5-
"colab": {
6-
"private_outputs": true,
7-
"provenance": [],
8-
"machine_shape": "hm",
9-
"gpuType": "V100"
10-
},
11-
"kernelspec": {
12-
"name": "python3",
13-
"display_name": "Python 3"
14-
},
15-
"language_info": {
16-
"name": "python"
17-
},
18-
"accelerator": "GPU"
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {
7+
"id": "xosv4DUIlvi9"
8+
},
9+
"outputs": [],
10+
"source": [
11+
"!pip install torchinfo"
12+
]
1913
},
20-
"cells": [
21-
{
22-
"cell_type": "code",
23-
"source": [
24-
"!pip install torchinfo"
25-
],
26-
"metadata": {
27-
"id": "xosv4DUIlvi9"
28-
},
29-
"execution_count": null,
30-
"outputs": []
31-
},
32-
{
33-
"cell_type": "code",
34-
"execution_count": null,
35-
"metadata": {
36-
"id": "ma5VwOTFT0CX"
37-
},
38-
"outputs": [],
39-
"source": [
40-
"import torch\n",
41-
"import torch.nn as nn\n",
42-
"import math\n",
43-
"import torch.nn.functional as F\n",
44-
"from torchinfo import summary"
45-
]
46-
},
47-
{
48-
"cell_type": "code",
49-
"source": [
50-
"class MaskedMultiSelfAttention(nn.Module):\n",
51-
" def __init__(self, h_dim, max_T, n_heads, drop_p):\n",
52-
" super().__init__()\n",
53-
" self.n_heads = n_heads\n",
54-
"\n",
55-
" self.q_net = nn.Linear(h_dim, h_dim)\n",
56-
" self.k_net = nn.Linear(h_dim, h_dim)\n",
57-
" self.v_net = nn.Linear(h_dim, h_dim)\n",
58-
"\n",
59-
" self.proj_net = nn.Linear(h_dim, h_dim)\n",
60-
"\n",
61-
" self.attn_drop = nn.Dropout(drop_p)\n",
62-
" self.proj_drop = nn.Dropout(drop_p)\n",
63-
"\n",
64-
" # Make lower triangle matrix with one\n",
65-
" ones = torch.ones((max_T, max_T))\n",
66-
" mask = torch.tril(ones).view(1, 1, max_T, max_T)\n",
67-
"\n",
68-
" # mask is constant\n",
69-
" self.register_buffer('mask', mask)\n",
70-
"\n",
71-
" def forward(self, x):\n",
72-
" B, T, C = x.shape\n",
73-
" N, D = self.n_heads, C // self.n_heads\n",
74-
"\n",
75-
" q = self.q_net(x).view(B, T, N, D).transpose(1, 2)\n",
76-
" k = self.k_net(x).view(B, T, N, D).transpose(1, 2)\n",
77-
" v = self.v_net(x).view(B, T, N, D).transpose(1, 2)\n",
78-
"\n",
79-
" weights = q @ k.transpose(2, 3) / math.sqrt(D)\n",
80-
"\n",
81-
" # Masked causal weights\n",
82-
" weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))\n",
83-
"\n",
84-
" # Normalize weights : all -inf -> 0 after softmax\n",
85-
" normalized_weights = F.softmax(weights, dim=-1)\n",
86-
"\n",
87-
" # Masked causal attention (B, N, T, D)\n",
88-
" attention = self.attn_drop(normalized_weights @ v)\n",
89-
" attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)\n",
90-
"\n",
91-
" out = self.proj_drop(self.proj_net(attention))\n",
92-
" return out"
93-
],
94-
"metadata": {
95-
"id": "THxWrPdLVAd6"
96-
},
97-
"execution_count": null,
98-
"outputs": []
99-
},
100-
{
101-
"cell_type": "code",
102-
"source": [
103-
"class TransformerDecoderBlock(nn.Module):\n",
104-
" def __init__(self, h_dim, max_T, n_heads, drop_p):\n",
105-
" super().__init__()\n",
106-
" self.attn = MaskedMultiSelfAttention(h_dim, max_T, n_heads, drop_p)\n",
107-
" self.mlp = nn.Sequential(\n",
108-
" nn.Linear(h_dim, 4 * h_dim),\n",
109-
" nn.GELU(),\n",
110-
" nn.Linear(4 * h_dim, h_dim),\n",
111-
" nn.Dropout(drop_p)\n",
112-
" )\n",
113-
" self.ln1 = nn.LayerNorm(h_dim)\n",
114-
" self.ln2 = nn.LayerNorm(h_dim)\n",
115-
"\n",
116-
" def forward(self, x):\n",
117-
" # MaskedMultiSelfAttention -> LayerNorm -> FeedForward -> LayerNorm\n",
118-
" x = self.attn(x) + x\n",
119-
" x = self.ln1(x)\n",
120-
" x = self.mlp(x) + x\n",
121-
" x = self.ln2(x)\n",
122-
" return x"
123-
],
124-
"metadata": {
125-
"id": "copsi6yLUBtW"
126-
},
127-
"execution_count": null,
128-
"outputs": []
129-
},
130-
{
131-
"cell_type": "code",
132-
"source": [
133-
"B, T, D = 4, 8, 64\n",
134-
"n_heads = 12"
135-
],
136-
"metadata": {
137-
"id": "050mS7V4kyuO"
138-
},
139-
"execution_count": null,
140-
"outputs": []
141-
},
142-
{
143-
"cell_type": "code",
144-
"source": [
145-
"block = TransformerDecoderBlock(h_dim=n_heads*D, max_T=T, n_heads=n_heads, drop_p=0.1)"
146-
],
147-
"metadata": {
148-
"id": "kZHgTOFuk6cw"
149-
},
150-
"execution_count": null,
151-
"outputs": []
152-
},
153-
{
154-
"cell_type": "code",
155-
"source": [
156-
"summary(block, input_size=(B, T, n_heads * D))"
157-
],
158-
"metadata": {
159-
"id": "HS6ByiOPlars"
160-
},
161-
"execution_count": null,
162-
"outputs": []
163-
}
164-
]
165-
}
14+
{
15+
"cell_type": "code",
16+
"execution_count": null,
17+
"metadata": {
18+
"id": "ma5VwOTFT0CX"
19+
},
20+
"outputs": [],
21+
"source": [
22+
"import torch\n",
23+
"import torch.nn as nn\n",
24+
"import math\n",
25+
"import torch.nn.functional as F\n",
26+
"from torchinfo import summary"
27+
]
28+
},
29+
{
30+
"cell_type": "code",
31+
"execution_count": null,
32+
"metadata": {
33+
"id": "THxWrPdLVAd6"
34+
},
35+
"outputs": [],
36+
"source": [
37+
"class MaskedMultiSelfAttention(nn.Module):\n",
38+
" def __init__(self, h_dim, max_T, n_heads, drop_p):\n",
39+
" super().__init__()\n",
40+
" self.n_heads = n_heads\n",
41+
"\n",
42+
" self.q_net = nn.Linear(h_dim, h_dim)\n",
43+
" self.k_net = nn.Linear(h_dim, h_dim)\n",
44+
" self.v_net = nn.Linear(h_dim, h_dim)\n",
45+
"\n",
46+
" self.proj_net = nn.Linear(h_dim, h_dim)\n",
47+
"\n",
48+
" self.attn_drop = nn.Dropout(drop_p)\n",
49+
" self.proj_drop = nn.Dropout(drop_p)\n",
50+
"\n",
51+
" # Make lower triangle matrix with one\n",
52+
" ones = torch.ones((max_T, max_T))\n",
53+
" mask = torch.tril(ones).view(1, 1, max_T, max_T)\n",
54+
"\n",
55+
" # mask is constant\n",
56+
" self.register_buffer('mask', mask)\n",
57+
"\n",
58+
" def forward(self, x):\n",
59+
" B, T, C = x.shape\n",
60+
" N, D = self.n_heads, C // self.n_heads\n",
61+
"\n",
62+
" q = self.q_net(x).view(B, T, N, D).transpose(1, 2)\n",
63+
" k = self.k_net(x).view(B, T, N, D).transpose(1, 2)\n",
64+
" v = self.v_net(x).view(B, T, N, D).transpose(1, 2)\n",
65+
"\n",
66+
" weights = q @ k.transpose(2, 3) / math.sqrt(D)\n",
67+
"\n",
68+
" # Masked causal weights\n",
69+
" weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))\n",
70+
"\n",
71+
" # Normalize weights : all -inf -> 0 after softmax\n",
72+
" normalized_weights = F.softmax(weights, dim=-1)\n",
73+
"\n",
74+
" # Masked causal attention (B, N, T, D)\n",
75+
" attention = self.attn_drop(normalized_weights @ v)\n",
76+
" attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)\n",
77+
"\n",
78+
" out = self.proj_drop(self.proj_net(attention))\n",
79+
" return out"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"metadata": {
86+
"id": "copsi6yLUBtW"
87+
},
88+
"outputs": [],
89+
"source": [
90+
"class TransformerDecoderBlock(nn.Module):\n",
91+
" def __init__(self, h_dim, max_T, n_heads, drop_p):\n",
92+
" super().__init__()\n",
93+
" self.attn = MaskedMultiSelfAttention(h_dim, max_T, n_heads, drop_p)\n",
94+
" self.mlp = nn.Sequential(\n",
95+
" nn.Linear(h_dim, 4 * h_dim),\n",
96+
" nn.GELU(),\n",
97+
" nn.Linear(4 * h_dim, h_dim),\n",
98+
" nn.Dropout(drop_p)\n",
99+
" )\n",
100+
" self.ln1 = nn.LayerNorm(h_dim)\n",
101+
" self.ln2 = nn.LayerNorm(h_dim)\n",
102+
"\n",
103+
" def forward(self, x):\n",
104+
" # MaskedMultiSelfAttention -> LayerNorm -> FeedForward -> LayerNorm\n",
105+
" x = self.attn(x) + x\n",
106+
" x = self.ln1(x)\n",
107+
" x = self.mlp(x) + x\n",
108+
" x = self.ln2(x)\n",
109+
" return x"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"metadata": {
116+
"id": "050mS7V4kyuO"
117+
},
118+
"outputs": [],
119+
"source": [
120+
"B, T, D = 4, 8, 64\n",
121+
"n_heads = 12"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": null,
127+
"metadata": {
128+
"id": "kZHgTOFuk6cw"
129+
},
130+
"outputs": [],
131+
"source": [
132+
"block = TransformerDecoderBlock(h_dim=n_heads*D, max_T=T, n_heads=n_heads, drop_p=0.1)"
133+
]
134+
},
135+
{
136+
"cell_type": "code",
137+
"execution_count": null,
138+
"metadata": {
139+
"id": "HS6ByiOPlars"
140+
},
141+
"outputs": [],
142+
"source": [
143+
"summary(block, input_size=(B, T, n_heads * D))"
144+
]
145+
},
146+
{
147+
"cell_type": "code",
148+
"execution_count": null,
149+
"metadata": {},
150+
"outputs": [],
151+
"source": [
152+
"from google.colab import runtime\n",
153+
"\n",
154+
"runtime.unassign()"
155+
]
156+
}
157+
],
158+
"metadata": {
159+
"accelerator": "GPU",
160+
"colab": {
161+
"gpuType": "V100",
162+
"machine_shape": "hm",
163+
"private_outputs": true,
164+
"provenance": []
165+
},
166+
"kernelspec": {
167+
"display_name": "Python 3",
168+
"name": "python3"
169+
},
170+
"language_info": {
171+
"name": "python"
172+
}
173+
},
174+
"nbformat": 4,
175+
"nbformat_minor": 0
176+
}

0 commit comments

Comments
 (0)