Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorchlightning #98

Merged
merged 136 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
136 commits
Select commit Hold shift + click to select a range
daa77a1
Switching wrapper to pytorch-lightning
jameschapman19 Oct 27, 2021
452fc3b
Switching wrapper to pytorch-lightning
jameschapman19 Oct 27, 2021
79dca6d
Switching wrapper to pytorch-lightning
jameschapman19 Oct 27, 2021
7f41007
Switching wrapper to pytorch-lightning
jameschapman19 Oct 27, 2021
7432948
Switching wrapper to pytorch-lightning
jameschapman19 Oct 27, 2021
e003334
Merge branch 'main' of https://github.com/jameschapman19/MultiViewMet…
jameschapman19 Nov 2, 2021
88caa14
Switching wrapper to pytorch-lightning
jameschapman19 Nov 2, 2021
a055b04
Switching wrapper to pytorch-lightning
jameschapman19 Nov 5, 2021
cd94b3b
Switching wrapper to pytorch-lightning
jameschapman19 Nov 5, 2021
7b60f41
Switching wrapper to pytorch-lightning
jameschapman19 Nov 11, 2021
2221131
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
7792dbe
Merge remote-tracking branch 'origin/pytorchlightning' into pytorchli…
jameschapman19 Nov 13, 2021
a605a0c
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
b64fbcb
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
610d717
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
303a1bc
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
dee4766
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
4ccc829
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
fcef45b
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
9a77ab6
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
f361ab0
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
c0bcf6f
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
b9a90e3
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
5aac39c
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
dbdf962
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
90d82c5
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
cd346e9
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
cffff4a
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
88ca119
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
774b8ad
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
512d243
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
475f2c6
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
03d5f54
Switching tutorials to sphinx gallery
jameschapman19 Nov 13, 2021
c666eca
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
d69cb50
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
80f9845
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
302ec31
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
fa74e85
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
9bf2aca
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
21e24cf
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
7700420
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
045e9a3
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
928789b
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
3b1b50d
Merge branch 'dev' into pytorchlightning
jameschapman19 Nov 14, 2021
b8a8a03
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
f12c430
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
f967e0a
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
2ff74b8
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
fe295d1
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
2918dae
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
88190a2
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
e11260f
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
dfd26f2
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
e449dc9
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
4ab374a
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
6fdf50f
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
7d694a8
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
7fe2c6a
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
1e30447
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
fa44dcc
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
4054793
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
1c41fcb
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
a2d37df
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
e4107e4
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
02f991e
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
2a5fc2e
Switching tutorials to sphinx gallery
jameschapman19 Nov 14, 2021
4063e84
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
e255bc6
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
ccc11a1
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
67ca66a
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
2adfa6b
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
b74ffbd
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
323518a
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
5465444
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
8c4a36f
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
e88c804
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
a0292ed
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
1aec37b
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
7fed1b0
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
cdf18d1
Update python-package.yml
jameschapman19 Nov 15, 2021
81ff36c
Update python-package.yml
jameschapman19 Nov 15, 2021
49e5d87
Switching tutorials to sphinx gallery
jameschapman19 Nov 15, 2021
a5b282e
Merge remote-tracking branch 'origin/pytorchlightning' into pytorchli…
jameschapman19 Nov 15, 2021
e103c73
Doctests
jameschapman19 Nov 15, 2021
3ea22a8
Doctests
jameschapman19 Nov 15, 2021
728ac41
Doctests
jameschapman19 Nov 15, 2021
33762d3
Doctests
jameschapman19 Nov 15, 2021
022d6c5
Doctests
jameschapman19 Nov 16, 2021
a71db44
Doctests
jameschapman19 Nov 16, 2021
1583ef7
Doctests
jameschapman19 Nov 16, 2021
a1c43db
Doctests
jameschapman19 Nov 16, 2021
701d80d
Doctests
jameschapman19 Nov 16, 2021
6532d82
Doctests
jameschapman19 Nov 16, 2021
6788bab
black
jameschapman19 Nov 16, 2021
59d5ba7
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
93aa41c
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
b643e05
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
1ab5ed7
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
5d1be86
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
f0dc07b
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
7de358b
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
3282116
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
29a2bc3
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
a26bfb4
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
19bcce5
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
6e70bc9
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
1da0419
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
df37754
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
aa73ff8
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
af07f18
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
5aaf6b3
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
8141bec
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
31701a3
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
31cd5b7
Remove Tutorial Notebooks as they are replaced with sphinx-gallery
jameschapman19 Nov 16, 2021
c6fe837
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
2153ac3
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
c361eb2
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
d2348b6
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
8519824
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
4048c08
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
ae4e7f2
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
cbbc4a3
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
a53b229
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
8608705
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
4482afc
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
1a02cf9
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
c24e82f
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
f5c1a1b
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
0055712
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
e72af45
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
810ba76
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
e2c60a8
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
d43f928
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
4eb6461
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 16, 2021
7826a4b
Removing the slow constrained option for Elastic CCA. Now uses maxvar…
jameschapman19 Nov 17, 2021
156efc4
Changed iterative error to warning. This stops gridsearch breaking if…
jameschapman19 Nov 17, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ jobs:
build:

runs-on: ubuntu-latest
env:
MODULE_NAME: cca_zoo
strategy:
matrix:
python-version: [3.7, 3.8, 3.9]
Expand Down Expand Up @@ -44,3 +46,6 @@ jobs:
uses: codecov/codecov-action@v1
with:
fail_ci_if_error: true
- name: Run doctests
run: |
pytest --doctest-modules --ignore=$MODULE_NAME/tests $MODULE_NAME
1 change: 0 additions & 1 deletion cca_zoo/data/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
self.dataset = datasets.KMNIST("../../data", train=train, download=True)

self.data = self.dataset.data
self.base_transform = transforms.ToTensor()
self.targets = self.dataset.targets
self.flatten = flatten

Expand Down
5 changes: 4 additions & 1 deletion cca_zoo/deepmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import cca_zoo.deepmodels.objectives
from ._dcca_base import _DCCA_base
from .dcca import DCCA
from .dcca_barlow_twins import BarlowTwins
from .dcca_noi import DCCA_NOI
from .dcca_sdl import DCCA_SDL
from .dccae import DCCAE
from .deepwrapper import DeepWrapper
from .dtcca import DTCCA
from .dvcca import DVCCA
from .splitae import SplitAE
from .trainers import CCALightning
from .utils import get_dataloaders, process_data
7 changes: 7 additions & 0 deletions cca_zoo/deepmodels/_dcca_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def forward(self, *args):
"""
raise NotImplementedError

@abstractmethod
def loss(self, *args, **kwargs):
"""
Required when using the LightningTrainer
"""
raise NotImplementedError

def post_transform(self, *z_list, train=False) -> Iterable[np.ndarray]:
"""
Some models require a final linear CCA after model training.
Expand Down
8 changes: 4 additions & 4 deletions cca_zoo/deepmodels/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ class DCCA(_DCCA_base):
"""
A class used to fit a DCCA model.

Examples
--------
>>> from cca_zoo.deepmodels import DCCA
>>> model = DCCA()
:Citation:

Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013.

"""

def __init__(
Expand Down
51 changes: 51 additions & 0 deletions cca_zoo/deepmodels/dcca_barlow_twins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable

import torch

from cca_zoo.deepmodels import DCCA
from cca_zoo.deepmodels.architectures import BaseEncoder, Encoder


class BarlowTwins(DCCA):
"""
A class used to fit a Barlow Twins model.

:Citation:

Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021).

"""

def __init__(
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
):
"""
Constructor class for Barlow Twins

:param latent_dims: # latent dimensions
:param encoders: list of encoder networks
:param lam: weighting of off diagonal loss terms
"""
super().__init__(latent_dims=latent_dims, encoders=encoders)
self.lam = lam
self.bns = torch.nn.ModuleList(
[torch.nn.BatchNorm1d(latent_dims, affine=False) for _ in self.encoders]
)

def forward(self, *args):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(args[i])))
return tuple(z)

def loss(self, *args):
z = self(*args)
cross_cov = z[0].T @ z[1] / (z[0].shape[0] - 1)
invariance = torch.mean(torch.pow(1 - torch.diag(cross_cov), 2))
covariance = torch.mean(
torch.triu(torch.pow(cross_cov, 2), diagonal=1)
) + torch.mean(torch.tril(torch.pow(cross_cov, 2), diagonal=-1))
return invariance + covariance
40 changes: 17 additions & 23 deletions cca_zoo/deepmodels/dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class DCCA_NOI(DCCA):
"""
A class used to fit a DCCA model by non-linear orthogonal iterations


:Citation:

Wang, Weiran, et al. "Stochastic optimization for deep CCA via nonlinear orthogonal iterations." 2015 53rd Annual Allerton Conference on Communication, Control, and Computing (Allerton). IEEE, 2015.

"""

def __init__(
Expand All @@ -17,7 +22,7 @@ def __init__(
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-3,
eps: float = 1e-9,
shared_target: bool = False,
):
"""
Expand All @@ -39,7 +44,7 @@ def __init__(
self.eps = eps
self.rho = rho
self.shared_target = shared_target
self.mse = torch.nn.MSELoss()
self.mse = torch.nn.MSELoss(reduction="sum")
# Authors state that a final linear layer is an important part of their algorithmic implementation
self.linear_layers = torch.nn.ModuleList(
[
Expand All @@ -61,7 +66,7 @@ def forward(self, *args):
def loss(self, *args):
z = self(*args)
z_copy = [z_.detach().clone() for z_ in z]
self.update_covariances(*z_copy)
self._update_covariances(*z_copy)
covariance_inv = [
torch.linalg.inv(objectives.MatrixSquareRoot.apply(cov))
for cov in self.covs
Expand All @@ -70,25 +75,14 @@ def loss(self, *args):
loss = self.mse(z[0], preds[1]) + self.mse(z[1], preds[0])
return loss

def update_mean(self, *z):
batch_means = [torch.mean(z_, dim=0) for z_ in z]
if self.means is not None:
self.means = [
self.rho * self.means[i].detach() + (1 - self.rho) * batch_mean
for i, batch_mean in enumerate(batch_means)
]
else:
self.means = batch_means
z = [z_ - mean for (z_, mean) in zip(z, self.means)]
return z

def update_covariances(self, *z):
def _update_covariances(self, *z, train=True):
b = z[0].shape[0]
batch_covs = [self.N * z_.T @ z_ / b for z_ in z]
if self.covs is not None:
self.covs = [
self.rho * self.covs[i] + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.covs = batch_covs
if train:
if self.covs is not None:
self.covs = [
self.rho * self.covs[i] + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.covs = batch_covs
92 changes: 92 additions & 0 deletions cca_zoo/deepmodels/dcca_sdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import torch.nn.functional as F

from cca_zoo.deepmodels import DCCA_NOI


class DCCA_SDL(DCCA_NOI):
"""
A class used to fit a Deep CCA by Stochastic Decorrelation model.

:Citation:

Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep CCA via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.

"""

def __init__(
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-3,
shared_target: bool = False,
lam=0.5,
):
"""
Constructor class for DCCA
:param latent_dims: # latent dimensions
:param encoders: list of encoder networks
:param r: regularisation parameter of tracenorm CCA like ridge CCA
:param rho: covariance memory like DCCA non-linear orthogonal iterations paper
:param eps: epsilon used throughout
:param shared_target: not used
"""
super().__init__(
latent_dims=latent_dims,
N=N,
encoders=encoders,
r=r,
rho=rho,
eps=eps,
shared_target=shared_target,
)
self.c = None
self.cross_cov = None
self.lam = lam
self.bns = torch.nn.ModuleList(
[
torch.nn.BatchNorm1d(latent_dims, affine=False)
for _ in range(latent_dims)
]
)

def forward(self, *args):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(args[i])))
return tuple(z)

def loss(self, *args):
z = self(*args)
self._update_covariances(*z, train=self.training)
SDL_loss = self._sdl_loss(self.covs)
l2_loss = F.mse_loss(z[0], z[1])
return l2_loss + self.lam * SDL_loss

def _sdl_loss(self, covs):
loss = 0
for cov in covs:
cov = cov
sgn = torch.sign(cov)
sgn.fill_diagonal_(0)
loss += torch.mean(cov * sgn)
return loss

def _update_covariances(self, *z, train=True):
batch_covs = [z_.T @ z_ for z_ in z]
if train:
if self.c is not None:
self.c = self.rho * self.c + 1
self.covs = [
self.rho * self.covs[i].detach() + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.c = 1
self.covs = batch_covs
# pytorch-lightning runs validation once so this just fixes the bug
elif self.covs is None:
self.covs = batch_covs
13 changes: 6 additions & 7 deletions cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ class DCCAE(_DCCA_base):
"""
A class used to fit a DCCAE model.

Examples
--------
>>> from cca_zoo.deepmodels import DCCAE
>>> model = DCCAE()
:Citation:

Wang, Weiran, et al. "On deep multi-view representation learning." International conference on machine learning. PMLR, 2015.

"""

def __init__(
Expand Down Expand Up @@ -57,7 +57,6 @@ def decode(self, *z):
"""
This method is used to decode from the latent space to the best prediction of the original views

:param args:
"""
recon = []
for i, decoder in enumerate(self.decoders):
Expand All @@ -67,11 +66,11 @@ def decode(self, *z):
def loss(self, *args):
z = self(*args)
recon = self.decode(*z)
recon_loss = self.recon_loss(args[: len(recon)], recon)
recon_loss = self._recon_loss(args[: len(recon)], recon)
return self.lam * recon_loss + self.objective.loss(*z)

@staticmethod
def recon_loss(x, recon):
def _recon_loss(x, recon):
recons = [
F.mse_loss(recon_, x_, reduction="mean") for recon_, x_ in zip(recon, x)
]
Expand Down
Loading