|
1 | 1 | """ |
2 | 2 | ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ์ ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ๋ฅผ ์ฌ์ฉํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต |
3 | | -==================================================================================== |
| 3 | +======================================================================= |
4 | 4 |
|
5 | 5 | **Author**: `Pritam Damania <https://github.com/pritamdamania87>`_ |
6 | 6 | **๋ฒ์ญ**: `๋ฐฑ์ ํฌ <https://github.com/spongebob03>`_ |
|
22 | 22 |
|
23 | 23 | ###################################################################### |
24 | 24 | # ๋ชจ๋ธ ์ ์ํ๊ธฐ |
25 | | -# ---------------- |
| 25 | +# ------------- |
26 | 26 | # |
27 | 27 |
|
28 | 28 | ###################################################################### |
29 | 29 | # ``PositionalEncoding`` ๋ชจ๋์ ์ํ์ค์์ ํ ํฐ์ ์๋์ , ์ ๋ ์์น์ ๋ํ |
30 | 30 | # ์ผ๋ถ ์ ๋ณด๋ฅผ ์ฃผ์
ํฉ๋๋ค. |
31 | 31 | # ์์น ์ธ์ฝ๋ฉ์ ์๋ฒ ๋ฉ๊ณผ ๊ฐ์ ์ฐจ์์ ๊ฐ์ง๋ฏ๋ก |
32 | | -# ๋์ ํฉ์น ์ ์์ต๋๋ค. ์ฌ๊ธฐ์, ์ฃผํ์๊ฐ ๋ค๋ฅธ ``sine``๊ณผ ``cosine`` ๊ธฐ๋ฅ์ |
| 32 | +# ๋์ ํฉ์น ์ ์์ต๋๋ค. ์ฌ๊ธฐ์, ์ฃผํ์๊ฐ ๋ค๋ฅธ ``sine`` ๊ณผ ``cosine`` ๊ธฐ๋ฅ์ |
33 | 33 | # ์ฌ์ฉํฉ๋๋ค. |
34 | 34 |
|
35 | 35 | import sys |
@@ -73,7 +73,7 @@ def forward(self, x): |
73 | 73 | # `nn.TransformerEncoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html>`__ ๊ณ์ธต(layer)์ ํฌํจ๋ฉ๋๋ค. |
74 | 74 | # `nn.TransformerEncoder <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html>`__ ๋ |
75 | 75 | # `nn.TransformerEncoderLayer <https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html>`__ ์ ``nlayers`` ๋ก ๊ตฌ์ฑ๋์ด ์์ต๋๋ค. |
76 | | -# ๊ฒฐ๊ณผ์ ์ผ๋ก, ์ฐ๋ฆฌ๋ ``nn.TransformerEncoder`` ์ ์ค์ ์ ๋๊ณ ์์ผ๋ฉฐ |
| 76 | +# ๊ฒฐ๊ณผ์ ์ผ๋ก, ``nn.TransformerEncoder`` ์ ์ค์ ์ ๋๊ณ ์์ผ๋ฉฐ |
77 | 77 | # ``nn.TransformerEncoderLayer`` ์ ์ ๋ฐ์ ํ GPU์ ๋๊ณ |
78 | 78 | # ๋๋จธ์ง ์ ๋ฐ์ ๋ค๋ฅธ GPU์ ์๋๋ก ๋ชจ๋ธ์ ๋ถํ ํฉ๋๋ค. ์ด๋ฅผ ์ํด์ ``Encoder`` ์ |
79 | 79 | # ``Decoder`` ์น์
์ ๋ถ๋ฆฌ๋ ๋ชจ๋๋ก ๋นผ๋ธ ๋ค์, ์๋ณธ ํธ๋์คํฌ๋จธ ๋ชจ๋์ |
@@ -122,20 +122,20 @@ def forward(self, inp): |
122 | 122 |
|
123 | 123 | ###################################################################### |
124 | 124 | # ํ์ต์ ์ํ ๋ค์ค ํ๋ก์ธ์ค ์์ |
125 | | -# ------------------------------------- |
| 125 | +# ------------------------------ |
126 | 126 | # |
127 | 127 |
|
128 | 128 |
|
129 | 129 | ###################################################################### |
130 | 130 | # ๊ฐ์ ๋ ๊ฐ์ GPU์์ ์์ฒด ํ์ดํ๋ผ์ธ์ ๊ตฌ๋ํ๋ ๋ ๊ฐ์ง ํ๋ก์ธ์ค๋ฅผ ์์ํฉ๋๋ค. |
131 | | -# ``run_worker``๋ ๊ฐ ํ๋ก์ธ์ค์ ์คํ๋ฉ๋๋ค. |
| 131 | +# ``run_worker`` ๋ ๊ฐ ํ๋ก์ธ์ค์ ์คํ๋ฉ๋๋ค. |
132 | 132 |
|
133 | 133 | def run_worker(rank, world_size): |
134 | 134 |
|
135 | 135 |
|
136 | 136 | ###################################################################### |
137 | 137 | # ๋ฐ์ดํฐ ๋ก๋ํ๊ณ ๋ฐฐ์น ๋ง๋ค๊ธฐ |
138 | | -# ------------------- |
| 138 | +# --------------------------- |
139 | 139 | # |
140 | 140 |
|
141 | 141 |
|
@@ -210,7 +210,7 @@ def batchify(data, bsz, rank, world_size, is_train=False): |
210 | 210 |
|
211 | 211 | ###################################################################### |
212 | 212 | # ์
๋ ฅ๊ณผ ํ๊ฒ ์ํ์ค๋ฅผ ์์ฑํ๊ธฐ ์ํ ํจ์๋ค |
213 | | -# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 213 | +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
214 | 214 | # |
215 | 215 |
|
216 | 216 |
|
@@ -239,15 +239,15 @@ def get_batch(source, i): |
239 | 239 |
|
240 | 240 | ###################################################################### |
241 | 241 | # ๋ชจ๋ธ ๊ท๋ชจ์ ํ์ดํ ์ด๊ธฐํ |
242 | | -# ----------------------------------- |
| 242 | +# ------------------------- |
243 | 243 | # |
244 | 244 |
|
245 | 245 |
|
246 | 246 | ###################################################################### |
247 | 247 | # ํ์ดํ๋ผ์ธ ๋ณ๋ ฌํ๋ฅผ ํ์ฉํ ๋ํ ํธ๋์คํฌ๋จธ ๋ชจ๋ธ ํ์ต์ ์ฆ๋ช
ํ๊ธฐ ์ํด, |
248 | 248 | # ํธ๋์คํฌ๋จธ ๊ณ์ธต ๊ท๋ชจ๋ฅผ ์ ์ ํ ํ์ฅ์ํต๋๋ค. We use an embedding |
249 | 249 | # 4096์ฐจ์์ ์๋ฒ ๋ฉ ๋ฒกํฐ, 4096์ ์๋ ์ฌ์ด์ฆ, 16๊ฐ์ ์ดํ
์
ํค๋(attention head)์ ์ด 8 ๊ฐ์ |
250 | | -# ํธ๋์คํฌ๋จธ ๊ณ์ธต (``nn.TransformerEncoderLayer``)๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด๋ ์ต๋ |
| 250 | +# ํธ๋์คํฌ๋จธ ๊ณ์ธต (``nn.TransformerEncoderLayer``)๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด๋ ์ต๋ |
251 | 251 | # **~1 ์ต** ๊ฐ์ ํ๋ผ๋ฏธํฐ๋ฅผ ๊ฐ๋ ๋ชจ๋ธ์ ์์ฑํฉ๋๋ค. |
252 | 252 | # |
253 | 253 | # `RPC ํ๋ ์์ํฌ <https://pytorch.org/docs/stable/rpc.html>`__ ๋ฅผ ์ด๊ธฐํํด์ผ ํฉ๋๋ค. |
@@ -434,7 +434,7 @@ def evaluate(eval_model, data_source): |
434 | 434 |
|
435 | 435 | ###################################################################### |
436 | 436 | # ํ๊ฐ ๋ฐ์ดํฐ์
์ผ๋ก ๋ชจ๋ธ ํ๊ฐํ๊ธฐ |
437 | | -# ------------------------------------- |
| 437 | +# ------------------------------- |
438 | 438 | # |
439 | 439 | # ํ๊ฐ ๋ฐ์ดํฐ์
์์์ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ๊ธฐ ์ํด ์ต๊ณ ์ ๋ชจ๋ธ์ ์ ์ฉํฉ๋๋ค. |
440 | 440 |
|
|
0 commit comments