Skip to content

Commit 02309fa

Browse files
committed
__call__ -> forward
1 parent eaa248c commit 02309fa

File tree

5 files changed

+9
-9
lines changed

5 files changed

+9
-9
lines changed

labml_nn/gan/cycle_gan/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, input_channels: int, n_residual_blocks: int):
112112
# Initialize weights to $\mathcal{N}(0, 0.2)$
113113
self.apply(weights_init_normal)
114114

115-
def __call__(self, x):
115+
def forward(self, x):
116116
return self.layers(x)
117117

118118

@@ -132,7 +132,7 @@ def __init__(self, in_features: int):
132132
nn.ReLU(inplace=True),
133133
)
134134

135-
def __call__(self, x: torch.Tensor):
135+
def forward(self, x: torch.Tensor):
136136
return x + self.block(x)
137137

138138

@@ -184,7 +184,7 @@ def __init__(self, in_filters: int, out_filters: int, normalize: bool = True):
184184
layers.append(nn.LeakyReLU(0.2, inplace=True))
185185
self.layers = nn.Sequential(*layers)
186186

187-
def __call__(self, x: torch.Tensor):
187+
def forward(self, x: torch.Tensor):
188188
return self.layers(x)
189189

190190

labml_nn/gan/dcgan/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self):
5353

5454
self.apply(_weights_init)
5555

56-
def __call__(self, x):
56+
def forward(self, x):
5757
# Change from shape `[batch_size, 100]` to `[batch_size, 100, 1, 1]`
5858
x = x.unsqueeze(-1).unsqueeze(-1)
5959
x = self.layers(x)

labml_nn/gan/original/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, smoothing: float = 0.2):
7676
self.register_buffer('labels_true', _create_labels(256, 1.0 - smoothing, 1.0), False)
7777
self.register_buffer('labels_false', _create_labels(256, 0.0, smoothing), False)
7878

79-
def __call__(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
79+
def forward(self, logits_true: torch.Tensor, logits_false: torch.Tensor):
8080
"""
8181
`logits_true` are logits from $D(\pmb{x}^{(i)})$ and
8282
`logits_false` are logits from $D(G(\pmb{z}^{(i)}))$
@@ -111,7 +111,7 @@ def __init__(self, smoothing: float = 0.2):
111111
# the above gradient.
112112
self.register_buffer('fake_labels', _create_labels(256, 1.0 - smoothing, 1.0), False)
113113

114-
def __call__(self, logits: torch.Tensor):
114+
def forward(self, logits: torch.Tensor):
115115
if len(logits) > len(self.fake_labels):
116116
self.register_buffer("fake_labels",
117117
_create_labels(len(logits), 1.0 - self.smoothing, 1.0, logits.device), False)

labml_nn/gan/wasserstein/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class DiscriminatorLoss(Module):
101101
\frac{1}{m} \sum_{i=1}^m f_w \big( g_\theta(z^{(i)}) \big)$$
102102
"""
103103

104-
def __call__(self, f_real: torch.Tensor, f_fake: torch.Tensor):
104+
def forward(self, f_real: torch.Tensor, f_fake: torch.Tensor):
105105
"""
106106
* `f_real` is $f_w(x)$
107107
* `f_fake` is $f_w(g_\theta(z))$
@@ -127,7 +127,7 @@ class GeneratorLoss(Module):
127127
128128
"""
129129

130-
def __call__(self, f_fake: torch.Tensor):
130+
def forward(self, f_fake: torch.Tensor):
131131
"""
132132
* `f_fake` is $f_w(g_\theta(z))$
133133
"""

labml_nn/gan/wasserstein/gradient_penalty/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ class GradientPenalty(Module):
5454
## Gradient Penalty
5555
"""
5656

57-
def __call__(self, x: torch.Tensor, f: torch.Tensor):
57+
def forward(self, x: torch.Tensor, f: torch.Tensor):
5858
"""
5959
* `x` is $x \sim \mathbb{P}_r$
6060
* `f` is $D(x)$

0 commit comments

Comments
 (0)