Skip to content

Commit a8e3695

Browse files
committed
__call__ -> forward
1 parent 0a4b5b6 commit a8e3695

File tree

4 files changed

+15
-15
lines changed

4 files changed

+15
-15
lines changed

labml_nn/sketch_rnn/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ def __init__(self, d_z: int, enc_hidden_size: int):
211211
# Head to get $\hat{\sigma}$
212212
self.sigma_head = nn.Linear(2 * enc_hidden_size, d_z)
213213

214-
def __call__(self, inputs: torch.Tensor, state=None):
214+
def forward(self, inputs: torch.Tensor, state=None):
215215
# The hidden state of the bidirectional LSTM is the concatenation of the
216216
# output of the last token in the forward direction and
217217
# first token in the reverse direction, which is what we want.
@@ -269,7 +269,7 @@ def __init__(self, d_z: int, dec_hidden_size: int, n_distributions: int):
269269
self.n_distributions = n_distributions
270270
self.dec_hidden_size = dec_hidden_size
271271

272-
def __call__(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
272+
def forward(self, x: torch.Tensor, z: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]]):
273273
# Calculate the initial state
274274
if state is None:
275275
# $[h_0; c_0] = \tanh(W_{z}z + b_z)$
@@ -314,7 +314,7 @@ class ReconstructionLoss(Module):
314314
## Reconstruction Loss
315315
"""
316316

317-
def __call__(self, mask: torch.Tensor, target: torch.Tensor,
317+
def forward(self, mask: torch.Tensor, target: torch.Tensor,
318318
dist: 'BivariateGaussianMixture', q_logits: torch.Tensor):
319319
# Get $\Pi$ and $\mathcal{N}(\mu_{x}, \mu_{y}, \sigma_{x}, \sigma_{y}, \rho_{xy})$
320320
pi, mix = dist.get_distribution()
@@ -355,7 +355,7 @@ class KLDivLoss(Module):
355355
This calculates the KL divergence between a given normal distribution and $\mathcal{N}(0, 1)$
356356
"""
357357

358-
def __call__(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
358+
def forward(self, sigma_hat: torch.Tensor, mu: torch.Tensor):
359359
# $$L_{KL} = - \frac{1}{2 N_z} \bigg( 1 + \hat{\sigma} - \mu^2 - \exp(\hat{\sigma}) \bigg)$$
360360
return -0.5 * torch.mean(1 + sigma_hat - mu ** 2 - torch.exp(sigma_hat))
361361

labml_nn/transformers/fast_weights/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def __init__(self, nu: int = 1, eps: float = 1e-6):
145145
self.relu = nn.ReLU()
146146
self.eps = eps
147147

148-
def __call__(self, k: torch.Tensor):
148+
def forward(self, k: torch.Tensor):
149149
# Get $\color{lightgreen}{\phi(k)}$
150150
k = self.dpfp(k)
151151
# Normalize by $\sum^{d_{dot}}_{j=1} \color{lightgreen}{\phi(k)_j}$
@@ -228,7 +228,7 @@ def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
228228
# Dropout
229229
self.dropout = nn.Dropout(dropout_prob)
230230

231-
def __call__(self, x: torch.Tensor):
231+
def forward(self, x: torch.Tensor):
232232
# Get the number of steps $L$
233233
seq_len = x.shape[0]
234234
# $\color{lightgreen}{\phi'(q^{(i)})}$ for all steps and heads
@@ -291,7 +291,7 @@ def __init__(self, *,
291291
self.norm_self_attn = nn.LayerNorm([d_model])
292292
self.norm_ff = nn.LayerNorm([d_model])
293293

294-
def __call__(self, x: torch.Tensor):
294+
def forward(self, x: torch.Tensor):
295295
# Calculate fast weights self attention
296296
attn = self.attn(x)
297297
# Add the self attention results
@@ -319,7 +319,7 @@ def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
319319
# Final normalization layer
320320
self.norm = nn.LayerNorm([layer.size])
321321

322-
def __call__(self, x: torch.Tensor):
322+
def forward(self, x: torch.Tensor):
323323
for i, layer in enumerate(self.layers):
324324
# Get layer output
325325
x = layer(x)

labml_nn/transformers/fast_weights/token_wise.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self, heads: int, d_model: int, dropout_prob: float, phi: DPFP):
4343
# Dropout
4444
self.dropout = nn.Dropout(dropout_prob)
4545

46-
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
46+
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
4747
query = self.phi(self.query(x))
4848
key = self.phi(self.key(x))
4949
value = self.value(x)
@@ -84,7 +84,7 @@ def __init__(self, *,
8484
self.norm_self_attn = nn.LayerNorm([d_model])
8585
self.norm_ff = nn.LayerNorm([d_model])
8686

87-
def __call__(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
87+
def forward(self, x: torch.Tensor, weights: Optional[torch.Tensor]):
8888
attn, weights = self.attn(x, weights)
8989
# Add the self attention results
9090
x = x + self.dropout(attn)
@@ -108,7 +108,7 @@ def __init__(self, layer: FastWeightsAttentionTransformerLayer, n_layers: int):
108108
# Final normalization layer
109109
self.norm = nn.LayerNorm([layer.size])
110110

111-
def __call__(self, x_seq: torch.Tensor):
111+
def forward(self, x_seq: torch.Tensor):
112112
# Split the input to a list along the sequence axis
113113
x_seq = torch.unbind(x_seq, dim=0)
114114
# List to store the outputs

labml_nn/transformers/vit/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, d_model: int, patch_size: int, in_channels: int):
7575
# transformation on each patch.
7676
self.conv = nn.Conv2d(in_channels, d_model, patch_size, stride=patch_size)
7777

78-
def __call__(self, x: torch.Tensor):
78+
def forward(self, x: torch.Tensor):
7979
"""
8080
* `x` is the input image of shape `[batch_size, channels, height, width]`
8181
"""
@@ -109,7 +109,7 @@ def __init__(self, d_model: int, max_len: int = 5_000):
109109
# Positional embeddings for each location
110110
self.positional_encodings = nn.Parameter(torch.zeros(max_len, 1, d_model), requires_grad=True)
111111

112-
def __call__(self, x: torch.Tensor):
112+
def forward(self, x: torch.Tensor):
113113
"""
114114
* `x` is the patch embeddings of shape `[patches, batch_size, d_model]`
115115
"""
@@ -141,7 +141,7 @@ def __init__(self, d_model: int, n_hidden: int, n_classes: int):
141141
# Second layer
142142
self.linear2 = nn.Linear(n_hidden, n_classes)
143143

144-
def __call__(self, x: torch.Tensor):
144+
def forward(self, x: torch.Tensor):
145145
"""
146146
* `x` is the transformer encoding for `[CLS]` token
147147
"""
@@ -187,7 +187,7 @@ def __init__(self, transformer_layer: TransformerLayer, n_layers: int,
187187
# Final normalization layer
188188
self.ln = nn.LayerNorm([transformer_layer.size])
189189

190-
def __call__(self, x: torch.Tensor):
190+
def forward(self, x: torch.Tensor):
191191
"""
192192
* `x` is the input image of shape `[batch_size, channels, height, width]`
193193
"""

0 commit comments

Comments
 (0)