1
1
{
2
- "nbformat" : 4 ,
3
- "nbformat_minor" : 0 ,
4
- "metadata" : {
5
- "colab" : {
6
- "private_outputs" : true ,
7
- "provenance" : [],
8
- "machine_shape" : " hm" ,
9
- "gpuType" : " V100"
10
- },
11
- "kernelspec" : {
12
- "name" : " python3" ,
13
- "display_name" : " Python 3"
14
- },
15
- "language_info" : {
16
- "name" : " python"
17
- },
18
- "accelerator" : " GPU"
2
+ "cells" : [
3
+ {
4
+ "cell_type" : " code" ,
5
+ "execution_count" : null ,
6
+ "metadata" : {
7
+ "id" : " xosv4DUIlvi9"
8
+ },
9
+ "outputs" : [],
10
+ "source" : [
11
+ " !pip install torchinfo"
12
+ ]
19
13
},
20
- "cells" : [
21
- {
22
- "cell_type" : " code" ,
23
- "source" : [
24
- " !pip install torchinfo"
25
- ],
26
- "metadata" : {
27
- "id" : " xosv4DUIlvi9"
28
- },
29
- "execution_count" : null ,
30
- "outputs" : []
31
- },
32
- {
33
- "cell_type" : " code" ,
34
- "execution_count" : null ,
35
- "metadata" : {
36
- "id" : " ma5VwOTFT0CX"
37
- },
38
- "outputs" : [],
39
- "source" : [
40
- " import torch\n " ,
41
- " import torch.nn as nn\n " ,
42
- " import math\n " ,
43
- " import torch.nn.functional as F\n " ,
44
- " from torchinfo import summary"
45
- ]
46
- },
47
- {
48
- "cell_type" : " code" ,
49
- "source" : [
50
- " class MaskedMultiSelfAttention(nn.Module):\n " ,
51
- " def __init__(self, h_dim, max_T, n_heads, drop_p):\n " ,
52
- " super().__init__()\n " ,
53
- " self.n_heads = n_heads\n " ,
54
- " \n " ,
55
- " self.q_net = nn.Linear(h_dim, h_dim)\n " ,
56
- " self.k_net = nn.Linear(h_dim, h_dim)\n " ,
57
- " self.v_net = nn.Linear(h_dim, h_dim)\n " ,
58
- " \n " ,
59
- " self.proj_net = nn.Linear(h_dim, h_dim)\n " ,
60
- " \n " ,
61
- " self.attn_drop = nn.Dropout(drop_p)\n " ,
62
- " self.proj_drop = nn.Dropout(drop_p)\n " ,
63
- " \n " ,
64
- " # Make lower triangle matrix with one\n " ,
65
- " ones = torch.ones((max_T, max_T))\n " ,
66
- " mask = torch.tril(ones).view(1, 1, max_T, max_T)\n " ,
67
- " \n " ,
68
- " # mask is constant\n " ,
69
- " self.register_buffer('mask', mask)\n " ,
70
- " \n " ,
71
- " def forward(self, x):\n " ,
72
- " B, T, C = x.shape\n " ,
73
- " N, D = self.n_heads, C // self.n_heads\n " ,
74
- " \n " ,
75
- " q = self.q_net(x).view(B, T, N, D).transpose(1, 2)\n " ,
76
- " k = self.k_net(x).view(B, T, N, D).transpose(1, 2)\n " ,
77
- " v = self.v_net(x).view(B, T, N, D).transpose(1, 2)\n " ,
78
- " \n " ,
79
- " weights = q @ k.transpose(2, 3) / math.sqrt(D)\n " ,
80
- " \n " ,
81
- " # Masked causal weights\n " ,
82
- " weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))\n " ,
83
- " \n " ,
84
- " # Normalize weights : all -inf -> 0 after softmax\n " ,
85
- " normalized_weights = F.softmax(weights, dim=-1)\n " ,
86
- " \n " ,
87
- " # Masked causal attention (B, N, T, D)\n " ,
88
- " attention = self.attn_drop(normalized_weights @ v)\n " ,
89
- " attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)\n " ,
90
- " \n " ,
91
- " out = self.proj_drop(self.proj_net(attention))\n " ,
92
- " return out"
93
- ],
94
- "metadata" : {
95
- "id" : " THxWrPdLVAd6"
96
- },
97
- "execution_count" : null ,
98
- "outputs" : []
99
- },
100
- {
101
- "cell_type" : " code" ,
102
- "source" : [
103
- " class TransformerDecoderBlock(nn.Module):\n " ,
104
- " def __init__(self, h_dim, max_T, n_heads, drop_p):\n " ,
105
- " super().__init__()\n " ,
106
- " self.attn = MaskedMultiSelfAttention(h_dim, max_T, n_heads, drop_p)\n " ,
107
- " self.mlp = nn.Sequential(\n " ,
108
- " nn.Linear(h_dim, 4 * h_dim),\n " ,
109
- " nn.GELU(),\n " ,
110
- " nn.Linear(4 * h_dim, h_dim),\n " ,
111
- " nn.Dropout(drop_p)\n " ,
112
- " )\n " ,
113
- " self.ln1 = nn.LayerNorm(h_dim)\n " ,
114
- " self.ln2 = nn.LayerNorm(h_dim)\n " ,
115
- " \n " ,
116
- " def forward(self, x):\n " ,
117
- " # MaskedMultiSelfAttention -> LayerNorm -> FeedForward -> LayerNorm\n " ,
118
- " x = self.attn(x) + x\n " ,
119
- " x = self.ln1(x)\n " ,
120
- " x = self.mlp(x) + x\n " ,
121
- " x = self.ln2(x)\n " ,
122
- " return x"
123
- ],
124
- "metadata" : {
125
- "id" : " copsi6yLUBtW"
126
- },
127
- "execution_count" : null ,
128
- "outputs" : []
129
- },
130
- {
131
- "cell_type" : " code" ,
132
- "source" : [
133
- " B, T, D = 4, 8, 64\n " ,
134
- " n_heads = 12"
135
- ],
136
- "metadata" : {
137
- "id" : " 050mS7V4kyuO"
138
- },
139
- "execution_count" : null ,
140
- "outputs" : []
141
- },
142
- {
143
- "cell_type" : " code" ,
144
- "source" : [
145
- " block = TransformerDecoderBlock(h_dim=n_heads*D, max_T=T, n_heads=n_heads, drop_p=0.1)"
146
- ],
147
- "metadata" : {
148
- "id" : " kZHgTOFuk6cw"
149
- },
150
- "execution_count" : null ,
151
- "outputs" : []
152
- },
153
- {
154
- "cell_type" : " code" ,
155
- "source" : [
156
- " summary(block, input_size=(B, T, n_heads * D))"
157
- ],
158
- "metadata" : {
159
- "id" : " HS6ByiOPlars"
160
- },
161
- "execution_count" : null ,
162
- "outputs" : []
163
- }
164
- ]
165
- }
14
+ {
15
+ "cell_type" : " code" ,
16
+ "execution_count" : null ,
17
+ "metadata" : {
18
+ "id" : " ma5VwOTFT0CX"
19
+ },
20
+ "outputs" : [],
21
+ "source" : [
22
+ " import torch\n " ,
23
+ " import torch.nn as nn\n " ,
24
+ " import math\n " ,
25
+ " import torch.nn.functional as F\n " ,
26
+ " from torchinfo import summary"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type" : " code" ,
31
+ "execution_count" : null ,
32
+ "metadata" : {
33
+ "id" : " THxWrPdLVAd6"
34
+ },
35
+ "outputs" : [],
36
+ "source" : [
37
+ " class MaskedMultiSelfAttention(nn.Module):\n " ,
38
+ " def __init__(self, h_dim, max_T, n_heads, drop_p):\n " ,
39
+ " super().__init__()\n " ,
40
+ " self.n_heads = n_heads\n " ,
41
+ " \n " ,
42
+ " self.q_net = nn.Linear(h_dim, h_dim)\n " ,
43
+ " self.k_net = nn.Linear(h_dim, h_dim)\n " ,
44
+ " self.v_net = nn.Linear(h_dim, h_dim)\n " ,
45
+ " \n " ,
46
+ " self.proj_net = nn.Linear(h_dim, h_dim)\n " ,
47
+ " \n " ,
48
+ " self.attn_drop = nn.Dropout(drop_p)\n " ,
49
+ " self.proj_drop = nn.Dropout(drop_p)\n " ,
50
+ " \n " ,
51
+ " # Make lower triangle matrix with one\n " ,
52
+ " ones = torch.ones((max_T, max_T))\n " ,
53
+ " mask = torch.tril(ones).view(1, 1, max_T, max_T)\n " ,
54
+ " \n " ,
55
+ " # mask is constant\n " ,
56
+ " self.register_buffer('mask', mask)\n " ,
57
+ " \n " ,
58
+ " def forward(self, x):\n " ,
59
+ " B, T, C = x.shape\n " ,
60
+ " N, D = self.n_heads, C // self.n_heads\n " ,
61
+ " \n " ,
62
+ " q = self.q_net(x).view(B, T, N, D).transpose(1, 2)\n " ,
63
+ " k = self.k_net(x).view(B, T, N, D).transpose(1, 2)\n " ,
64
+ " v = self.v_net(x).view(B, T, N, D).transpose(1, 2)\n " ,
65
+ " \n " ,
66
+ " weights = q @ k.transpose(2, 3) / math.sqrt(D)\n " ,
67
+ " \n " ,
68
+ " # Masked causal weights\n " ,
69
+ " weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))\n " ,
70
+ " \n " ,
71
+ " # Normalize weights : all -inf -> 0 after softmax\n " ,
72
+ " normalized_weights = F.softmax(weights, dim=-1)\n " ,
73
+ " \n " ,
74
+ " # Masked causal attention (B, N, T, D)\n " ,
75
+ " attention = self.attn_drop(normalized_weights @ v)\n " ,
76
+ " attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)\n " ,
77
+ " \n " ,
78
+ " out = self.proj_drop(self.proj_net(attention))\n " ,
79
+ " return out"
80
+ ]
81
+ },
82
+ {
83
+ "cell_type" : " code" ,
84
+ "execution_count" : null ,
85
+ "metadata" : {
86
+ "id" : " copsi6yLUBtW"
87
+ },
88
+ "outputs" : [],
89
+ "source" : [
90
+ " class TransformerDecoderBlock(nn.Module):\n " ,
91
+ " def __init__(self, h_dim, max_T, n_heads, drop_p):\n " ,
92
+ " super().__init__()\n " ,
93
+ " self.attn = MaskedMultiSelfAttention(h_dim, max_T, n_heads, drop_p)\n " ,
94
+ " self.mlp = nn.Sequential(\n " ,
95
+ " nn.Linear(h_dim, 4 * h_dim),\n " ,
96
+ " nn.GELU(),\n " ,
97
+ " nn.Linear(4 * h_dim, h_dim),\n " ,
98
+ " nn.Dropout(drop_p)\n " ,
99
+ " )\n " ,
100
+ " self.ln1 = nn.LayerNorm(h_dim)\n " ,
101
+ " self.ln2 = nn.LayerNorm(h_dim)\n " ,
102
+ " \n " ,
103
+ " def forward(self, x):\n " ,
104
+ " # MaskedMultiSelfAttention -> LayerNorm -> FeedForward -> LayerNorm\n " ,
105
+ " x = self.attn(x) + x\n " ,
106
+ " x = self.ln1(x)\n " ,
107
+ " x = self.mlp(x) + x\n " ,
108
+ " x = self.ln2(x)\n " ,
109
+ " return x"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type" : " code" ,
114
+ "execution_count" : null ,
115
+ "metadata" : {
116
+ "id" : " 050mS7V4kyuO"
117
+ },
118
+ "outputs" : [],
119
+ "source" : [
120
+ " B, T, D = 4, 8, 64\n " ,
121
+ " n_heads = 12"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type" : " code" ,
126
+ "execution_count" : null ,
127
+ "metadata" : {
128
+ "id" : " kZHgTOFuk6cw"
129
+ },
130
+ "outputs" : [],
131
+ "source" : [
132
+ " block = TransformerDecoderBlock(h_dim=n_heads*D, max_T=T, n_heads=n_heads, drop_p=0.1)"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type" : " code" ,
137
+ "execution_count" : null ,
138
+ "metadata" : {
139
+ "id" : " HS6ByiOPlars"
140
+ },
141
+ "outputs" : [],
142
+ "source" : [
143
+ " summary(block, input_size=(B, T, n_heads * D))"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type" : " code" ,
148
+ "execution_count" : null ,
149
+ "metadata" : {},
150
+ "outputs" : [],
151
+ "source" : [
152
+ " from google.colab import runtime\n " ,
153
+ " \n " ,
154
+ " runtime.unassign()"
155
+ ]
156
+ }
157
+ ],
158
+ "metadata" : {
159
+ "accelerator" : " GPU" ,
160
+ "colab" : {
161
+ "gpuType" : " V100" ,
162
+ "machine_shape" : " hm" ,
163
+ "private_outputs" : true ,
164
+ "provenance" : []
165
+ },
166
+ "kernelspec" : {
167
+ "display_name" : " Python 3" ,
168
+ "name" : " python3"
169
+ },
170
+ "language_info" : {
171
+ "name" : " python"
172
+ }
173
+ },
174
+ "nbformat" : 4 ,
175
+ "nbformat_minor" : 0
176
+ }
0 commit comments