Skip to content

Commit b6bef1d

Browse files
committed
cleanup
1 parent ab4264c commit b6bef1d

File tree

8 files changed

+215
-223
lines changed

8 files changed

+215
-223
lines changed

docs/transformers/basic/autoregressive_experiment.html

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ <h1>Transformer Auto-Regression Experiment</h1>
7676
</div>
7777
<div class='code'>
7878
<div class="highlight"><pre><span class="lineno">17</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
79-
<span class="lineno">18</span>
80-
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
81-
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
82-
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
79+
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">torch</span> <span class="kn">import</span> <span class="n">nn</span>
80+
<span class="lineno">19</span>
81+
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
82+
<span class="lineno">21</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">option</span>
8383
<span class="lineno">22</span><span class="kn">from</span> <span class="nn">labml_nn.experiments.nlp_autoregression</span> <span class="kn">import</span> <span class="n">NLPAutoRegressionConfigs</span>
8484
<span class="lineno">23</span><span class="kn">from</span> <span class="nn">labml_nn.transformers</span> <span class="kn">import</span> <span class="n">TransformerConfigs</span><span class="p">,</span> <span class="n">Encoder</span>
8585
<span class="lineno">24</span><span class="kn">from</span> <span class="nn">labml_nn.transformers.utils</span> <span class="kn">import</span> <span class="n">subsequent_mask</span></pre></div>
@@ -94,7 +94,7 @@ <h2>Auto-Regressive model</h2>
9494

9595
</div>
9696
<div class='code'>
97-
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">Module</span><span class="p">):</span></pre></div>
97+
<div class="highlight"><pre><span class="lineno">27</span><span class="k">class</span> <span class="nc">AutoregressiveTransformer</span><span class="p">(</span><span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
9898
</div>
9999
</div>
100100
<div class='section' id='section-2'>
@@ -111,7 +111,7 @@ <h2>Auto-Regressive model</h2>
111111

112112
</div>
113113
<div class='code'>
114-
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">Module</span><span class="p">):</span></pre></div>
114+
<div class="highlight"><pre><span class="lineno">31</span> <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">encoder</span><span class="p">:</span> <span class="n">Encoder</span><span class="p">,</span> <span class="n">src_embed</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">,</span> <span class="n">generator</span><span class="p">:</span> <span class="n">nn</span><span class="o">.</span><span class="n">Module</span><span class="p">):</span></pre></div>
115115
</div>
116116
</div>
117117
<div class='section' id='section-3'>

docs/transformers/mha.html

Lines changed: 57 additions & 58 deletions
Large diffs are not rendered by default.

docs/transformers/models.html

Lines changed: 102 additions & 103 deletions
Large diffs are not rendered by default.

docs/transformers/positional_encoding.html

Lines changed: 35 additions & 37 deletions
Large diffs are not rendered by default.

labml_nn/transformers/basic/autoregressive_experiment.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,20 @@
1515
"""
1616

1717
import torch
18+
from torch import nn
1819

1920
from labml import experiment
2021
from labml.configs import option
21-
from labml_helpers.module import Module
2222
from labml_nn.experiments.nlp_autoregression import NLPAutoRegressionConfigs
2323
from labml_nn.transformers import TransformerConfigs, Encoder
2424
from labml_nn.transformers.utils import subsequent_mask
2525

2626

27-
class AutoregressiveTransformer(Module):
27+
class AutoregressiveTransformer(nn.Module):
2828
"""
2929
## Auto-Regressive model
3030
"""
31-
def __init__(self, encoder: Encoder, src_embed: Module, generator: Module):
31+
def __init__(self, encoder: Encoder, src_embed: nn.Module, generator: nn.Module):
3232
"""
3333
* `encoder` is the transformer [Encoder](../models.html#Encoder)
3434
* `src_embed` is the token

labml_nn/transformers/mha.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,12 @@
2626
from typing import Optional, List
2727

2828
import torch
29-
from torch import nn as nn
29+
from torch import nn
3030

3131
from labml import tracker
32-
from labml_helpers.module import Module
3332

3433

35-
class PrepareForMultiHeadAttention(Module):
34+
class PrepareForMultiHeadAttention(nn.Module):
3635
"""
3736
<a id="PrepareMHA"></a>
3837
@@ -68,7 +67,7 @@ def forward(self, x: torch.Tensor):
6867
return x
6968

7069

71-
class MultiHeadAttention(Module):
70+
class MultiHeadAttention(nn.Module):
7271
r"""
7372
<a id="MHA"></a>
7473

labml_nn/transformers/models.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,14 @@
1515

1616
import torch
1717
import torch.nn as nn
18-
from labml_helpers.module import Module
1918

2019
from labml_nn.utils import clone_module_list
2120
from .feed_forward import FeedForward
2221
from .mha import MultiHeadAttention
2322
from .positional_encoding import get_positional_encoding
2423

2524

26-
class EmbeddingsWithPositionalEncoding(Module):
25+
class EmbeddingsWithPositionalEncoding(nn.Module):
2726
"""
2827
<a id="EmbeddingsWithPositionalEncoding"></a>
2928
@@ -41,7 +40,7 @@ def forward(self, x: torch.Tensor):
4140
return self.linear(x) * math.sqrt(self.d_model) + pe
4241

4342

44-
class EmbeddingsWithLearnedPositionalEncoding(Module):
43+
class EmbeddingsWithLearnedPositionalEncoding(nn.Module):
4544
"""
4645
<a id="EmbeddingsWithLearnedPositionalEncoding"></a>
4746
@@ -59,7 +58,7 @@ def forward(self, x: torch.Tensor):
5958
return self.linear(x) * math.sqrt(self.d_model) + pe
6059

6160

62-
class TransformerLayer(Module):
61+
class TransformerLayer(nn.Module):
6362
"""
6463
<a id="TransformerLayer"></a>
6564
@@ -139,7 +138,7 @@ def forward(self, *,
139138
return x
140139

141140

142-
class Encoder(Module):
141+
class Encoder(nn.Module):
143142
"""
144143
<a id="Encoder"></a>
145144
@@ -161,7 +160,7 @@ def forward(self, x: torch.Tensor, mask: torch.Tensor):
161160
return self.norm(x)
162161

163162

164-
class Decoder(Module):
163+
class Decoder(nn.Module):
165164
"""
166165
<a id="Decoder"></a>
167166
@@ -183,7 +182,7 @@ def forward(self, x: torch.Tensor, memory: torch.Tensor, src_mask: torch.Tensor,
183182
return self.norm(x)
184183

185184

186-
class Generator(Module):
185+
class Generator(nn.Module):
187186
"""
188187
<a id="Generator"></a>
189188
@@ -201,14 +200,14 @@ def forward(self, x):
201200
return self.projection(x)
202201

203202

204-
class EncoderDecoder(Module):
203+
class EncoderDecoder(nn.Module):
205204
"""
206205
<a id="EncoderDecoder"></a>
207206
208207
## Combined Encoder-Decoder
209208
"""
210209

211-
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: Module, tgt_embed: Module, generator: Module):
210+
def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: nn.Module, tgt_embed: nn.Module, generator: nn.Module):
212211
super().__init__()
213212
self.encoder = encoder
214213
self.decoder = decoder

labml_nn/transformers/positional_encoding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,8 @@
2626
import torch
2727
import torch.nn as nn
2828

29-
from labml_helpers.module import Module
3029

31-
32-
class PositionalEncoding(Module):
30+
class PositionalEncoding(nn.Module):
3331
def __init__(self, d_model: int, dropout_prob: float, max_len: int = 5000):
3432
super().__init__()
3533
self.dropout = nn.Dropout(dropout_prob)

0 commit comments

Comments
 (0)