Skip to content

Commit

Permalink
Merge pull request #98 from jameschapman19/pytorchlightning
Browse files Browse the repository at this point in the history
Pytorchlightning
  • Loading branch information
jameschapman19 authored Nov 17, 2021
2 parents d06e913 + 156efc4 commit b1318ba
Show file tree
Hide file tree
Showing 50 changed files with 1,799 additions and 7,804 deletions.
1 change: 0 additions & 1 deletion cca_zoo/data/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def __init__(
self.dataset = datasets.KMNIST("../../data", train=train, download=True)

self.data = self.dataset.data
self.base_transform = transforms.ToTensor()
self.targets = self.dataset.targets
self.flatten = flatten

Expand Down
5 changes: 4 additions & 1 deletion cca_zoo/deepmodels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@
import cca_zoo.deepmodels.objectives
from ._dcca_base import _DCCA_base
from .dcca import DCCA
from .dcca_barlow_twins import BarlowTwins
from .dcca_noi import DCCA_NOI
from .dcca_sdl import DCCA_SDL
from .dccae import DCCAE
from .deepwrapper import DeepWrapper
from .dtcca import DTCCA
from .dvcca import DVCCA
from .splitae import SplitAE
from .trainers import CCALightning
from .utils import get_dataloaders, process_data
7 changes: 7 additions & 0 deletions cca_zoo/deepmodels/_dcca_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@ def forward(self, *args):
"""
raise NotImplementedError

@abstractmethod
def loss(self, *args, **kwargs):
"""
Required when using the LightningTrainer
"""
raise NotImplementedError

def post_transform(self, *z_list, train=False) -> Iterable[np.ndarray]:
"""
Some models require a final linear CCA after model training.
Expand Down
8 changes: 4 additions & 4 deletions cca_zoo/deepmodels/dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ class DCCA(_DCCA_base):
"""
A class used to fit a DCCA model.
Examples
--------
>>> from cca_zoo.deepmodels import DCCA
>>> model = DCCA()
:Citation:
Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013.
"""

def __init__(
Expand Down
51 changes: 51 additions & 0 deletions cca_zoo/deepmodels/dcca_barlow_twins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Iterable

import torch

from cca_zoo.deepmodels import DCCA
from cca_zoo.deepmodels.architectures import BaseEncoder, Encoder


class BarlowTwins(DCCA):
"""
A class used to fit a Barlow Twins model.
:Citation:
Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021).
"""

def __init__(
self,
latent_dims: int,
encoders: Iterable[BaseEncoder] = [Encoder, Encoder],
lam=1,
):
"""
Constructor class for Barlow Twins
:param latent_dims: # latent dimensions
:param encoders: list of encoder networks
:param lam: weighting of off diagonal loss terms
"""
super().__init__(latent_dims=latent_dims, encoders=encoders)
self.lam = lam
self.bns = torch.nn.ModuleList(
[torch.nn.BatchNorm1d(latent_dims, affine=False) for _ in self.encoders]
)

def forward(self, *args):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(args[i])))
return tuple(z)

def loss(self, *args):
z = self(*args)
cross_cov = z[0].T @ z[1] / (z[0].shape[0] - 1)
invariance = torch.mean(torch.pow(1 - torch.diag(cross_cov), 2))
covariance = torch.mean(
torch.triu(torch.pow(cross_cov, 2), diagonal=1)
) + torch.mean(torch.tril(torch.pow(cross_cov, 2), diagonal=-1))
return invariance + covariance
40 changes: 17 additions & 23 deletions cca_zoo/deepmodels/dcca_noi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ class DCCA_NOI(DCCA):
"""
A class used to fit a DCCA model by non-linear orthogonal iterations
:Citation:
Wang, Weiran, et al. "Stochastic optimization for deep CCA via nonlinear orthogonal iterations." 2015 53rd Annual Allerton Conference on Communication, Control, and Computing (Allerton). IEEE, 2015.
"""

def __init__(
Expand All @@ -17,7 +22,7 @@ def __init__(
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-3,
eps: float = 1e-9,
shared_target: bool = False,
):
"""
Expand All @@ -39,7 +44,7 @@ def __init__(
self.eps = eps
self.rho = rho
self.shared_target = shared_target
self.mse = torch.nn.MSELoss()
self.mse = torch.nn.MSELoss(reduction="sum")
# Authors state that a final linear layer is an important part of their algorithmic implementation
self.linear_layers = torch.nn.ModuleList(
[
Expand All @@ -61,7 +66,7 @@ def forward(self, *args):
def loss(self, *args):
z = self(*args)
z_copy = [z_.detach().clone() for z_ in z]
self.update_covariances(*z_copy)
self._update_covariances(*z_copy)
covariance_inv = [
torch.linalg.inv(objectives.MatrixSquareRoot.apply(cov))
for cov in self.covs
Expand All @@ -70,25 +75,14 @@ def loss(self, *args):
loss = self.mse(z[0], preds[1]) + self.mse(z[1], preds[0])
return loss

def update_mean(self, *z):
batch_means = [torch.mean(z_, dim=0) for z_ in z]
if self.means is not None:
self.means = [
self.rho * self.means[i].detach() + (1 - self.rho) * batch_mean
for i, batch_mean in enumerate(batch_means)
]
else:
self.means = batch_means
z = [z_ - mean for (z_, mean) in zip(z, self.means)]
return z

def update_covariances(self, *z):
def _update_covariances(self, *z, train=True):
b = z[0].shape[0]
batch_covs = [self.N * z_.T @ z_ / b for z_ in z]
if self.covs is not None:
self.covs = [
self.rho * self.covs[i] + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.covs = batch_covs
if train:
if self.covs is not None:
self.covs = [
self.rho * self.covs[i] + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.covs = batch_covs
92 changes: 92 additions & 0 deletions cca_zoo/deepmodels/dcca_sdl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import torch
import torch.nn.functional as F

from cca_zoo.deepmodels import DCCA_NOI


class DCCA_SDL(DCCA_NOI):
"""
A class used to fit a Deep CCA by Stochastic Decorrelation model.
:Citation:
Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep CCA via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018.
"""

def __init__(
self,
latent_dims: int,
N: int,
encoders=None,
r: float = 0,
rho: float = 0.2,
eps: float = 1e-3,
shared_target: bool = False,
lam=0.5,
):
"""
Constructor class for DCCA
:param latent_dims: # latent dimensions
:param encoders: list of encoder networks
:param r: regularisation parameter of tracenorm CCA like ridge CCA
:param rho: covariance memory like DCCA non-linear orthogonal iterations paper
:param eps: epsilon used throughout
:param shared_target: not used
"""
super().__init__(
latent_dims=latent_dims,
N=N,
encoders=encoders,
r=r,
rho=rho,
eps=eps,
shared_target=shared_target,
)
self.c = None
self.cross_cov = None
self.lam = lam
self.bns = torch.nn.ModuleList(
[
torch.nn.BatchNorm1d(latent_dims, affine=False)
for _ in range(latent_dims)
]
)

def forward(self, *args):
z = []
for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)):
z.append(bn(encoder(args[i])))
return tuple(z)

def loss(self, *args):
z = self(*args)
self._update_covariances(*z, train=self.training)
SDL_loss = self._sdl_loss(self.covs)
l2_loss = F.mse_loss(z[0], z[1])
return l2_loss + self.lam * SDL_loss

def _sdl_loss(self, covs):
loss = 0
for cov in covs:
cov = cov
sgn = torch.sign(cov)
sgn.fill_diagonal_(0)
loss += torch.mean(cov * sgn)
return loss

def _update_covariances(self, *z, train=True):
batch_covs = [z_.T @ z_ for z_ in z]
if train:
if self.c is not None:
self.c = self.rho * self.c + 1
self.covs = [
self.rho * self.covs[i].detach() + (1 - self.rho) * batch_cov
for i, batch_cov in enumerate(batch_covs)
]
else:
self.c = 1
self.covs = batch_covs
# pytorch-lightning runs validation once so this just fixes the bug
elif self.covs is None:
self.covs = batch_covs
13 changes: 6 additions & 7 deletions cca_zoo/deepmodels/dccae.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ class DCCAE(_DCCA_base):
"""
A class used to fit a DCCAE model.
Examples
--------
>>> from cca_zoo.deepmodels import DCCAE
>>> model = DCCAE()
:Citation:
Wang, Weiran, et al. "On deep multi-view representation learning." International conference on machine learning. PMLR, 2015.
"""

def __init__(
Expand Down Expand Up @@ -57,7 +57,6 @@ def decode(self, *z):
"""
This method is used to decode from the latent space to the best prediction of the original views
:param args:
"""
recon = []
for i, decoder in enumerate(self.decoders):
Expand All @@ -67,11 +66,11 @@ def decode(self, *z):
def loss(self, *args):
z = self(*args)
recon = self.decode(*z)
recon_loss = self.recon_loss(args[: len(recon)], recon)
recon_loss = self._recon_loss(args[: len(recon)], recon)
return self.lam * recon_loss + self.objective.loss(*z)

@staticmethod
def recon_loss(x, recon):
def _recon_loss(x, recon):
recons = [
F.mse_loss(recon_, x_, reduction="mean") for recon_, x_ in zip(recon, x)
]
Expand Down
Loading

0 comments on commit b1318ba

Please sign in to comment.