diff --git a/cca_zoo/data/toy.py b/cca_zoo/data/toy.py index bae2b882..236f8125 100644 --- a/cca_zoo/data/toy.py +++ b/cca_zoo/data/toy.py @@ -34,7 +34,6 @@ def __init__( self.dataset = datasets.KMNIST("../../data", train=train, download=True) self.data = self.dataset.data - self.base_transform = transforms.ToTensor() self.targets = self.dataset.targets self.flatten = flatten diff --git a/cca_zoo/deepmodels/__init__.py b/cca_zoo/deepmodels/__init__.py index 9aeda2af..97d43dec 100644 --- a/cca_zoo/deepmodels/__init__.py +++ b/cca_zoo/deepmodels/__init__.py @@ -2,9 +2,12 @@ import cca_zoo.deepmodels.objectives from ._dcca_base import _DCCA_base from .dcca import DCCA +from .dcca_barlow_twins import BarlowTwins from .dcca_noi import DCCA_NOI +from .dcca_sdl import DCCA_SDL from .dccae import DCCAE -from .deepwrapper import DeepWrapper from .dtcca import DTCCA from .dvcca import DVCCA from .splitae import SplitAE +from .trainers import CCALightning +from .utils import get_dataloaders, process_data diff --git a/cca_zoo/deepmodels/_dcca_base.py b/cca_zoo/deepmodels/_dcca_base.py index c1bb1f40..7ad94a86 100644 --- a/cca_zoo/deepmodels/_dcca_base.py +++ b/cca_zoo/deepmodels/_dcca_base.py @@ -19,6 +19,13 @@ def forward(self, *args): """ raise NotImplementedError + @abstractmethod + def loss(self, *args, **kwargs): + """ + Required when using the LightningTrainer + """ + raise NotImplementedError + def post_transform(self, *z_list, train=False) -> Iterable[np.ndarray]: """ Some models require a final linear CCA after model training. diff --git a/cca_zoo/deepmodels/dcca.py b/cca_zoo/deepmodels/dcca.py index 7e05e09e..9a58c863 100644 --- a/cca_zoo/deepmodels/dcca.py +++ b/cca_zoo/deepmodels/dcca.py @@ -10,10 +10,10 @@ class DCCA(_DCCA_base): """ A class used to fit a DCCA model. - Examples - -------- - >>> from cca_zoo.deepmodels import DCCA - >>> model = DCCA() + :Citation: + + Andrew, Galen, et al. "Deep canonical correlation analysis." International conference on machine learning. PMLR, 2013. + """ def __init__( diff --git a/cca_zoo/deepmodels/dcca_barlow_twins.py b/cca_zoo/deepmodels/dcca_barlow_twins.py new file mode 100644 index 00000000..27c1da94 --- /dev/null +++ b/cca_zoo/deepmodels/dcca_barlow_twins.py @@ -0,0 +1,51 @@ +from typing import Iterable + +import torch + +from cca_zoo.deepmodels import DCCA +from cca_zoo.deepmodels.architectures import BaseEncoder, Encoder + + +class BarlowTwins(DCCA): + """ + A class used to fit a Barlow Twins model. + + :Citation: + + Zbontar, Jure, et al. "Barlow twins: Self-supervised learning via redundancy reduction." arXiv preprint arXiv:2103.03230 (2021). + + """ + + def __init__( + self, + latent_dims: int, + encoders: Iterable[BaseEncoder] = [Encoder, Encoder], + lam=1, + ): + """ + Constructor class for Barlow Twins + + :param latent_dims: # latent dimensions + :param encoders: list of encoder networks + :param lam: weighting of off diagonal loss terms + """ + super().__init__(latent_dims=latent_dims, encoders=encoders) + self.lam = lam + self.bns = torch.nn.ModuleList( + [torch.nn.BatchNorm1d(latent_dims, affine=False) for _ in self.encoders] + ) + + def forward(self, *args): + z = [] + for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)): + z.append(bn(encoder(args[i]))) + return tuple(z) + + def loss(self, *args): + z = self(*args) + cross_cov = z[0].T @ z[1] / (z[0].shape[0] - 1) + invariance = torch.mean(torch.pow(1 - torch.diag(cross_cov), 2)) + covariance = torch.mean( + torch.triu(torch.pow(cross_cov, 2), diagonal=1) + ) + torch.mean(torch.tril(torch.pow(cross_cov, 2), diagonal=-1)) + return invariance + covariance diff --git a/cca_zoo/deepmodels/dcca_noi.py b/cca_zoo/deepmodels/dcca_noi.py index 5bcefd4f..6ab09929 100644 --- a/cca_zoo/deepmodels/dcca_noi.py +++ b/cca_zoo/deepmodels/dcca_noi.py @@ -8,6 +8,11 @@ class DCCA_NOI(DCCA): """ A class used to fit a DCCA model by non-linear orthogonal iterations + + :Citation: + + Wang, Weiran, et al. "Stochastic optimization for deep CCA via nonlinear orthogonal iterations." 2015 53rd Annual Allerton Conference on Communication, Control, and Computing (Allerton). IEEE, 2015. + """ def __init__( @@ -17,7 +22,7 @@ def __init__( encoders=None, r: float = 0, rho: float = 0.2, - eps: float = 1e-3, + eps: float = 1e-9, shared_target: bool = False, ): """ @@ -39,7 +44,7 @@ def __init__( self.eps = eps self.rho = rho self.shared_target = shared_target - self.mse = torch.nn.MSELoss() + self.mse = torch.nn.MSELoss(reduction="sum") # Authors state that a final linear layer is an important part of their algorithmic implementation self.linear_layers = torch.nn.ModuleList( [ @@ -61,7 +66,7 @@ def forward(self, *args): def loss(self, *args): z = self(*args) z_copy = [z_.detach().clone() for z_ in z] - self.update_covariances(*z_copy) + self._update_covariances(*z_copy) covariance_inv = [ torch.linalg.inv(objectives.MatrixSquareRoot.apply(cov)) for cov in self.covs @@ -70,25 +75,14 @@ def loss(self, *args): loss = self.mse(z[0], preds[1]) + self.mse(z[1], preds[0]) return loss - def update_mean(self, *z): - batch_means = [torch.mean(z_, dim=0) for z_ in z] - if self.means is not None: - self.means = [ - self.rho * self.means[i].detach() + (1 - self.rho) * batch_mean - for i, batch_mean in enumerate(batch_means) - ] - else: - self.means = batch_means - z = [z_ - mean for (z_, mean) in zip(z, self.means)] - return z - - def update_covariances(self, *z): + def _update_covariances(self, *z, train=True): b = z[0].shape[0] batch_covs = [self.N * z_.T @ z_ / b for z_ in z] - if self.covs is not None: - self.covs = [ - self.rho * self.covs[i] + (1 - self.rho) * batch_cov - for i, batch_cov in enumerate(batch_covs) - ] - else: - self.covs = batch_covs + if train: + if self.covs is not None: + self.covs = [ + self.rho * self.covs[i] + (1 - self.rho) * batch_cov + for i, batch_cov in enumerate(batch_covs) + ] + else: + self.covs = batch_covs diff --git a/cca_zoo/deepmodels/dcca_sdl.py b/cca_zoo/deepmodels/dcca_sdl.py new file mode 100644 index 00000000..dcbde8ac --- /dev/null +++ b/cca_zoo/deepmodels/dcca_sdl.py @@ -0,0 +1,92 @@ +import torch +import torch.nn.functional as F + +from cca_zoo.deepmodels import DCCA_NOI + + +class DCCA_SDL(DCCA_NOI): + """ + A class used to fit a Deep CCA by Stochastic Decorrelation model. + + :Citation: + + Chang, Xiaobin, Tao Xiang, and Timothy M. Hospedales. "Scalable and effective deep CCA via soft decorrelation." Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2018. + + """ + + def __init__( + self, + latent_dims: int, + N: int, + encoders=None, + r: float = 0, + rho: float = 0.2, + eps: float = 1e-3, + shared_target: bool = False, + lam=0.5, + ): + """ + Constructor class for DCCA + :param latent_dims: # latent dimensions + :param encoders: list of encoder networks + :param r: regularisation parameter of tracenorm CCA like ridge CCA + :param rho: covariance memory like DCCA non-linear orthogonal iterations paper + :param eps: epsilon used throughout + :param shared_target: not used + """ + super().__init__( + latent_dims=latent_dims, + N=N, + encoders=encoders, + r=r, + rho=rho, + eps=eps, + shared_target=shared_target, + ) + self.c = None + self.cross_cov = None + self.lam = lam + self.bns = torch.nn.ModuleList( + [ + torch.nn.BatchNorm1d(latent_dims, affine=False) + for _ in range(latent_dims) + ] + ) + + def forward(self, *args): + z = [] + for i, (encoder, bn) in enumerate(zip(self.encoders, self.bns)): + z.append(bn(encoder(args[i]))) + return tuple(z) + + def loss(self, *args): + z = self(*args) + self._update_covariances(*z, train=self.training) + SDL_loss = self._sdl_loss(self.covs) + l2_loss = F.mse_loss(z[0], z[1]) + return l2_loss + self.lam * SDL_loss + + def _sdl_loss(self, covs): + loss = 0 + for cov in covs: + cov = cov + sgn = torch.sign(cov) + sgn.fill_diagonal_(0) + loss += torch.mean(cov * sgn) + return loss + + def _update_covariances(self, *z, train=True): + batch_covs = [z_.T @ z_ for z_ in z] + if train: + if self.c is not None: + self.c = self.rho * self.c + 1 + self.covs = [ + self.rho * self.covs[i].detach() + (1 - self.rho) * batch_cov + for i, batch_cov in enumerate(batch_covs) + ] + else: + self.c = 1 + self.covs = batch_covs + # pytorch-lightning runs validation once so this just fixes the bug + elif self.covs is None: + self.covs = batch_covs diff --git a/cca_zoo/deepmodels/dccae.py b/cca_zoo/deepmodels/dccae.py index 9e38ba01..f9185480 100644 --- a/cca_zoo/deepmodels/dccae.py +++ b/cca_zoo/deepmodels/dccae.py @@ -10,10 +10,10 @@ class DCCAE(_DCCA_base): """ A class used to fit a DCCAE model. - Examples - -------- - >>> from cca_zoo.deepmodels import DCCAE - >>> model = DCCAE() + :Citation: + + Wang, Weiran, et al. "On deep multi-view representation learning." International conference on machine learning. PMLR, 2015. + """ def __init__( @@ -57,7 +57,6 @@ def decode(self, *z): """ This method is used to decode from the latent space to the best prediction of the original views - :param args: """ recon = [] for i, decoder in enumerate(self.decoders): @@ -67,11 +66,11 @@ def decode(self, *z): def loss(self, *args): z = self(*args) recon = self.decode(*z) - recon_loss = self.recon_loss(args[: len(recon)], recon) + recon_loss = self._recon_loss(args[: len(recon)], recon) return self.lam * recon_loss + self.objective.loss(*z) @staticmethod - def recon_loss(x, recon): + def _recon_loss(x, recon): recons = [ F.mse_loss(recon_, x_, reduction="mean") for recon_, x_ in zip(recon, x) ] diff --git a/cca_zoo/deepmodels/deepwrapper.py b/cca_zoo/deepmodels/deepwrapper.py deleted file mode 100644 index 5cd83492..00000000 --- a/cca_zoo/deepmodels/deepwrapper.py +++ /dev/null @@ -1,348 +0,0 @@ -import copy -import itertools -from typing import Union, Iterable - -import numpy as np -import torch -from torch.utils.data import DataLoader - -from cca_zoo.data import CCA_Dataset -from cca_zoo.deepmodels import _DCCA_base, DCCA, DCCAE -from cca_zoo.models import _CCA_Base -from ..utils.check_values import _check_batch_size - - -class DeepWrapper(_CCA_Base): - """ - This class is used as a wrapper for DCCA, DCCAE, DVCCA, DTCCA, SplitAE. It can be inherited and adapted to - customise the training loop. By inheriting _CCA_Base, the DeepWrapper class gives access to fit_transform. - """ - - def __init__( - self, - model: _DCCA_base, - device: str = "cuda", - optimizer: torch.optim.Optimizer = None, - scheduler=None, - lr: float = 1e-3, - clip_value=float("inf"), - random_state: int = 1, - ): - """ - - :param model: An instance of a model - :param device: device to train on - :param optimizer: optimizer used to update model parameters each iteration - :param scheduler: scheduler used to update the optimizer e.g. learning rate decay - :param lr: learning rate if not specified in the optimizer - :param clip_value: - """ - super().__init__(latent_dims=model.latent_dims) - self.model = model - self.device = device - if not torch.cuda.is_available(): - self.device = "cpu" - torch.manual_seed(random_state) - torch.cuda.manual_seed(random_state) - self.latent_dims = model.latent_dims - self.optimizer = optimizer - if optimizer is None: - if isinstance(self.model, DCCA): - # Andrew G, Arora R, Bilmes J, Livescu K. Deep canonical correlation analysis. InInternational conference on machine learning 2013 May 26 (pp. 1247-1255). PMLR. - self.optimizer = torch.optim.LBFGS(self.model.parameters(), lr=lr) - elif isinstance(self.model, DCCAE): - # Wang W, Arora R, Livescu K, Bilmes J. On deep multi-view representation learning. InInternational conference on machine learning 2015 Jun 1 (pp. 1083-1092). PMLR. - self.optimizer = torch.optim.SGD( - self.model.parameters(), lr=lr, weight_decay=1e-4 - ) - else: - self.optimizer = torch.optim.SGD(self.model.parameters(), lr=lr) - self.scheduler = scheduler - self.clip_value = clip_value - - def fit( - self, - train_dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]], - val_dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]] = None, - train_labels=None, - val_labels=None, - val_split: float = 0, - batch_size: int = 0, - val_batch_size: int = 0, - patience: int = 0, - epochs: int = 1, - post_transform=True, - ): - """ - - :param train_dataset: either tuple of 2d numpy arrays (one for each view) or torch dataset - :param val_dataset: either tuple of 2d numpy arrays (one for each view), torch dataset or None - :param train_labels: - :param val_labels: - :param val_split: if val_dataset is None, - :param batch_size: the minibatch size - :param patience: if 0 train to num_epochs, else if validation score doesn't improve after patience epochs stop training - :param epochs: maximum number of epochs to train - """ - train_dataset, val_dataset = self._process_data( - train_dataset, val_dataset, train_labels, val_labels, val_split - ) - train_dataloader, val_dataloader = self._get_dataloaders( - train_dataset, batch_size, val_dataset, val_batch_size - ) - num_params = sum(p.numel() for p in self.model.parameters()) - print("total parameters: ", num_params) - best_model = copy.deepcopy(self.model.state_dict()) - self.model.to(self.device) - min_val_loss = torch.tensor(np.inf) - epochs_no_improve = 0 - early_stop = False - - for epoch in range(1, epochs + 1): - if not early_stop: - # Train - epoch_train_loss = self._train_epoch(train_dataloader) - print( - "====> Epoch: {} Average train loss: {:.4f}".format( - epoch, epoch_train_loss - ) - ) - # Val - if val_dataset: - epoch_val_loss = self._val_epoch(val_dataloader) - print( - "====> Epoch: {} Average val loss: {:.4f}".format( - epoch, epoch_val_loss - ) - ) - if epoch_val_loss < min_val_loss or epoch == 1: - min_val_loss = epoch_val_loss - best_model = copy.deepcopy(self.model.state_dict()) - print("Min loss %0.2f" % min_val_loss) - epochs_no_improve = 0 - else: - epochs_no_improve += 1 - # Check early stopping condition - if epochs_no_improve == patience and patience > 0: - print("Early stopping!") - early_stop = True - self.model.load_state_dict(best_model) - # Scheduler step - if self.scheduler: - try: - self.scheduler.step() - except: - self.scheduler.step(epoch_train_loss) - if post_transform: - self.transform(train_dataset, batch_size=batch_size, train=True) - return self - - def _train_epoch(self, train_dataloader: torch.utils.data.DataLoader): - """ - Train a single epoch - - :param train_dataloader: a dataloader for training data - :return: average loss over the epoch - """ - self.model.train() - train_loss = 0 - for batch_idx, (data, label) in enumerate(train_dataloader): - data = [d.to(self.device) for d in list(data)] - loss = self._update_weights(*data) - train_loss += loss.item() - return train_loss / len(train_dataloader) - - def _update_weights(self, *args): - """ - A complete update of the weights used every batch - - :param args: batches for each view separated by commas - :return: - """ - if type(self.optimizer) == torch.optim.LBFGS: - - def closure(): - """ - Required by LBFGS optimizer - """ - self.optimizer.zero_grad() - loss = self.model.loss(*args) - loss.backward() - return loss - - torch.nn.utils.clip_grad_value_( - self.model.parameters(), clip_value=self.clip_value - ) - self.optimizer.step(closure) - loss = closure() - else: - for p in self.model.parameters(): - p.grad = None - loss = self.model.loss(*args) - loss.backward() - torch.nn.utils.clip_grad_value_( - self.model.parameters(), clip_value=self.clip_value - ) - self.optimizer.step() - return loss - - def _val_epoch(self, val_dataloader: torch.utils.data.DataLoader): - """ - Validate a single epoch - - :param val_dataloader: a dataloder for validation data - :return: average validation loss over the epoch - """ - self.model.eval() - for param in self.model.parameters(): - param.grad = None - total_val_loss = 0 - for batch_idx, (data, label) in enumerate(val_dataloader): - data = [d.to(self.device) for d in list(data)] - loss = self.model.loss(*data) - total_val_loss += loss.item() - return total_val_loss / len(val_dataloader) - - def correlations( - self, - test_dataset: Union[ - torch.utils.data.Dataset, Iterable[np.ndarray], torch.utils.data.DataLoader - ], - train: bool = False, - batch_size: int = 0, - ): - """ - - - :return: numpy array containing correlations between each pair of views for each dimension (#views*#views*#latent_dimensions) - """ - transformed_views = self.transform( - test_dataset, train=train, batch_size=batch_size - ) - all_corrs = [] - for x, y in itertools.product(transformed_views, repeat=2): - all_corrs.append(np.diag(np.corrcoef(x.T, y.T)[: x.shape[1], y.shape[1]:])) - all_corrs = np.array(all_corrs).reshape( - (len(transformed_views), len(transformed_views), -1) - ) - return all_corrs - - def transform( - self, - test_dataset: Union[ - torch.utils.data.Dataset, Iterable[np.ndarray], torch.utils.data.DataLoader - ], - test_labels=None, - train: bool = False, - batch_size: int = 0, - ): - if isinstance(test_dataset, torch.utils.data.DataLoader): - test_dataloader = test_dataset - else: - test_dataset = self._process_data(test_dataset, labels=test_labels)[0] - if batch_size > 0: - test_dataloader = DataLoader(test_dataset, batch_size=batch_size) - else: - test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset)) - with torch.no_grad(): - for batch_idx, (data, label) in enumerate(test_dataloader): - data = [d.to(self.device) for d in list(data)] - z = self.model(*data) - if batch_idx == 0: - z_list = [z_i.detach().cpu().numpy() for i, z_i in enumerate(z)] - else: - z_list = [ - np.append(z_list[i], z_i.detach().cpu().numpy(), axis=0) - for i, z_i in enumerate(z) - ] - z_list = self.model.post_transform(*z_list, train=train) - return z_list - - def predict_view( - self, - test_dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]], - test_labels=None, - ): - test_dataset = self._process_data(test_dataset, labels=test_labels)[0] - test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset)) - with torch.no_grad(): - for batch_idx, (data, label) in enumerate(test_dataloader): - data = [d.to(self.device) for d in list(data)] - x = self.model.recon(*data) - if batch_idx == 0: - x_list = [x_i.detach().cpu().numpy() for i, x_i in enumerate(x)] - else: - x_list = [ - np.append(x_list[i], x_i.detach().cpu().numpy(), axis=0) - for i, x_i in enumerate(x) - ] - return x_list - - def score( - self, - test_dataset: Union[ - torch.utils.data.Dataset, Iterable[np.ndarray], torch.utils.data.DataLoader - ], - train: bool = False, - batch_size: int = 0, - ): - # by default return the average pairwise correlation in each dimension (for 2 views just the correlation) - pair_corrs = self.correlations( - test_dataset=test_dataset, train=train, batch_size=batch_size - ) - # n views - n_views = pair_corrs.shape[0] - # sum all the pairwise correlations for each dimension. Subtract the self correlations. Divide by the number of views. Gives average correlation - dim_corrs = ( - pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views - ) / (n_views ** 2 - n_views) - return dim_corrs - - def _process_data( - self, - dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]], - val_dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]] = None, - labels=None, - val_labels=None, - val_split: float = 0, - ): - # Ensure datasets are in the right form (e.g. if numpy arrays are passed turn them into - if isinstance(dataset, tuple): - dataset = CCA_Dataset(dataset, labels=labels) - if val_dataset is None and val_split > 0: - lengths = [ - len(dataset) - int(len(dataset) * val_split), - int(len(dataset) * val_split), - ] - dataset, val_dataset = torch.utils.data.random_split(dataset, lengths) - elif isinstance(val_dataset, tuple): - val_dataset = CCA_Dataset(val_dataset, labels=val_labels) - return dataset, val_dataset - - def _get_dataloaders( - self, dataset, batch_size, val_dataset=None, val_batch_size=None, num_workers=0 - ): - if batch_size == 0: - batch_size = len(dataset) - train_dataloader = DataLoader( - dataset, - batch_size=batch_size, - drop_last=True, - num_workers=num_workers, - pin_memory=True, - shuffle=True, - ) - _check_batch_size(batch_size, self.latent_dims) - if val_dataset: - if val_batch_size == 0: - val_batch_size = len(val_dataset) - val_dataloader = DataLoader( - val_dataset, - batch_size=val_batch_size, - drop_last=True, - num_workers=num_workers, - pin_memory=True, - ) - _check_batch_size(batch_size, self.latent_dims) - return train_dataloader, val_dataloader - return train_dataloader, None diff --git a/cca_zoo/deepmodels/dtcca.py b/cca_zoo/deepmodels/dtcca.py index b824fd86..6d0b0535 100644 --- a/cca_zoo/deepmodels/dtcca.py +++ b/cca_zoo/deepmodels/dtcca.py @@ -13,10 +13,10 @@ class DTCCA(DCCA): Is just a thin wrapper round DCCA with the DTCCA objective and a TCCA post-processing - Examples - -------- - >>> from cca_zoo.deepmodels import DTCCA - >>> model = DTCCA() + :Citation: + + Wong, Hok Shing, et al. "Deep Tensor CCA for Multi-view Learning." IEEE Transactions on Big Data (2021). + """ def __init__( diff --git a/cca_zoo/deepmodels/dvcca.py b/cca_zoo/deepmodels/dvcca.py index 861f6f7d..65e1c311 100644 --- a/cca_zoo/deepmodels/dvcca.py +++ b/cca_zoo/deepmodels/dvcca.py @@ -12,9 +12,14 @@ class DVCCA(_DCCA_base): """ A class used to fit a DVCCA model. + :Citation: + + Wang, Weiran, et al. "Deep variational canonical correlation analysis." arXiv preprint arXiv:1610.03454 (2016). + https: // arxiv.org / pdf / 1610.03454.pdf - With pieces borrowed from the variational autoencoder implementation @ - # https: // github.com / pytorch / examples / blob / master / vae / main.py + + https: // github.com / pytorch / examples / blob / master / vae / main.py + """ def __init__( @@ -49,7 +54,7 @@ def forward(self, *args, mle=True): :return: """ # Used when we get reconstructions - mu, logvar = self.encode(*args) + mu, logvar = self._encode(*args) if mle: z = mu else: @@ -59,7 +64,7 @@ def forward(self, *args, mle=True): if len(self.encoders) == 1: z = z * len(args) if self.private_encoders: - mu_p, logvar_p = self.encode_private(*args) + mu_p, logvar_p = self._encode_private(*args) if mle: z_p = mu_p else: @@ -68,7 +73,7 @@ def forward(self, *args, mle=True): z = [torch.cat([z_] + z_p, dim=-1) for z_ in z] return z - def encode(self, *args): + def _encode(self, *args): """ :param args: :return: @@ -81,7 +86,7 @@ def encode(self, *args): logvar.append(logvar_i) return mu, logvar - def encode_private(self, *args): + def _encode_private(self, *args): """ :param args: :return: @@ -94,14 +99,14 @@ def encode_private(self, *args): logvar.append(logvar_i) return mu, logvar - def decode(self, z): + def _decode(self, z): """ :param z: :return: """ x = [] for i, decoder in enumerate(self.decoders): - x_i = decoder(z) + x_i = F.sigmoid(decoder(z)) x.append(x_i) return x @@ -111,16 +116,16 @@ def recon(self, *args): :return: """ z = self(*args) - return [self.decode(z_i) for z_i in z][0] + return [self._decode(z_i) for z_i in z][0] def loss(self, *args): """ :param args: :return: """ - mus, logvars = self.encode(*args) + mus, logvars = self._encode(*args) if self.private_encoders: - mus_p, logvars_p = self.encode_private(*args) + mus_p, logvars_p = self._encode_private(*args) losses = [ self.vcca_private_loss( *args, mu=mu, logvar=logvar, mu_p=mu_p, logvar_p=logvar_p @@ -147,7 +152,7 @@ def vcca_loss(self, *args, mu, logvar): kl = torch.mean( -0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0 ) - recons = self.decode(z) + recons = self._decode(z) bces = torch.stack( [ F.binary_cross_entropy(recon, arg, reduction="sum") / batch_n @@ -182,7 +187,7 @@ def vcca_private_loss(self, *args, mu, logvar, mu_p, logvar_p): -0.5 * torch.sum(1 + logvar - logvar.exp() - mu.pow(2), dim=1), dim=0 ) z_combined = torch.cat([z, z_p], dim=-1) - recon = self.decode(z_combined) + recon = self._decode(z_combined) bces = torch.stack( [ F.binary_cross_entropy(recon[i], args[i], reduction="sum") / batch_n diff --git a/cca_zoo/deepmodels/objectives.py b/cca_zoo/deepmodels/objectives.py index 96287069..7a094f57 100644 --- a/cca_zoo/deepmodels/objectives.py +++ b/cca_zoo/deepmodels/objectives.py @@ -273,6 +273,6 @@ def loss(self, *z): M = torch.unsqueeze(M, -1) @ el M = torch.mean(M, 0) tl.set_backend("pytorch") - M_parafac = parafac(M.detach(), self.latent_dims) + M_parafac = parafac(M.detach(), self.latent_dims, verbose=False) M_hat = cp_to_tensor(M_parafac) return torch.norm(M - M_hat) diff --git a/cca_zoo/deepmodels/splitae.py b/cca_zoo/deepmodels/splitae.py index 967707d6..a71a2a73 100644 --- a/cca_zoo/deepmodels/splitae.py +++ b/cca_zoo/deepmodels/splitae.py @@ -9,10 +9,10 @@ class SplitAE(_DCCA_base): """ A class used to fit a Split Autoencoder model. - Examples - -------- - >>> from cca_zoo.deepmodels import SplitAE - >>> model = SplitAE() + :Citation: + + Ngiam, Jiquan, et al. "Multimodal deep learning." ICML. 2011. + """ def __init__(self, latent_dims: int, encoder: BaseEncoder = Encoder, decoders=None): @@ -30,7 +30,7 @@ def __init__(self, latent_dims: int, encoder: BaseEncoder = Encoder, decoders=No def forward(self, *args): z = self.encoder(args[0]) - return z + return [z] def decode(self, z): """ @@ -45,11 +45,13 @@ def decode(self, z): def loss(self, *args): z = self(*args) - recon = self.decode(z) + recon = self.decode(*z) recon_loss = self.recon_loss(args, recon) return recon_loss @staticmethod def recon_loss(x, recon): - recons = [F.mse_loss(recon[i], x[i], reduction="mean") for i in range(len(x))] + recons = [ + F.mse_loss(recon[i], x[i], reduction="mean") for i in range(len(recon)) + ] return torch.stack(recons).sum(dim=0) diff --git a/cca_zoo/deepmodels/trainers.py b/cca_zoo/deepmodels/trainers.py new file mode 100644 index 00000000..1462b267 --- /dev/null +++ b/cca_zoo/deepmodels/trainers.py @@ -0,0 +1,229 @@ +import itertools +import sys +from typing import Optional, Union + +import numpy as np +import torch +from pytorch_lightning import LightningModule +from torch.utils.data import DataLoader + +from cca_zoo.deepmodels import _DCCA_base + + +class CCALightning(LightningModule): + def __init__( + self, + model: _DCCA_base, + optimizer: Union[torch.optim.Optimizer, str] = "Adam", + learning_rate: float = 1e-3, + weight_decay: float = 0.1, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + StepLR_step_size: float = None, + StepLR_gamma: float = None, + lr_factor: float = None, + lr_patience: float = None, + OneCycleLR_max_lr: float = None, + OneCycleLR_epochs: float = None, + train_trajectories: float = None, + T: float = None, + ): + """ + + :param model: a model instance from deepmodels + :param optimizer: a pytorch optimizer with parameters from model or a string like 'Adam' to use Adam optimizer with default parameters or those specified by the user + :param learning_rate: learning rate used when optimizer is instantiated with a string + :param weight_decay: weight decay used when optimizer is instantiated with a string + :param lr_scheduler: a pytorch learning rate scheduler or a string like "StepLR" or None + :param StepLR_step_size: step size used by "StepLR" + :param StepLR_gamma: gamma used by "StepLR" + :param lr_factor: factor used by "ReduceLROnPlateau" + :param lr_patience: patience used by "ReduceLROnPlateau" + :param OneCycleLR_max_lr: max lr used by "OneCycleLR" + :param OneCycleLR_epochs: epochs used by "OneCycleLR" + :param train_trajectories: train trajectories used by "OneCycleLR" + :param T: T used by "OneCycleLR" + """ + super().__init__() + self.save_hyperparameters() + self.model = model + + def forward(self, *args): + z = self.encode(*args) + return z + + def loss(self, *args, **kwargs): + return self.model.loss(*args, **kwargs) + + # Configuration. Add more for learning schedulers, etc.? + def configure_optimizers(self): + if isinstance(self.hparams.optimizer, torch.optim.Optimizer): + optimizer = self.hparams.optimizer + elif self.hparams.optimizer == "Adam": + optimizer = torch.optim.Adam( + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + ) + elif self.hparams.optimizer == "SGD": + # Left out the momentum options for now + optimizer = torch.optim.SGD( + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + ) + elif self.hparams.optimizer == "LBFGS": + optimizer = torch.optim.LBFGS( + self.parameters(), + # or can have self.hparams.learning_rate with warning if too low. + lr=1, + tolerance_grad=1e-5, # can add to parameters if useful. + tolerance_change=1e-9, # can add to parameters if useful. + ) + else: + print("Invalid optimizer. See --help") + sys.exit() + + if self.hparams.lr_scheduler is None: + return optimizer + elif isinstance( + self.hparams.lr_scheduler, torch.optim.lr_scheduler._LRScheduler + ): + scheduler = self.hparams.lr_scheduler + elif self.hparams.lr_scheduler == "StepLR": + step_size = self.hparams.StepLR_step_size + gamma = self.hparams.StepLR_gamma + scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma) + elif self.hparams.lr_scheduler == "ReduceLROnPlateau": + factor = self.hparams.lr_factor + patience = self.hparams.lr_patience + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + optimizer, mode="min", factor=factor, patience=patience + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": self.hparams.LRScheduler_metric, + } + elif self.hparams.lr_scheduler == "OneCycleLR": + max_lr = self.hparams.OneCycleLR_max_lr + epochs = self.hparams.OneCycleLR_epochs + steps_per_epoch = self.hparams.train_trajectories * (self.hparams.T + 1) + scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=max_lr, + epochs=epochs, + steps_per_epoch=steps_per_epoch, + ) + else: + print("Invalid scheduler configuration. See --help") + raise + return [optimizer], [scheduler] + + def training_step(self, batch, batch_idx): + data, label = batch + loss = self.model.loss(*data) + return loss + + def validation_step(self, batch, batch_idx): + data, label = batch + loss = self.model.loss(*data) + return loss + + def test_step(self, batch, batch_idx): + data, label = batch + loss = self.model.loss(*data) + return loss + + def on_train_epoch_end(self, unused: Optional = None) -> None: + score = self.score(self.trainer.train_dataloader, train=True).sum() + self.log("train corr", score) + + def on_validation_epoch_end(self, unused: Optional = None) -> None: + score = self.score(self.trainer.val_dataloaders[0], train=True).sum() + self.log("val corr", score) + + def correlations( + self, + loader: torch.utils.data.DataLoader, + train: bool = False, + ): + """ + + :param loader: a dataloader that matches the structure of that used for training + :param train: if True and the model requires a final linear CCA this solves and stores the linear CCA + :return: numpy array containing correlations between each pair of views for each dimension (#views*#views*#latent_dimensions) + """ + transformed_views = self.transform(loader, train=train) + if len(transformed_views) < 2: + return None + all_corrs = [] + for x, y in itertools.product(transformed_views, repeat=2): + all_corrs.append(np.diag(np.corrcoef(x.T, y.T)[: x.shape[1], y.shape[1]:])) + all_corrs = np.array(all_corrs).reshape( + (len(transformed_views), len(transformed_views), -1) + ) + return all_corrs + + def transform( + self, + loader: torch.utils.data.DataLoader, + train: bool = False, + ): + """ + + :param loader: a dataloader that matches the structure of that used for training + :param train: if True and the model requires a final linear CCA this solves and stores the linear CCA + :return: transformed views + """ + with torch.no_grad(): + for batch_idx, (data, label) in enumerate(loader): + data = [d.to(self.device) for d in list(data)] + z = self.model(*data) + if batch_idx == 0: + z_list = [z_i.detach().cpu().numpy() for i, z_i in enumerate(z)] + else: + z_list = [ + np.append(z_list[i], z_i.detach().cpu().numpy(), axis=0) + for i, z_i in enumerate(z) + ] + z_list = self.model.post_transform(*z_list, train=train) + return z_list + + def score( + self, + loader: torch.utils.data.DataLoader, + train: bool = False, + ): + """ + + :param loader: a dataloader that matches the structure of that used for training + :param train: if True and the model requires a final linear CCA this solves and stores the linear CCA + :return: by default returns the average pairwise correlation in each dimension (for 2 views just the correlation) + """ + pair_corrs = self.correlations(loader, train=train) + if pair_corrs is None: + return np.zeros(1) + # n views + n_views = pair_corrs.shape[0] + # sum all the pairwise correlations for each dimension. Subtract the self correlations. Divide by the number of views. Gives average correlation + dim_corrs = ( + pair_corrs.sum(axis=tuple(range(pair_corrs.ndim - 1))) - n_views + ) / (n_views ** 2 - n_views) + return dim_corrs + + def predict_view( + self, + loader: torch.utils.data.DataLoader, + ): + with torch.no_grad(): + for batch_idx, (data, label) in enumerate(loader): + data = [d.to(self.device) for d in list(data)] + x = self.model.recon(*data) + if batch_idx == 0: + x_list = [x_i.detach().cpu().numpy() for i, x_i in enumerate(x)] + else: + x_list = [ + np.append(x_list[i], x_i.detach().cpu().numpy(), axis=0) + for i, x_i in enumerate(x) + ] + return x_list diff --git a/cca_zoo/deepmodels/utils.py b/cca_zoo/deepmodels/utils.py new file mode 100644 index 00000000..8c04edc3 --- /dev/null +++ b/cca_zoo/deepmodels/utils.py @@ -0,0 +1,55 @@ +from typing import Union, Iterable + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from cca_zoo.data.utils import CCA_Dataset + + +def process_data( + dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]], + val_dataset: Union[torch.utils.data.Dataset, Iterable[np.ndarray]] = None, + labels=None, + val_labels=None, + val_split: float = 0, +): + # Ensure datasets are in the right form (e.g. if numpy arrays are passed turn them into + if isinstance(dataset, tuple): + dataset = CCA_Dataset(dataset, labels=labels) + if val_dataset is None and val_split > 0: + lengths = [ + len(dataset) - int(len(dataset) * val_split), + int(len(dataset) * val_split), + ] + dataset, val_dataset = torch.utils.data.random_split(dataset, lengths) + elif isinstance(val_dataset, tuple): + val_dataset = CCA_Dataset(val_dataset, labels=val_labels) + return dataset, val_dataset + + +def get_dataloaders( + dataset, val_dataset=None, batch_size=None, val_batch_size=None, num_workers=0 +): + if batch_size is None: + batch_size = len(dataset) + train_dataloader = DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + pin_memory=True, + shuffle=True, + ) + if val_dataset: + if val_batch_size is None: + val_batch_size = len(val_dataset) + val_dataloader = DataLoader( + val_dataset, + batch_size=val_batch_size, + drop_last=True, + num_workers=num_workers, + pin_memory=True, + ) + return train_dataloader, val_dataloader + return train_dataloader diff --git a/cca_zoo/model_selection/_search.py b/cca_zoo/model_selection/_search.py index 35521597..57b20826 100644 --- a/cca_zoo/model_selection/_search.py +++ b/cca_zoo/model_selection/_search.py @@ -51,8 +51,9 @@ def param2grid(params): --------- >>> params = {'regs': [[1, 2], [3, 4]]} >>> param2grid(params) - [[1,3], [1,4], [2,3], [2,4]] + {'regs': [[1, 3], [1, 4], [2, 3], [2, 4]]} """ + params = params.copy() for k, v in params.items(): if any([isinstance(v_, list) for v_ in v]): # itertools expects all lists to perform product diff --git a/cca_zoo/models/gcca.py b/cca_zoo/models/gcca.py index 01695eff..dd5bc322 100644 --- a/cca_zoo/models/gcca.py +++ b/cca_zoo/models/gcca.py @@ -6,25 +6,38 @@ from sklearn.utils.validation import check_is_fitted from cca_zoo.models import rCCA -from ..utils.check_values import _process_parameter, check_views +from cca_zoo.utils.check_values import _process_parameter, check_views class GCCA(rCCA): - """ + r""" A class used to fit GCCA model. For more than 2 views, GCCA optimizes the sum of correlations with a shared auxiliary vector - Citation - -------- + :Maths: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ \sum_iw_i^TX_i^TT \}\\ + + \text{subject to:} + + T^TT=1 + + :Citation: + Tenenhaus, Arthur, and Michel Tenenhaus. "Regularized generalized canonical correlation analysis." Psychometrika 76.2 (2011): 257. :Example: >>> from cca_zoo.models import GCCA + >>> import numpy as np >>> rng=np.random.RandomState(0) - >>> X1 = rng.random(10,5) - >>> X2 = np.random.rand(10,5) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> X3 = rng.random((10,5)) >>> model = GCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) + array([0.97229856]) """ def __init__( @@ -95,21 +108,34 @@ def _solve_evp(self, views: Iterable[np.ndarray], C, D=None, **kwargs): class KGCCA(GCCA): - """ + r""" A class used to fit KGCCA model. For more than 2 views, KGCCA optimizes the sum of correlations with a shared auxiliary vector - Citation - -------- + :Maths: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ \sum_i\alpha_i^TK_i^TT \}\\ + + \text{subject to:} + + T^TT=1 + + :Citation: + Tenenhaus, Arthur, Cathy Philippe, and Vincent Frouin. "Kernel generalized canonical correlation analysis." Computational Statistics & Data Analysis 90 (2015): 114-131. :Example: >>> from cca_zoo.models import KGCCA + >>> import numpy as np >>> rng=np.random.RandomState(0) - >>> X1 = rng.random(10,5) - >>> X2 = np.random.rand(10,5) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> X3 = rng.random((10,5)) >>> model = KGCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) + array([0.97019284]) """ def __init__( @@ -128,6 +154,8 @@ def __init__( kernel_params: Iterable[dict] = None, ): """ + Constructor for PLS + :param latent_dims: number of latent dimensions to fit :param scale: normalize variance in each column before fitting :param centre: demean data by column before fitting (and before transforming out of sample @@ -168,7 +196,7 @@ def _check_params(self): ) def _get_kernel(self, view, X, Y=None): - if callable(self.kernel): + if callable(self.kernel[view]): params = self.kernel_params[view] or {} else: params = { diff --git a/cca_zoo/models/innerloop.py b/cca_zoo/models/innerloop.py index 1c7f5b31..1dc52afa 100644 --- a/cca_zoo/models/innerloop.py +++ b/cca_zoo/models/innerloop.py @@ -26,17 +26,14 @@ def __init__( self, max_iter: int = 100, tol: float = 1e-5, - generalized: bool = False, initialization: str = "unregularized", random_state=None, ): """ :param max_iter: maximum number of iterations to perform if tol is not reached :param tol: tolerance value used for stopping criteria - :param generalized: use an auxiliary variable to :param initialization: initialise the optimisation with either the 'unregularized' (CCA/PLS) solution, or a 'random' initialisation """ - self.generalized = generalized self.initialization = initialization self.max_iter = max_iter self.tol = tol @@ -124,14 +121,12 @@ def __init__( self, max_iter: int = 100, tol=1e-5, - generalized: bool = False, initialization: str = "unregularized", random_state=None, ): super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) @@ -179,7 +174,6 @@ def __init__( self, max_iter: int = 100, tol=1e-5, - generalized: bool = False, initialization: str = "unregularized", c=None, positive=None, @@ -188,7 +182,6 @@ def __init__( super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) @@ -242,7 +235,6 @@ def __init__( self, max_iter: int = 100, tol=1e-5, - generalized: bool = False, initialization: str = "unregularized", c=None, random_state=None, @@ -250,14 +242,13 @@ def __init__( super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) self.c = c def _check_params(self): - self.c = _process_parameter("c", self.c, [0.0001], len(self.views)) + self.c = _process_parameter("c", self.c, 0.0001, len(self.views)) if any(c <= 0 for c in self.c): raise ("All regularisation parameters should be above 0. " f"c=[{self.c}]") @@ -283,11 +274,10 @@ def __init__( self, max_iter: int = 100, tol=1e-5, - generalized: bool = False, initialization: str = "unregularized", c=None, l1_ratio=None, - constrained=False, + maxvar=True, stochastic=True, positive=None, random_state=None, @@ -295,15 +285,14 @@ def __init__( super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) self.stochastic = stochastic - self.constrained = constrained self.c = c self.l1_ratio = l1_ratio self.positive = positive + self.maxvar = maxvar def _check_params(self): self.c = _process_parameter("c", self.c, 0, len(self.views)) @@ -313,8 +302,6 @@ def _check_params(self): self.positive = _process_parameter( "positive", self.positive, False, len(self.views) ) - if self.constrained: - self.gamma = np.zeros(len(self.views)) self.regressors = [] for alpha, l1_ratio, positive in zip(self.c, self.l1_ratio, self.positive): if self.stochastic: @@ -398,14 +385,12 @@ def _update_view(self, view_index: int): :param view_index: index of view being updated :return: updated weights """ - if self.generalized: + if self.maxvar: target = self.scores.mean(axis=0) + target /= np.linalg.norm(target) else: target = self.scores[view_index - 1] - if self.constrained: - self._elastic_solver_constrained(self.views[view_index], target, view_index) - else: - self._elastic_solver(self.views[view_index], target, view_index) + self._elastic_solver(self.views[view_index], target, view_index) _check_converged_weights(self.weights[view_index], view_index) self.scores[view_index] = self.views[view_index] @ self.weights[view_index] @@ -418,35 +403,6 @@ def _elastic_solver(self, X, y, view_index): self.views[view_index] @ self.weights[view_index] ) / np.sqrt(self.n) - @ignore_warnings(category=ConvergenceWarning) - def _elastic_solver_constrained(self, X, y, view_index): - converged = False - min_ = -1 - max_ = 1 - previous = self.gamma[view_index] - previous_val = None - i = 0 - while not converged: - i += 1 - coef = ( - self.regressors[view_index] - .fit( - np.sqrt(self.gamma[view_index] + 1) * X, - y.ravel() / np.sqrt(self.gamma[view_index] + 1), - ) - .coef_ - ) - current_val = 1 - (np.linalg.norm(X @ coef) ** 2) / self.n - self.gamma[view_index], previous, min_, max_ = _bin_search( - self.gamma[view_index], previous, current_val, previous_val, min_, max_ - ) - previous_val = current_val - if np.abs(current_val) < 1e-5: - converged = True - elif np.abs(max_ - min_) < 1e-30 or i == 50: - converged = True - self.weights[view_index] = coef - def _objective(self): views = len(self.views) c = np.array(self.c) @@ -455,7 +411,6 @@ def _objective(self): l2 = c * (1 - ratio) total_objective = 0 for i in range(views): - # TODO this looks like it could be tidied up. In particular can we make the generalized objective correspond to the 2 view target = self.scores.mean(axis=0) objective = ( views @@ -469,7 +424,7 @@ def _objective(self): def _early_stop(self) -> bool: # Some kind of early stopping - if np.abs(self.track['objective'][-2] - self.track['objective'][-1]) < self.tol: + if np.abs(self.track["objective"][-2] - self.track["objective"][-1]) < self.tol: return True else: return False @@ -480,7 +435,6 @@ def __init__( self, max_iter: int = 100, tol=1e-5, - generalized: bool = False, initialization: str = "unregularized", mu=None, lam=None, @@ -491,7 +445,6 @@ def __init__( super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) @@ -601,7 +554,6 @@ def __init__( self, max_iter: int = 100, tol=1e-5, - generalized: bool = False, initialization: str = "unregularized", c=None, regularisation="l0", @@ -612,7 +564,6 @@ def __init__( super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) @@ -664,7 +615,6 @@ def __init__( self, max_iter: int = 100, tol=1e-20, - generalized: bool = False, initialization: str = "unregularized", regularisation="l0", c=None, @@ -675,7 +625,6 @@ def __init__( super().__init__( max_iter=max_iter, tol=tol, - generalized=generalized, initialization=initialization, random_state=random_state, ) @@ -688,9 +637,11 @@ def __init__( self.positive = positive def _check_params(self): + if self.sample_support is None: + self.sample_support = self.views[0].shape[0] self.sample_weights = np.ones((self.views[0].shape[0], 1)) self.sample_weights /= np.linalg.norm(self.sample_weights) - self.c = _process_parameter("c", self.c, 1, len(self.views)) + self.c = _process_parameter("c", self.c, 2, len(self.views)) self.positive = _process_parameter( "positive", self.positive, False, len(self.views) ) diff --git a/cca_zoo/models/iterative.py b/cca_zoo/models/iterative.py index a365c4f7..aab90457 100644 --- a/cca_zoo/models/iterative.py +++ b/cca_zoo/models/iterative.py @@ -5,8 +5,8 @@ import numpy as np -from .cca_base import _CCA_Base -from .innerloop import ( +from cca_zoo.models.cca_base import _CCA_Base +from cca_zoo.models.innerloop import ( PLSInnerLoop, PMDInnerLoop, ParkhomenkoInnerLoop, @@ -15,7 +15,7 @@ SpanCCAInnerLoop, SWCCAInnerLoop, ) -from ..utils import check_views +from cca_zoo.utils import check_views class _Iterative(_CCA_Base): @@ -33,7 +33,6 @@ def __init__( random_state=None, deflation="cca", max_iter: int = 100, - generalized: bool = False, initialization: str = "unregularized", tol: float = 1e-9, ): @@ -47,7 +46,7 @@ def __init__( :param random_state: Pass for reproducible output across multiple function calls :param deflation: the type of deflation. :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) + :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping """ @@ -59,7 +58,6 @@ def __init__( accept_sparse=["csc", "csr"], ) self.max_iter = max_iter - self.generalized = generalized self.initialization = initialization self.tol = tol self.deflation = deflation @@ -126,25 +124,35 @@ def _set_loop_params(self): """ self.loop = PLSInnerLoop( max_iter=self.max_iter, - generalized=self.generalized, initialization=self.initialization, random_state=self.random_state, ) class PLS_ALS(_Iterative): - """ + r""" A class used to fit a PLS model Fits a partial least squares model with CCA deflation by NIPALS algorithm + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \|X_iw_i-X_jw_j\|^2\}\\ + + \text{subject to:} + + w_i^Tw_i=1 + :Example: >>> from cca_zoo.models import PLS - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = PLS_ALS() - >>> model.fit([X1,X2]) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = PLS_ALS(random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.81796873]) """ def __init__( @@ -155,7 +163,6 @@ def __init__( copy_data=True, random_state=None, max_iter: int = 100, - generalized: bool = False, initialization: str = "unregularized", tol: float = 1e-9, ): @@ -168,7 +175,6 @@ def __init__( :param copy_data: If True, X will be copied; else, it may be overwritten :param random_state: Pass for reproducible output across multiple function calls :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping """ @@ -179,7 +185,6 @@ def __init__( copy_data=copy_data, deflation="pls", max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -188,7 +193,6 @@ def __init__( def _set_loop_params(self): self.loop = PLSInnerLoop( max_iter=self.max_iter, - generalized=self.generalized, initialization=self.initialization, tol=self.tol, random_state=self.random_state, @@ -196,20 +200,31 @@ def _set_loop_params(self): class ElasticCCA(_Iterative): - """ - Fits an elastic CCA by iterative rescaled elastic net regression + r""" + Fits an elastic CCA by iterating elastic net regression - Citation - -------- - Waaijenborg, Sandra, Philip C. Verselewel de Witt Hamer, and Aeilko H. Zwinderman. "Quantifying the association between gene expressions and DNA-markers by penalized canonical correlation analysis." Statistical applications in genetics and molecular biology 7.1 (2008). + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \|X_iw_i-X_jw_j\|^2 + c\|w_i\|^2_2 + \text{l1_ratio}\|w_i\|_1\}\\ + + \text{subject to:} + + w_i^TX_i^TX_iw_i=1 + + :Citation: + + Fu, Xiao, et al. "Scalable and flexible multiview MAX-VAR canonical correlation analysis." IEEE Transactions on Signal Processing 65.16 (2017): 4150-4165. :Example: >>> from cca_zoo.models import ElasticCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = ElasticCCA(c=[0.1,0.1],l1_ratio=[0.5,0.5]) - >>> model.fit([X1,X2]) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = ElasticCCA(c=[0.1,0.1],l1_ratio=[0.5,0.5], random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.95818397]) """ def __init__( @@ -221,12 +236,11 @@ def __init__( random_state=None, deflation="cca", max_iter: int = 100, - generalized: bool = False, initialization: str = "unregularized", tol: float = 1e-9, c: Union[Iterable[float], float] = None, l1_ratio: Union[Iterable[float], float] = None, - constrained: bool = False, + maxvar: bool = True, stochastic=False, positive: Union[Iterable[bool], bool] = None, ): @@ -240,18 +254,17 @@ def __init__( :param random_state: Pass for reproducible output across multiple function calls :param deflation: the type of deflation. :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param c: lasso alpha :param l1_ratio: l1 ratio in lasso subproblems - :param constrained: force unit norm constraint with binary search + :param maxvar: use auxiliary variable "maxvar" formulation :param stochastic: use stochastic regression optimisers for subproblems :param positive: constrain model weights to be positive """ self.c = c self.l1_ratio = l1_ratio - self.constrained = constrained + self.maxvar = maxvar self.stochastic = stochastic self.positive = positive if self.positive is not None and stochastic: @@ -266,7 +279,6 @@ def __init__( copy_data=copy_data, deflation=deflation, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -277,10 +289,9 @@ def _set_loop_params(self): max_iter=self.max_iter, c=self.c, l1_ratio=self.l1_ratio, - generalized=self.generalized, + maxvar=self.maxvar, initialization=self.initialization, tol=self.tol, - constrained=self.constrained, stochastic=self.stochastic, positive=self.positive, random_state=self.random_state, @@ -288,20 +299,31 @@ def _set_loop_params(self): class CCA_ALS(ElasticCCA): - """ + r""" Fits a CCA model with CCA deflation by NIPALS algorithm. Implemented by ElasticCCA with 0 regularisation - Citation - -------- + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \|X_iw_i-X_jw_j\|^2 }\\ + + \text{subject to:} + + w_i^TX_i^TX_iw_i=1 + + :Citation: + Golub, Gene H., and Hongyuan Zha. "The canonical correlations of matrix pairs and their numerical computation." Linear algebra for signal processing. Springer, New York, NY, 1995. 27-49. :Example: >>> from cca_zoo.models import CCA_ALS - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = CCA_ALS() - >>> model.fit(X1,X2) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,3)) + >>> X2 = rng.random((10,3)) + >>> model = CCA_ALS(random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.85890619]) """ def __init__( @@ -312,7 +334,6 @@ def __init__( copy_data=True, random_state=None, max_iter: int = 100, - generalized: bool = False, initialization: str = "random", tol: float = 1e-9, stochastic=True, @@ -327,8 +348,7 @@ def __init__( :param copy_data: If True, X will be copied; else, it may be overwritten :param random_state: Pass for reproducible output across multiple function calls :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) - :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores + :param initialization: initialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param stochastic: use stochastic regression optimisers for subproblems :param positive: constrain model weights to be positive @@ -337,10 +357,8 @@ def __init__( super().__init__( latent_dims=latent_dims, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, - constrained=False, stochastic=stochastic, centre=centre, copy_data=copy_data, @@ -348,24 +366,36 @@ def __init__( positive=positive, random_state=random_state, c=1e-3, + maxvar=False, ) class SCCA(ElasticCCA): - """ + r""" Fits a sparse CCA model by iterative rescaled lasso regression. Implemented by ElasticCCA with l1 ratio=1 - Citation - -------- + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \|X_iw_i-X_jw_j\|^2 + \text{l1_ratio}\|w_i\|_1\}\\ + + \text{subject to:} + + w_i^TX_i^TX_iw_i=1 + + :Citation: + Mai, Qing, and Xin Zhang. "An iterative penalized least squares approach to sparse canonical correlation analysis." Biometrics 75.3 (2019): 734-744. :Example: >>> from cca_zoo.models import SCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = SCCA(c=[0.001,0.001]) - >>> model.fit(X1,X2) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = SCCA(c=[0.001,0.001], random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.99998919]) """ def __init__( @@ -377,7 +407,7 @@ def __init__( random_state=None, c: Union[Iterable[float], float] = None, max_iter: int = 100, - generalized: bool = False, + maxvar: bool = False, initialization: str = "unregularized", tol: float = 1e-9, stochastic=False, @@ -392,7 +422,7 @@ def __init__( :param copy_data: If True, X will be copied; else, it may be overwritten :param random_state: Pass for reproducible output across multiple function calls :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) + :param maxvar: use auxiliary variable "maxvar" form :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param c: lasso alpha @@ -405,12 +435,11 @@ def __init__( centre=centre, copy_data=copy_data, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, c=c, l1_ratio=1, - constrained=False, + maxvar=maxvar, stochastic=stochastic, positive=positive, random_state=random_state, @@ -418,20 +447,33 @@ def __init__( class PMD(_Iterative): - """ + r""" Fits a Sparse CCA (Penalized Matrix Decomposition) model. - Citation - -------- + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \}\\ + + \text{subject to:} + + w_i^Tw_i=1 + + \|w_i\|<=c_i + + :Citation: + Witten, Daniela M., Robert Tibshirani, and Trevor Hastie. "A penalized matrix decomposition, with applications to sparse principal components and canonical correlation analysis." Biostatistics 10.3 (2009): 515-534. :Example: >>> from cca_zoo.models import PMD - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = PMD(c=[1,1]) - >>> model.fit(X1,X2) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = PMD(c=[1,1],random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.69792082]) """ def __init__( @@ -443,7 +485,6 @@ def __init__( random_state=None, c: Union[Iterable[float], float] = None, max_iter: int = 100, - generalized: bool = False, initialization: str = "unregularized", tol: float = 1e-9, positive: Union[Iterable[bool], bool] = None, @@ -458,7 +499,6 @@ def __init__( :param random_state: Pass for reproducible output across multiple function calls :param c: l1 regularisation parameter between 1 and sqrt(number of features) for each view :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param positive: constrain model weights to be positive @@ -471,7 +511,6 @@ def __init__( centre=centre, copy_data=copy_data, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -481,7 +520,6 @@ def _set_loop_params(self): self.loop = PMDInnerLoop( max_iter=self.max_iter, c=self.c, - generalized=self.generalized, initialization=self.initialization, tol=self.tol, positive=self.positive, @@ -490,20 +528,31 @@ def _set_loop_params(self): class ParkhomenkoCCA(_Iterative): - """ + r""" Fits a sparse CCA (penalized CCA) model - Citation - -------- + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \} + c_i\|w_i\|\\ + + \text{subject to:} + + w_i^Tw_i=1 + + :Citation: + Parkhomenko, Elena, David Tritchler, and Joseph Beyene. "Sparse canonical correlation analysis with application to genomic data integration." Statistical applications in genetics and molecular biology 8.1 (2009). :Example: >>> from cca_zoo.models import ParkhomenkoCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = ParkhomenkoCCA(c=[0.001,0.001]) - >>> model.fit(X1,X2) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = ParkhomenkoCCA(c=[0.001,0.001],random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.81803543]) """ def __init__( @@ -515,7 +564,6 @@ def __init__( random_state=None, c: Union[Iterable[float], float] = None, max_iter: int = 100, - generalized: bool = False, initialization: str = "unregularized", tol: float = 1e-9, ): @@ -529,7 +577,6 @@ def __init__( :param random_state: Pass for reproducible output across multiple function calls :param c: l1 regularisation parameter :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping """ @@ -540,7 +587,6 @@ def __init__( centre=centre, copy_data=copy_data, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -550,7 +596,6 @@ def _set_loop_params(self): self.loop = ParkhomenkoInnerLoop( max_iter=self.max_iter, c=self.c, - generalized=self.generalized, initialization=self.initialization, tol=self.tol, random_state=self.random_state, @@ -558,20 +603,31 @@ def _set_loop_params(self): class SCCA_ADMM(_Iterative): - """ + r""" Fits a sparse CCA model by alternating ADMM - Citation - -------- + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \|X_iw_i-X_jw_j\|^2 + \text{l1_ratio}\|w_i\|_1\}\\ + + \text{subject to:} + + w_i^TX_i^TX_iw_i=1 + + :Citation: + Suo, Xiaotong, et al. "Sparse canonical correlation analysis." arXiv preprint arXiv:1705.10865 (2017). :Example: >>> from cca_zoo.models import SCCA_ADMM - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = SCCA_ADMM() - >>> model.fit(X1,X2) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = SCCA_ADMM(random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.99999997]) """ def __init__( @@ -586,7 +642,6 @@ def __init__( lam: Union[Iterable[float], float] = None, eta: Union[Iterable[float], float] = None, max_iter: int = 100, - generalized: bool = False, initialization: str = "unregularized", tol: float = 1e-9, ): @@ -600,7 +655,6 @@ def __init__( :param random_state: Pass for reproducible output across multiple function calls :param c: l1 regularisation parameter :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param mu: @@ -617,7 +671,6 @@ def __init__( centre=centre, copy_data=copy_data, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -630,7 +683,6 @@ def _set_loop_params(self): mu=self.mu, lam=self.lam, eta=self.eta, - generalized=self.generalized, initialization=self.initialization, tol=self.tol, random_state=self.random_state, @@ -638,13 +690,32 @@ def _set_loop_params(self): class SpanCCA(_Iterative): - """ + r""" Fits a Sparse CCA model using SpanCCA. - Citation - -------- + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \|X_iw_i-X_jw_j\|^2 + \text{l1_ratio}\|w_i\|_1\}\\ + + \text{subject to:} + + w_i^TX_i^TX_iw_i=1 + + :Citation: + Asteris, Megasthenis, et al. "A simple and provable algorithm for sparse diagonal CCA." International Conference on Machine Learning. PMLR, 2016. + + :Example: + + >>> from cca_zoo.models import SpanCCA + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = SpanCCA(regularisation="l0", c=[2, 2]) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.84556666]) """ def __init__( @@ -654,7 +725,6 @@ def __init__( centre=True, copy_data=True, max_iter: int = 100, - generalized: bool = False, initialization: str = "uniform", tol: float = 1e-9, regularisation="l0", @@ -671,7 +741,6 @@ def __init__( :param copy_data: If True, X will be copied; else, it may be overwritten :param random_state: Pass for reproducible output across multiple function calls :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param regularisation: @@ -685,7 +754,6 @@ def __init__( centre=centre, copy_data=copy_data, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -699,7 +767,6 @@ def _set_loop_params(self): self.loop = SpanCCAInnerLoop( max_iter=self.max_iter, c=self.c, - generalized=self.generalized, initialization=self.initialization, tol=self.tol, regularisation=self.regularisation, @@ -710,14 +777,23 @@ def _set_loop_params(self): class SWCCA(_Iterative): - """ + r""" A class used to fit SWCCA model - Citation - -------- - Wenwen, M. I. N., L. I. U. Juan, and Shihua Zhang. "Sparse Weighted Canonical Correlation Analysis." Chinese Journal of Electronics 27.3 (2018): 459-466. + :Citation: + + .. Wenwen, M. I. N., L. I. U. Juan, and Shihua Zhang. "Sparse Weighted Canonical Correlation Analysis." Chinese Journal of Electronics 27.3 (2018): 459-466. + :Example: + >>> from cca_zoo.models import SWCCA + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = SWCCA(regularisation='l0',c=[2, 2], sample_support=5, random_state=0) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.61620969]) """ def __init__( @@ -728,7 +804,6 @@ def __init__( copy_data=True, random_state=None, max_iter: int = 500, - generalized: bool = False, initialization: str = "uniform", tol: float = 1e-9, regularisation="l0", @@ -744,7 +819,6 @@ def __init__( :param copy_data: If True, X will be copied; else, it may be overwritten :param random_state: Pass for reproducible output across multiple function calls :param max_iter: the maximum number of iterations to perform in the inner optimization loop - :param generalized: use auxiliary variables (required for >2 views) :param initialization: intialization for optimisation. 'unregularized' uses CCA or PLS solution,'random' uses random initialization,'uniform' uses uniform initialization of weights and scores :param tol: tolerance value used for early stopping :param regularisation: the type of regularisation on the weights either 'l0' or 'l1' @@ -763,7 +837,6 @@ def __init__( centre=centre, copy_data=copy_data, max_iter=max_iter, - generalized=generalized, initialization=initialization, tol=tol, random_state=random_state, @@ -772,7 +845,6 @@ def __init__( def _set_loop_params(self): self.loop = SWCCAInnerLoop( max_iter=self.max_iter, - generalized=self.generalized, initialization=self.initialization, tol=self.tol, regularisation=self.regularisation, diff --git a/cca_zoo/models/mcca.py b/cca_zoo/models/mcca.py index ba1df3cd..4d0294c1 100644 --- a/cca_zoo/models/mcca.py +++ b/cca_zoo/models/mcca.py @@ -6,23 +6,37 @@ from sklearn.utils.validation import check_is_fitted from cca_zoo.models import rCCA -from ..utils.check_values import _process_parameter, check_views +from cca_zoo.utils.check_values import _process_parameter, check_views class MCCA(rCCA): - """ + r""" A class used to fit MCCA model. For more than 2 views, MCCA optimizes the sum of pairwise correlations. - Citation - -------- + :Maths: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} w_i^TX_i^TX_jw_j \}\\ + + \text{subject to:} + + (1-c_i)w_i^TX_i^TX_iw_i+c_iw_i^Tw_i=1 + + :Citation: + Kettenring, Jon R. "Canonical analysis of several sets of variables." Biometrika 58.3 (1971): 433-451. :Example: >>> from cca_zoo.models import MCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> X3 = rng.random((10,5)) >>> model = MCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) + array([0.97200847]) """ def __init__( @@ -87,16 +101,30 @@ def _solve_evp(self, views: Iterable[np.ndarray], C, D=None, **kwargs): class KCCA(MCCA): - """ + r""" A class used to fit KCCA model. + :Maths: + + .. math:: + + \alpha_{opt}=\underset{\alpha}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \alpha_i^TK_i^TK_j\alpha_j \}\\ + + \text{subject to:} + + c_i\alpha_i^TK_i\alpha_i + (1-c_i)\alpha_i^TK_i^TK_i\alpha_i=1 + :Example: >>> from cca_zoo.models import KCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> X3 = rng.random((10,5)) >>> model = KCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) + array([0.96893666]) """ def __init__( @@ -152,7 +180,7 @@ def _check_params(self): self.c = _process_parameter("c", self.c, 0, self.n_views) def _get_kernel(self, view, X, Y=None): - if callable(self.kernel): + if callable(self.kernel[view]): params = self.kernel_params[view] or {} else: params = { diff --git a/cca_zoo/models/ncca.py b/cca_zoo/models/ncca.py index ab6d1abe..2fe741c6 100644 --- a/cca_zoo/models/ncca.py +++ b/cca_zoo/models/ncca.py @@ -13,22 +13,30 @@ class NCCA(_CCA_Base): """ A class used to fit nonparametric (NCCA) model. - Citation - -------- + :Citation: + Michaeli, Tomer, Weiran Wang, and Karen Livescu. "Nonparametric canonical correlation analysis." International conference on machine learning. PMLR, 2016. :Example: - >>> from cca_zoo.experimental import NCCA + >>> from cca_zoo.models import NCCA >>> X1 = np.random.rand(10,5) >>> X2 = np.random.rand(10,5) >>> model = NCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2)).score((X1,X2)) + array([1.]) """ - def __init__(self, latent_dims: int = 1, scale=True, centre=True, copy_data=True, accept_sparse=False, - random_state: Union[int, np.random.RandomState] = None, nearest_neighbors=None, - gamma: Iterable[float] = None, - ): + def __init__( + self, + latent_dims: int = 1, + scale=True, + centre=True, + copy_data=True, + accept_sparse=False, + random_state: Union[int, np.random.RandomState] = None, + nearest_neighbors=None, + gamma: Iterable[float] = None, + ): """ Constructor for NCCA @@ -38,15 +46,19 @@ def __init__(self, latent_dims: int = 1, scale=True, centre=True, copy_data=True :param copy_data: If True, X will be copied; else, it may be overwritten :param accept_sparse: Whether model can take sparse data as input :param random_state: Pass for reproducible output across multiple function calls - :param nearest_neighbors: Number of nearest neighbors (l2 distance) to consider when constructing affinity + :param nearest_neighbors: Number of neaest neighbors (l2 distance) to consider when constructing affinity :param gamma: Bandwidth parameter for rbf kernel """ - super().__init__(latent_dims, scale, centre, copy_data, accept_sparse, random_state) + super().__init__( + latent_dims, scale, centre, copy_data, accept_sparse, random_state + ) self.nearest_neighbors = nearest_neighbors self.gamma = gamma def _check_params(self): - self.nearest_neighbors = _process_parameter("nearest_neighbors", self.nearest_neighbors, 1, self.n_views) + self.nearest_neighbors = _process_parameter( + "nearest_neighbors", self.nearest_neighbors, 1, self.n_views + ) self.gamma = _process_parameter("gamma", self.gamma, None, self.n_views) self.kernel = _process_parameter("kernel", None, "rbf", self.n_views) @@ -59,18 +71,25 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): self.n = views[0].shape[0] self._check_params() self.train_views = views - self.KNs = [NearestNeighbors(n_neighbors=self.nearest_neighbors[i]).fit(view) for i, view in - enumerate(views)] - NNs = [self.KNs[i].kneighbors(view, self.nearest_neighbors[i]) for i, view in enumerate(views)] + self.KNs = [ + NearestNeighbors(n_neighbors=self.nearest_neighbors[i]).fit(view) + for i, view in enumerate(views) + ] + NNs = [ + self.KNs[i].kneighbors(view, self.nearest_neighbors[i]) + for i, view in enumerate(views) + ] kernels = [self._get_kernel(i, view) for i, view in enumerate(self.train_views)] self.Ws = [fill_W(kernel, inds) for kernel, (dists, inds) in zip(kernels, NNs)] - self.Ws = [self.Ws[0] / self.Ws[0].sum(axis=1, keepdims=True), - self.Ws[1] / self.Ws[1].sum(axis=0, keepdims=True)] + self.Ws = [ + self.Ws[0] / self.Ws[0].sum(axis=1, keepdims=True), + self.Ws[1] / self.Ws[1].sum(axis=0, keepdims=True), + ] S = self.Ws[0] @ self.Ws[1] U, S, Vt = np.linalg.svd(S) - self.f = U[:, 1:self.latent_dims + 1] * np.sqrt(self.n) - self.g = Vt[1:self.latent_dims + 1, :].T * np.sqrt(self.n) - self.S = S[1:self.latent_dims + 1] + self.f = U[:, 1: self.latent_dims + 1] * np.sqrt(self.n) + self.g = Vt[1: self.latent_dims + 1, :].T * np.sqrt(self.n) + self.S = S[1: self.latent_dims + 1] return self def transform(self, views: Iterable[np.ndarray], y=None, **kwargs): @@ -80,18 +99,24 @@ def transform(self, views: Iterable[np.ndarray], y=None, **kwargs): :param views: numpy arrays with the same number of rows (samples) separated by commas :param kwargs: any additional keyword arguments required by the given model """ - check_is_fitted(self, attributes=["U", "V", "f", "g"]) + check_is_fitted(self, attributes=["f", "g"]) views = check_views( *views, copy=self.copy_data, accept_sparse=self.accept_sparse ) views = self._centre_scale_transform(views) - NNs = [self.KNs[i].kneighbors(view, self.nearest_neighbors[i]) for i, view in enumerate(views)] + NNs = [ + self.KNs[i].kneighbors(view, self.nearest_neighbors[i]) + for i, view in enumerate(views) + ] kernels = [ self._get_kernel(i, self.train_views[i], Y=view) for i, view in enumerate(views) ] Wst = [fill_W(kernel, inds) for kernel, (dists, inds) in zip(kernels, NNs)] - Wst = [Wst[0] / Wst[0].sum(axis=1, keepdims=True), Wst[1] / Wst[1].sum(axis=1, keepdims=True)] + Wst = [ + Wst[0] / Wst[0].sum(axis=1, keepdims=True), + Wst[1] / Wst[1].sum(axis=1, keepdims=True), + ] St = [Wst[0] @ self.Ws[1], Wst[1] @ self.Ws[0]] return St[0] @ self.g * (1 / self.S), St[1] @ self.f * (1 / self.S) @@ -108,4 +133,4 @@ def fill_W(kernels, inds): W = np.zeros_like(kernels) for i, ind in enumerate(inds): W[ind, i] = kernels[ind, i] - return W.T \ No newline at end of file + return W.T diff --git a/cca_zoo/models/rcca.py b/cca_zoo/models/rcca.py index 2285ef6c..4247240b 100644 --- a/cca_zoo/models/rcca.py +++ b/cca_zoo/models/rcca.py @@ -4,28 +4,44 @@ import numpy as np from scipy.linalg import block_diag, eigh -from .cca_base import _CCA_Base -from ..utils.check_values import _process_parameter, check_views +from cca_zoo.models.cca_base import _CCA_Base +from cca_zoo.utils.check_values import _process_parameter, check_views # from hyperopt import fmin, tpe, Trials class rCCA(_CCA_Base): - """ + r""" A class used to fit Regularised CCA (canonical ridge) model. Uses PCA to perform the optimization efficiently for high dimensional data. - Citation - -------- + :Maths: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \}\\ + + \text{subject to:} + + (1-c_1)w_1^TX_1^TX_1w_1+c_1w_1^Tw_1=1 + + (1-c_2)w_2^TX_2^TX_2w_2+c_2w_2^Tw_2=1 + + + :Citation: + Vinod, Hrishikesh D. "Canonical ridge and econometrics of joint production." Journal of econometrics 4.2 (1976): 147-166. :Example: >>> from cca_zoo.models import rCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) >>> model = rCCA(c=[0.1,0.1]) - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.95222128]) """ def __init__( @@ -157,22 +173,37 @@ def _multi_view_evp(self, Us, Ss): class CCA(rCCA): - """ + r""" A class used to fit a simple CCA model Implements CCA by inheriting regularised CCA with 0 regularisation - Citation - -------- + :Maths: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \}\\ + + \text{subject to:} + + w_1^TX_1^TX_1w_1=1 + + w_2^TX_2^TX_2w_2=1 + + :Citation: + Hotelling, Harold. "Relations between two sets of variates." Breakthroughs in statistics. Springer, New York, NY, 1992. 162-190. :Example: >>> from cca_zoo.models import CCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) >>> model = CCA() - >>> model.fit(X1,X2) + >>> model.fit((X1,X2)).score((X1,X2)) + array([1.]) """ def __init__( @@ -185,6 +216,7 @@ def __init__( ): """ Constructor for CCA + :param latent_dims: number of latent dimensions to fit :param scale: normalize variance in each column before fitting :param centre: demean data by column before fitting (and before transforming out of sample @@ -202,18 +234,33 @@ def __init__( class PLS(rCCA): - """ + r""" A class used to fit a simple PLS model Implements PLS by inheriting regularised CCA with maximal regularisation + :Maths: + + .. math:: + + w_{opt}=\underset{w}{\mathrm{argmax}}\{ w_1^TX_1^TX_2w_2 \}\\ + + \text{subject to:} + + w_1^Tw_1=1 + + w_2^Tw_2=1 + :Example: >>> from cca_zoo.models import PLS - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) - >>> model = CCA() - >>> model.fit([X1,X2]) + >>> import numpy as np + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> model = PLS() + >>> model.fit((X1,X2)).score((X1,X2)) + array([0.81796873]) """ def __init__( @@ -225,7 +272,8 @@ def __init__( random_state=None, ): """ - Constructor for CCA + Constructor for PLS + :param latent_dims: number of latent dimensions to fit :param scale: normalize variance in each column before fitting :param centre: demean data by column before fitting (and before transforming out of sample diff --git a/cca_zoo/models/tcca.py b/cca_zoo/models/tcca.py index 8f7289aa..740ab737 100644 --- a/cca_zoo/models/tcca.py +++ b/cca_zoo/models/tcca.py @@ -7,27 +7,39 @@ from sklearn.utils.validation import check_is_fitted from tensorly.decomposition import parafac -from .cca_base import _CCA_Base -from ..utils.check_values import _process_parameter, check_views +from cca_zoo.models.cca_base import _CCA_Base +from cca_zoo.utils.check_values import _process_parameter, check_views class TCCA(_CCA_Base): - """ + r""" Fits a Tensor CCA model. Tensor CCA maximises higher order correlations - Citation - -------- - Kim, Tae-Kyun, Shu-Fai Wong, and Roberto Cipolla. "Tensor canonical correlation analysis for action classification." 2007 IEEE Conference on Computer Vision and Pattern Recognition. IEEE, 2007 + :Maths: + + .. math:: + + \alpha_{opt}=\underset{\alpha}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \alpha_i^TK_i^TK_j\alpha_j \}\\ - My own port from https://github.com/rciszek/mdr_tcca + \text{subject to:} + + \alpha_i^TK_i^TK_i\alpha_i=1 + + :Citation: + + Kim, Tae-Kyun, Shu-Fai Wong, and Roberto Cipolla. "Tensor canonical correlation analysis for action classification." 2007 IEEE Conference on Computer Vision and Pattern Recognition. IEEE, 2007 + https://github.com/rciszek/mdr_tcca :Example: >>> from cca_zoo.models import TCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> X3 = rng.random((10,5)) >>> model = TCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) + array([1.14595755]) """ def __init__( @@ -92,7 +104,7 @@ def fit(self, views: Iterable[np.ndarray], y=None, **kwargs): M = np.expand_dims(M, -1) @ el M = np.mean(M, 0) tl.set_backend("numpy") - M_parafac = parafac(M, self.latent_dims, verbose=True) + M_parafac = parafac(M, self.latent_dims, verbose=False) self.weights = [ cov_invsqrt @ fac for i, (view, cov_invsqrt, fac) in enumerate( @@ -148,20 +160,33 @@ def _setup_tensor(self, *views: np.ndarray, **kwargs): class KTCCA(TCCA): - """ + r""" Fits a Kernel Tensor CCA model. Tensor CCA maximises higher order correlations - Citation - -------- + :Maths: + + .. math:: + + \alpha_{opt}=\underset{\alpha}{\mathrm{argmax}}\{\sum_i\sum_{j\neq i} \alpha_i^TK_i^TK_j\alpha_j \}\\ + + \text{subject to:} + + \alpha_i^TK_i^TK_i\alpha_i=1 + + :Citation: + Kim, Tae-Kyun, Shu-Fai Wong, and Roberto Cipolla. "Tensor canonical correlation analysis for action classification." 2007 IEEE Conference on Computer Vision and Pattern Recognition. IEEE, 2007 :Example: >>> from cca_zoo.models import KTCCA - >>> X1 = np.random.rand(10,5) - >>> X2 = np.random.rand(10,5) + >>> rng=np.random.RandomState(0) + >>> X1 = rng.random((10,5)) + >>> X2 = rng.random((10,5)) + >>> X3 = rng.random((10,5)) >>> model = KTCCA() - >>> model.fit([X1,X2]) + >>> model.fit((X1,X2,X3)).score((X1,X2,X3)) + array([1.69896269]) """ def __init__( @@ -218,7 +243,7 @@ def _check_params(self): self.c = _process_parameter("c", self.c, 0, self.n_views) def _get_kernel(self, view, X, Y=None): - if callable(self.kernel): + if callable(self.kernel[view]): params = self.kernel_params[view] or {} else: params = { diff --git a/cca_zoo/probabilisticmodels/vcca.py b/cca_zoo/probabilisticmodels/vcca.py index 5f9696eb..eaf5ed3a 100644 --- a/cca_zoo/probabilisticmodels/vcca.py +++ b/cca_zoo/probabilisticmodels/vcca.py @@ -13,14 +13,11 @@ class VariationalCCA(_CCA_Base): """ - A class used to fit a variational bayesian CCA + A class used to fit a variational bayesian CCA. Not quite the same due to using VI methods rather than EM - Citation - -------- - Wang, Chong. "Variational Bayesian approach to canonical correlation analysis." IEEE Transactions on Neural Networks 18.3 (2007): 905-910. - - :Example: + :Citation: + Wang, Chong. "Variational Bayesian approach to canonical correlation analysis." IEEE Transactions on Neural Networks 18.3 (2007): 905-910. """ diff --git a/cca_zoo/test/test_data.py b/cca_zoo/test/test_data.py deleted file mode 100644 index 5a0ac1c0..00000000 --- a/cca_zoo/test/test_data.py +++ /dev/null @@ -1,56 +0,0 @@ -import numpy as np -from sklearn.utils.validation import check_random_state - -from cca_zoo.data import Noisy_MNIST_Dataset, Tangled_MNIST_Dataset, Split_MNIST_Dataset -from cca_zoo.data import generate_covariance_data -from cca_zoo.models import MCCA, CCA - -rng = check_random_state(0) -X = rng.rand(500, 20) -Y = rng.rand(500, 21) -Z = rng.rand(500, 22) - - -def test_data_rand(): - (x, y), true_feats = generate_covariance_data( - 1000, [10, 11], 1, [0.5, 0.5], correlation=0.5, structure="random" - ) - - -def test_data_gen(): - (x, y, z), true_feats = generate_covariance_data( - 1000, - [10, 11, 12], - 1, - [0.5, 0.5, 0.5], - correlation=0.5, - structure=["identity", "identity", "identity"], - ) - cca = CCA().fit((x[:500], y[:500])) - cca_pred = cca.score((x[500:], y[500:])) - mcca = MCCA().fit((x[:500], y[:500], z[:500])) - mcca_pred = mcca.score((x[500:], y[500:], z[500:])) - - (x, y), true_feats = generate_covariance_data( - 1000, - [10, 11], - 1, - [0.5, 0.5], - correlation=0.5, - structure=["gaussian", "toeplitz"], - ) - - -def test_deep_data(): - dataset = Noisy_MNIST_Dataset(mnist_type="FashionMNIST", train=True) - (train_view_1, train_view_2), (train_rotations, train_labels) = dataset.to_numpy( - np.arange(10) - ) - dataset = Tangled_MNIST_Dataset(mnist_type="FashionMNIST", train=True) - (train_view_1, train_view_2), ( - train_rotations_1, - train_rotations_2, - train_labels, - ) = dataset.to_numpy(np.arange(10)) - dataset = Split_MNIST_Dataset(mnist_type="FashionMNIST", train=True) - (train_view_1, train_view_2), (train_labels) = dataset.to_numpy(np.arange(10)) diff --git a/cca_zoo/test/test_deepmodels.py b/cca_zoo/test/test_deepmodels.py index 63080bba..dcad9bb9 100644 --- a/cca_zoo/test/test_deepmodels.py +++ b/cca_zoo/test/test_deepmodels.py @@ -1,11 +1,22 @@ import numpy as np +import pytorch_lightning as pl from sklearn.utils.validation import check_random_state from torch import optim, manual_seed -from torch.utils.data import Subset from cca_zoo import data -from cca_zoo.data import Noisy_MNIST_Dataset -from cca_zoo.deepmodels import DCCA, DCCAE, DVCCA, DCCA_NOI, DTCCA, SplitAE, DeepWrapper +from cca_zoo.deepmodels import ( + DCCA, + DCCAE, + DVCCA, + DCCA_NOI, + DTCCA, + SplitAE, + CCALightning, + get_dataloaders, + process_data, + BarlowTwins, + DCCA_SDL, +) from cca_zoo.deepmodels import objectives, architectures from cca_zoo.models import CCA @@ -16,290 +27,169 @@ Z = rng.rand(200, 14) X_conv = rng.rand(100, 1, 16, 16) Y_conv = rng.rand(100, 1, 16, 16) -train_dataset = data.CCA_Dataset([X, Y]) - - -def test_input_types(): +dataset = data.CCA_Dataset([X, Y, Z]) +train_dataset, val_dataset = process_data(dataset, val_split=0.2) +train_dataset_numpy, val_dataset_numpy = process_data((X, Y, Z), val_split=0.2) +loader = get_dataloaders(dataset) +train_loader, val_loader = get_dataloaders(train_dataset, val_dataset) +train_loader_numpy, val_loader_numpy = get_dataloaders(train_dataset, val_dataset) +conv_dataset = data.CCA_Dataset((X_conv, Y_conv)) +conv_loader = get_dataloaders(conv_dataset) + + +def test_DCCA_methods(): + N = len(train_dataset) latent_dims = 2 - device = "cpu" + epochs = 100 + cca = CCA(latent_dims=latent_dims).fit((X, Y)) + # DCCA_NOI encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - # DCCA - dcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.CCA, + dcca_noi = DCCA_NOI(latent_dims, N, encoders=[encoder_1, encoder_2], rho=0) + optimizer = optim.Adam(dcca_noi.parameters(), lr=1e-3) + dcca_noi = CCALightning(dcca_noi, optimizer=optimizer) + trainer = pl.Trainer( + max_epochs=epochs, log_every_n_steps=10, enable_checkpointing=False ) - - dcca_model = DeepWrapper(dcca_model, device=device) - dcca_model.fit(train_dataset, epochs=3) - dcca_model.fit(train_dataset, val_dataset=train_dataset, epochs=3) - dcca_model.fit((X, Y), val_dataset=(X, Y), epochs=3) - dcca_model.fit((X, Y), val_split=0.2, epochs=3) - - -def tutorial_test(): - # Load MNIST Data - N = 500 - latent_dims = 2 - dataset = Noisy_MNIST_Dataset(mnist_type="FashionMNIST", train=True) - ids = np.arange(min(2 * N, len(dataset))) - np.random.shuffle(ids) - train_ids, val_ids = np.array_split(ids, 2) - val_dataset = Subset(dataset, val_ids) - train_dataset = Subset(dataset, train_ids) - test_dataset = Noisy_MNIST_Dataset(mnist_type="FashionMNIST", train=False) - test_ids = np.arange(min(N, len(test_dataset))) - np.random.shuffle(test_ids) - test_dataset = Subset(test_dataset, test_ids) - print("DCCA") - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784) - dcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) - dcca_model = DeepWrapper(dcca_model) - dcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=2) - dcca_results = np.stack( - (dcca_model.score(train_dataset), dcca_model.correlations(test_dataset)[0, 1]) + trainer.fit(dcca_noi, train_loader) + assert ( + np.testing.assert_array_less( + cca.score((X, Y)).sum(), trainer.model.score(train_loader).sum() + ) + is None ) - - -def test_large_p(): - large_p = 256 - X = rng.rand(2000, large_p) - Y = rng.rand(2000, large_p) - latent_dims = 32 - device = "cpu" - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=large_p) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=large_p) - dcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.MCCA, - eps=1e-3, + # Soft Decorrelation (stochastic Decorrelation Loss) + encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) + encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) + sdl = DCCA_SDL(latent_dims, N, encoders=[encoder_1, encoder_2], lam=1e-3) + optimizer = optim.SGD(sdl.parameters(), lr=1e-1) + sdl = CCALightning(sdl, optimizer=optimizer) + trainer = pl.Trainer(max_epochs=epochs, log_every_n_steps=10) + trainer.fit(sdl, train_loader) + assert ( + np.testing.assert_array_less( + cca.score((X, Y)).sum(), trainer.model.score(train_loader).sum() + ) + is None ) - optimizer = optim.Adam(dcca_model.parameters(), lr=1e-4) - dcca_model = DeepWrapper(dcca_model, device=device, optimizer=optimizer) - dcca_model.fit((X, Y), epochs=100) - - -def test_DCCA_methods_cpu(): - latent_dims = 4 - cca_model = CCA(latent_dims=latent_dims).fit((X, Y)) - device = "cpu" - epochs = 100 # DCCA encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - dcca_model = DCCA( + dcca = DCCA( latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.CCA, ) - optimizer = optim.SGD(dcca_model.parameters(), lr=1e-2) - dcca_model = DeepWrapper(dcca_model, device=device, optimizer=optimizer) - dcca_model.fit((X, Y), epochs=epochs) + optimizer = optim.SGD(dcca.parameters(), lr=1e-2) + dcca = CCALightning(dcca, optimizer=optimizer) + trainer = pl.Trainer( + max_epochs=epochs, log_every_n_steps=10, enable_checkpointing=False + ) + trainer.fit(dcca, train_loader) assert ( np.testing.assert_array_less( - cca_model.score((X, Y)).sum(), dcca_model.score((X, Y)).sum() + cca.score((X, Y)).sum(), trainer.model.score(train_loader).sum() ) is None ) # DGCCA encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - dgcca_model = DCCA( + dgcca = DCCA( latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.GCCA, ) - optimizer = optim.SGD(dgcca_model.parameters(), lr=1e-2) - dgcca_model = DeepWrapper(dgcca_model, device=device, optimizer=optimizer) - dgcca_model.fit((X, Y), epochs=epochs) + optimizer = optim.SGD(dgcca.parameters(), lr=1e-2) + dgcca = CCALightning(dgcca, optimizer=optimizer) + trainer = pl.Trainer( + max_epochs=epochs, log_every_n_steps=10, enable_checkpointing=False + ) + trainer.fit(dgcca, train_loader) assert ( np.testing.assert_array_less( - cca_model.score((X, Y)).sum(), dgcca_model.score((X, Y)).sum() + cca.score((X, Y)).sum(), trainer.model.score(train_loader).sum() ) is None ) # DMCCA encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - dmcca_model = DCCA( + dmcca = DCCA( latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.MCCA, ) - optimizer = optim.SGD(dmcca_model.parameters(), lr=1e-2) - dmcca_model = DeepWrapper(dmcca_model, device=device, optimizer=optimizer) - dmcca_model.fit((X, Y), epochs=epochs) + optimizer = optim.SGD(dmcca.parameters(), lr=1e-2) + dmcca = CCALightning(dmcca, optimizer=optimizer) + trainer = pl.Trainer( + max_epochs=epochs, log_every_n_steps=10, enable_checkpointing=False + ) + trainer.fit(dmcca, train_loader) assert ( np.testing.assert_array_less( - cca_model.score((X, Y)).sum(), dmcca_model.score((X, Y)).sum() + cca.score((X, Y)).sum(), trainer.model.score(train_loader).sum() ) is None ) - # DCCA_NOI + # Barlow Twins encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - dcca_noi_model = DCCA_NOI( - latent_dims, X.shape[0], encoders=[encoder_1, encoder_2], rho=0 + barlowtwins = BarlowTwins( + latent_dims=latent_dims, + encoders=[encoder_1, encoder_2], + ) + optimizer = optim.SGD(barlowtwins.parameters(), lr=1e-2) + barlowtwins = CCALightning(barlowtwins, optimizer=optimizer) + trainer = pl.Trainer( + max_epochs=epochs, log_every_n_steps=10, enable_checkpointing=False ) - optimizer = optim.Adam(dcca_noi_model.parameters(), lr=1e-3) - dcca_noi_model = DeepWrapper(dcca_noi_model, device=device, optimizer=optimizer) - dcca_noi_model.fit((X, Y), epochs=epochs) + trainer.fit(barlowtwins, train_loader) assert ( np.testing.assert_array_less( - cca_model.score((X, Y)).sum(), dcca_noi_model.score((X, Y)).sum() + cca.score((X, Y)).sum(), trainer.model.score(train_loader).sum() ) is None ) -def test_DTCCA_methods_cpu(): +def test_DTCCA_methods(): latent_dims = 2 - device = "cpu" + epochs = 5 encoder_1 = architectures.Encoder(latent_dims=10, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=10, feature_size=12) - encoder_3 = architectures.Encoder(latent_dims=10, feature_size=14) - # DTCCA - dtcca_model = DTCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) - - dtcca_model = DeepWrapper(dtcca_model, device=device) - dtcca_model.fit((X, Y), epochs=20) - # DCCA - dcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.GCCA, - ) - - dcca_model = DeepWrapper(dcca_model, device=device) - dcca_model.fit((X, Y), epochs=20) - - -def test_scheduler(): - latent_dims = 2 - device = "cpu" - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - # DCCA - dcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.CCA, - ) - optimizer = optim.Adam(dcca_model.parameters(), lr=1e-4) - scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1) - - dcca_model = DeepWrapper( - dcca_model, device=device, optimizer=optimizer, scheduler=scheduler - ) - dcca_model.fit((X, Y), epochs=20) - - -def test_DGCCA_methods_cpu(): - latent_dims = 2 - device = "cpu" - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - encoder_3 = architectures.Encoder(latent_dims=latent_dims, feature_size=14) - # DTCCA - dtcca_model = DTCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) - - dtcca_model = DeepWrapper(dtcca_model, device=device) - dtcca_model.fit((X, Y, Z)) - # DGCCA - dgcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2, encoder_3], - objective=objectives.GCCA, - ) - - dgcca_model = DeepWrapper(dgcca_model, device=device) - dgcca_model.fit((X, Y, Z)) - # DMCCA - dmcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2, encoder_3], - objective=objectives.MCCA, - ) - - dmcca_model = DeepWrapper(dmcca_model, device=device) - dmcca_model.fit((X, Y, Z)) + dtcca = DTCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) + dtcca = CCALightning(dtcca) + trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) + trainer.fit(dtcca, train_loader) -def test_DCCAE_methods_cpu(): +def test_DCCAE_methods(): latent_dims = 2 - device = "cpu" encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) decoder_1 = architectures.Decoder(latent_dims=latent_dims, feature_size=10) decoder_2 = architectures.Decoder(latent_dims=latent_dims, feature_size=12) - # DCCAE - dccae_model = DCCAE( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - decoders=[decoder_1, decoder_2], - ) - - dccae_model = DeepWrapper(dccae_model, device=device) - dccae_model.fit((X, Y), epochs=20) # SplitAE - splitae_model = SplitAE( + splitae = SplitAE( latent_dims=latent_dims, encoder=encoder_1, decoders=[decoder_1, decoder_2] ) - - splitae_model = DeepWrapper(splitae_model, device=device) - splitae_model.fit((X, Y), epochs=10) - - -def test_DCCAEconv_methods_cpu(): - latent_dims = 2 - device = "cpu" - encoder_1 = architectures.CNNEncoder(latent_dims=latent_dims, feature_size=[16, 16]) - encoder_2 = architectures.CNNEncoder(latent_dims=latent_dims, feature_size=[16, 16]) - decoder_1 = architectures.CNNDecoder(latent_dims=latent_dims, feature_size=[16, 16]) - decoder_2 = architectures.CNNDecoder(latent_dims=latent_dims, feature_size=[16, 16]) + splitae = CCALightning(splitae) + trainer = pl.Trainer(max_epochs=5, enable_checkpointing=False) + trainer.fit(splitae, train_loader) # DCCAE - dccae_model = DCCAE( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - decoders=[decoder_1, decoder_2], - ) - - dccae_model = DeepWrapper(dccae_model, device=device) - dccae_model.fit((X_conv, Y_conv)) - - -def test_DVCCA_methods_cpu(): - latent_dims = 2 - device = "cpu" - encoder_1 = architectures.Encoder( - latent_dims=latent_dims, feature_size=10, variational=True - ) - encoder_2 = architectures.Encoder( - latent_dims=latent_dims, feature_size=12, variational=True - ) - decoder_1 = architectures.Decoder( - latent_dims=latent_dims, feature_size=10, norm_output=True - ) - decoder_2 = architectures.Decoder( - latent_dims=latent_dims, feature_size=12, norm_output=True - ) - # DVCCA - dvcca_model = DVCCA( + dccae = DCCAE( latent_dims=latent_dims, encoders=[encoder_1, encoder_2], decoders=[decoder_1, decoder_2], ) + dccae = CCALightning(dccae) + trainer = pl.Trainer(max_epochs=5, enable_checkpointing=False) + trainer.fit(dccae, train_loader) - dvcca_model = DeepWrapper(dvcca_model, device=device) - dvcca_model.fit((X, Y)) - -def test_DVCCA_p_methods_cpu(): +def test_DVCCA_p_methods(): latent_dims = 2 - device = "cpu" encoder_1 = architectures.Encoder( latent_dims=latent_dims, feature_size=10, variational=True ) @@ -319,105 +209,20 @@ def test_DVCCA_p_methods_cpu(): latent_dims=2 * latent_dims, feature_size=12, norm_output=True ) # DVCCA - dvcca_model = DVCCA( + dvcca = DVCCA( latent_dims=latent_dims, encoders=[encoder_1, encoder_2], decoders=[decoder_1, decoder_2], private_encoders=[private_encoder_1, private_encoder_2], ) - dvcca_model = DeepWrapper(dvcca_model, device=device) - dvcca_model.fit((X, Y)) + dvcca = CCALightning(dvcca) + trainer = pl.Trainer(max_epochs=5, enable_checkpointing=False) + trainer.fit(dvcca, train_loader) -def test_DCCA_methods_gpu(): +def test_DVCCA_methods(): latent_dims = 2 - device = "cuda" - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - # DCCA - dcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.CCA, - ) - - dcca_model = DeepWrapper(dcca_model, device=device) - dcca_model.fit((X, Y)) - # DGCCA - dgcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.GCCA, - ) - - dgcca_model = DeepWrapper(dgcca_model, device=device) - dgcca_model.fit((X, Y)) - # DMCCA - dmcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - objective=objectives.MCCA, - ) - - dmcca_model = DeepWrapper(dmcca_model, device=device) - dmcca_model.fit((X, Y)) - # DCCA_NOI - dcca_noi_model = DCCA_NOI( - latent_dims, X.shape[0], encoders=[encoder_1, encoder_2], rho=0 - ) - - dcca_noi_model = DeepWrapper(dcca_noi_model, device=device) - dcca_noi_model.fit((X, Y)) - - -def test_DGCCA_methods_gpu(): - latent_dims = 2 - device = "cuda" - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - encoder_3 = architectures.Encoder(latent_dims=latent_dims, feature_size=14) - # DGCCA - dgcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2, encoder_3], - objective=objectives.GCCA, - ) - - dgcca_model = DeepWrapper(dgcca_model, device=device) - dgcca_model.fit((X, Y, Z)) - # DMCCA - dmcca_model = DCCA( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2, encoder_3], - objective=objectives.MCCA, - ) - - dmcca_model = DeepWrapper(dmcca_model, device=device) - dmcca_model.fit((X, Y, Z)) - - -def test_DCCAE_methods_gpu(): - latent_dims = 2 - device = "cuda" - encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=10) - encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=12) - decoder_1 = architectures.Decoder(latent_dims=latent_dims, feature_size=10) - decoder_2 = architectures.Decoder(latent_dims=latent_dims, feature_size=12) - # DCCAE - dccae_model = DCCAE( - latent_dims=latent_dims, - encoders=[encoder_1, encoder_2], - decoders=[decoder_1, decoder_2], - ) - - dccae_model = DeepWrapper(dccae_model, device=device) - dccae_model.fit((X, Y)) - - -def test_DVCCA_methods_gpu(): - latent_dims = 2 - device = "cuda" encoder_1 = architectures.Encoder( latent_dims=latent_dims, feature_size=10, variational=True ) @@ -430,27 +235,29 @@ def test_DVCCA_methods_gpu(): decoder_2 = architectures.Decoder( latent_dims=latent_dims, feature_size=12, norm_output=True ) - # DVCCA - dvcca_model = DVCCA( + dvcca = DVCCA( latent_dims=latent_dims, encoders=[encoder_1, encoder_2], decoders=[decoder_1, decoder_2], ) - dvcca_model = DeepWrapper(dvcca_model, device=device) - dvcca_model.fit((X, Y)) + dvcca = CCALightning(dvcca) + trainer = pl.Trainer(max_epochs=5, enable_checkpointing=False) + trainer.fit(dvcca, train_loader) def test_linear(): encoder_1 = architectures.LinearEncoder(latent_dims=1, feature_size=10) encoder_2 = architectures.LinearEncoder(latent_dims=1, feature_size=12) - dcca_model = DCCA(latent_dims=1, encoders=[encoder_1, encoder_2]) - dcca_model = DeepWrapper(dcca_model).fit((X, Y), epochs=40) + dcca = DCCA(latent_dims=1, encoders=[encoder_1, encoder_2]) + dcca = CCALightning(dcca, learning_rate=1e-1) + trainer = pl.Trainer(max_epochs=50, enable_checkpointing=False) + trainer.fit(dcca, loader) cca = CCA().fit((X, Y)) # check linear encoder with SGD matches vanilla linear CCA assert ( np.testing.assert_array_almost_equal( - cca.score((X, Y)), dcca_model.score((X, Y)), decimal=2 + cca.score((X, Y)), trainer.model.score(loader), decimal=2 ) is None ) diff --git a/cca_zoo/test/test_models.py b/cca_zoo/test/test_models.py index a95483e3..22f00b08 100644 --- a/cca_zoo/test/test_models.py +++ b/cca_zoo/test/test_models.py @@ -17,20 +17,23 @@ MCCA, GCCA, TCCA, - SCCA_ADMM, SpanCCA, SWCCA, PLS_ALS, KGCCA, + NCCA, + ParkhomenkoCCA, + SCCA_ADMM, ) from cca_zoo.utils.plotting import cv_plot +n = 50 rng = check_random_state(0) -X = rng.rand(500, 20) -Y = rng.rand(500, 21) -Z = rng.rand(500, 22) -X_sp = sp.random(500, 20, density=0.5, random_state=rng) -Y_sp = sp.random(500, 21, density=0.5, random_state=rng) +X = rng.rand(n, 4) +Y = rng.rand(n, 5) +Z = rng.rand(n, 6) +X_sp = sp.random(n, 4, density=0.5, random_state=rng) +Y_sp = sp.random(n, 5, density=0.5, random_state=rng) def test_unregularized_methods(): @@ -56,60 +59,60 @@ def test_unregularized_methods(): corr_kgcca = kgcca.score((X, Y)) corr_tcca = tcca.score((X, Y)) # Check the correlations from each unregularized method are the same - assert np.testing.assert_array_almost_equal(corr_cca, corr_iter, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_mcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_gcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_kcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_tcca, decimal=2) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_iter, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_mcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_gcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_kcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_tcca, decimal=1) is None assert ( - np.testing.assert_array_almost_equal(corr_kgcca, corr_gcca, decimal=2) is None + np.testing.assert_array_almost_equal(corr_kgcca, corr_gcca, decimal=1) is None ) # Check standardized models have standard outputs assert ( np.testing.assert_allclose( - np.linalg.norm(iter.transform((X, Y))[0], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(iter.transform((X, Y))[0], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(cca.transform((X, Y))[0], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(cca.transform((X, Y))[0], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(mcca.transform((X, Y))[0], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(mcca.transform((X, Y))[0], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(kcca.transform((X, Y))[0], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(kcca.transform((X, Y))[0], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(iter.transform((X, Y))[1], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(iter.transform((X, Y))[1], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(cca.transform((X, Y))[1], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(cca.transform((X, Y))[1], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(mcca.transform((X, Y))[1], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(mcca.transform((X, Y))[1], axis=0) ** 2, n, rtol=0.2 ) is None ) assert ( np.testing.assert_allclose( - np.linalg.norm(kcca.transform((X, Y))[1], axis=0) ** 2, 500, rtol=0.1 + np.linalg.norm(kcca.transform((X, Y))[1], axis=0) ** 2, n, rtol=0.2 ) is None ) @@ -135,10 +138,9 @@ def test_sparse_input(): corr_mcca = mcca.score((X, Y)) corr_kcca = kcca.score((X, Y)) # Check the correlations from each unregularized method are the same - assert np.testing.assert_array_almost_equal(corr_cca, corr_iter, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_iter, corr_mcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_iter, corr_gcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_iter, corr_kcca, decimal=2) is None + assert np.testing.assert_array_almost_equal(corr_iter, corr_mcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_iter, corr_gcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_iter, corr_kcca, decimal=1) is None def test_unregularized_multi(): @@ -156,9 +158,9 @@ def test_unregularized_multi(): corr_kcca = kcca.score((X, Y, Z)) # Check the correlations from each unregularized method are the same assert np.testing.assert_array_almost_equal(corr_cca, corr_iter, decimal=1) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_mcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_gcca, decimal=2) is None - assert np.testing.assert_array_almost_equal(corr_cca, corr_kcca, decimal=2) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_mcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_gcca, decimal=1) is None + assert np.testing.assert_array_almost_equal(corr_cca, corr_kcca, decimal=1) is None def test_regularized_methods(): @@ -178,7 +180,6 @@ def test_regularized_methods(): corr_pls = pls.score((X, Y)) corr_rcca = rcca.score((X, Y)) # Check the correlations from each unregularized method are the same - # assert np.testing.assert_array_almost_equal(corr_pls, corr_gcca, decimal=2)) assert np.testing.assert_array_almost_equal(corr_pls, corr_mcca, decimal=1) is None assert ( np.testing.assert_array_almost_equal(corr_pls, corr_kernel, decimal=1) is None @@ -188,49 +189,63 @@ def test_regularized_methods(): def test_non_negative_methods(): latent_dims = 2 - nnelasticca = ElasticCCA( + nnelastic = ElasticCCA( latent_dims=latent_dims, tol=1e-9, positive=True, l1_ratio=[0.5, 0.5], c=[1e-4, 1e-5], ).fit([X, Y]) - als = CCA_ALS(latent_dims=latent_dims, tol=1e-9).fit([X, Y]) nnals = CCA_ALS(latent_dims=latent_dims, tol=1e-9, positive=True).fit([X, Y]) nnscca = SCCA(latent_dims=latent_dims, tol=1e-9, positive=True, c=[1e-4, 1e-5]).fit( (X, Y) ) + assert np.testing.assert_array_less(-1e-9, nnelastic.weights[0]) is None + assert np.testing.assert_array_less(-1e-9, nnelastic.weights[1]) is None + assert np.testing.assert_array_less(-1e-9, nnals.weights[0]) is None + assert np.testing.assert_array_less(-1e-9, nnals.weights[1]) is None + assert np.testing.assert_array_less(-1e-9, nnscca.weights[0]) is None + assert np.testing.assert_array_less(-1e-9, nnscca.weights[1]) is None def test_sparse_methods(): - # Test sparsity inducing methods. At the moment just checks running. - latent_dims = 2 c1 = [1, 3] c2 = [1, 3] - param_grid = {"c": [c1, c2]} pmd_cv = GridSearchCV(PMD(random_state=rng), param_grid=param_grid).fit([X, Y]) cv_plot(pmd_cv.cv_results_) - c1 = [1e-4, 1e-5] - c2 = [1e-4, 1e-5] + c1 = [5e-1] + c2 = [1e-1] param_grid = {"c": [c1, c2]} scca_cv = GridSearchCV(SCCA(random_state=rng), param_grid=param_grid).fit([X, Y]) - - c1 = loguniform(1e-4, 1e0) - c2 = loguniform(1e-4, 1e0) + c1 = [1e-1] + c2 = [1e-1] + param_grid = {"c": [c1, c2]} + parkhomenko_cv = GridSearchCV(ParkhomenkoCCA(random_state=rng), param_grid=param_grid).fit([X, Y]) + c1 = [2e-2] + c2 = [1e-2] param_grid = {"c": [c1, c2]} + admm_cv = GridSearchCV(SCCA_ADMM(random_state=rng), param_grid=param_grid).fit([X, Y]) + c1 = loguniform(1e-1, 2e-1) + c2 = loguniform(1e-1, 2e-1) + param_grid = {"c": [c1, c2], "l1_ratio": [[0.9], [0.9]]} elastic_cv = RandomizedSearchCV( ElasticCCA(random_state=rng), param_distributions=param_grid, n_iter=4 ).fit([X, Y]) - corr_pmd = pmd_cv.score((X, Y)) - corr_scca = scca_cv.score((X, Y)) - corr_elastic = elastic_cv.score((X, Y)) - scca_admm = SCCA_ADMM(c=[1e-4, 1e-4]).fit([X, Y]) - scca = SCCA(c=[1e-4, 1e-4]).fit([X, Y]) + assert (pmd_cv.best_estimator_.weights[0] == 0).sum() > 0 + assert (pmd_cv.best_estimator_.weights[1] == 0).sum() > 0 + assert (scca_cv.best_estimator_.weights[0] == 0).sum() > 0 + assert (scca_cv.best_estimator_.weights[1] == 0).sum() > 0 + assert (admm_cv.best_estimator_.weights[0] == 0).sum() > 0 + assert (admm_cv.best_estimator_.weights[1] == 0).sum() > 0 + assert (parkhomenko_cv.best_estimator_.weights[0] == 0).sum() > 0 + assert (parkhomenko_cv.best_estimator_.weights[1] == 0).sum() > 0 + assert (elastic_cv.best_estimator_.weights[0] == 0).sum() > 0 + assert (elastic_cv.best_estimator_.weights[1] == 0).sum() > 0 def test_weighted_GCCA_methods(): - # Test the 'fancy' additions to GCCA i.e. the view weighting and observation weighting. + # TODO we have view weighted GCCA and missing observation GCCA latent_dims = 2 c = 0 unweighted_gcca = GCCA(latent_dims=latent_dims, c=[c, c]).fit([X, Y]) @@ -252,12 +267,20 @@ def test_weighted_GCCA_methods(): def test_TCCA(): - # Tests tensor CCA methods - latent_dims = 2 + latent_dims = 1 tcca = TCCA(latent_dims=latent_dims, c=[0.2, 0.2, 0.2]).fit([X, X, Y]) ktcca = KTCCA(latent_dims=latent_dims, c=[0.2, 0.2]).fit([X, Y]) corr_tcca = tcca.score((X, X, Y)) corr_ktcca = ktcca.score((X, Y)) + assert corr_tcca > 0.4 + assert corr_ktcca > 0.4 + + +def test_NCCA(): + latent_dims = 1 + ncca = NCCA(latent_dims=latent_dims).fit((X, Y)) + corr_ncca = ncca.score((X, Y)) + assert corr_ncca > 0.9 def test_l0(): @@ -276,7 +299,7 @@ def test_VCCA(): from cca_zoo.data import generate_simple_data # Tests tensor CCA methods - (X, Y), (_) = generate_simple_data(100, [10, 10], random_state=rng, eps=0.1) + (X, Y), (_) = generate_simple_data(20, [9, 9], random_state=rng, eps=0.1) latent_dims = 1 cca = CCA(latent_dims=latent_dims).fit([X, Y]) vcca = VariationalCCA( diff --git a/cca_zoo/utils/check_values.py b/cca_zoo/utils/check_values.py index 00839f28..451145cb 100644 --- a/cca_zoo/utils/check_values.py +++ b/cca_zoo/utils/check_values.py @@ -54,7 +54,7 @@ def _check_parameter_number(parameter_name: str, parameter, n_views: int): def _check_converged_weights(weights, view_index): """check the converged weights are not zero.""" if np.linalg.norm(weights) <= 0: - raise ValueError( + warnings.warn( f"All result weights are zero in view {view_index}. " "Try less regularisation or another initialisation" ) diff --git a/cca_zoo/utils/plotting.py b/cca_zoo/utils/plotting.py index 69abf2a9..2857fb7c 100644 --- a/cca_zoo/utils/plotting.py +++ b/cca_zoo/utils/plotting.py @@ -9,7 +9,7 @@ from matplotlib import cm -def post_process_cv_results(df): +def _post_process_cv_results(df): cols = [col for col in df.columns if "param_" in col] for col in cols: df = df.join( @@ -26,7 +26,7 @@ def cv_plot(cv_results_): """ if isinstance(cv_results_, dict): cv_results_ = pd.DataFrame(cv_results_) - cv_results_ = post_process_cv_results(cv_results_) + cv_results_ = _post_process_cv_results(cv_results_) param_cols = [col for col in cv_results_.columns if "param_" in col] n_params = len(param_cols) n_uniques = [cv_results_[col].nunique() for col in param_cols] diff --git a/docs/requirements.txt b/docs/requirements.txt index 894675cb..1a019f5d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,10 +1,13 @@ -jax -numpyro -arviz +sphinx==1.8.5 +sphinx-autodoc-typehints +sphinx-gallery +pandas numpy -scikit-learn +matplotlib +numpy +scikit-learn>=0.23 mvlearn -scipy>=1.5 +scipy>=1.7 matplotlib seaborn pandas @@ -13,4 +16,7 @@ joblib torch>=1.9.0 torchvision Pillow -sphinx-autodoc-typehints +jax~=0.2.20 +numpyro +arviz +pytorch-lightning \ No newline at end of file diff --git a/docs/source/api/deepmodels.rst b/docs/source/api/deepmodels.rst index a92d7a23..600e207a 100644 --- a/docs/source/api/deepmodels.rst +++ b/docs/source/api/deepmodels.rst @@ -43,10 +43,10 @@ Split Autoencoders :members: :undoc-members: -Deep Wrapper for Training --------------------------------------- +CCALightning Module for training with pytorch-lightning +--------------------------------------------------------- -.. automodule:: cca_zoo.deepmodels.deepwrapper +.. automodule:: cca_zoo.deepmodels.CCALightning :members: :inherited-members: :exclude-members: get_params, set_params diff --git a/docs/source/api/iterative.rst b/docs/source/api/iterative.rst new file mode 100644 index 00000000..e557d68e --- /dev/null +++ b/docs/source/api/iterative.rst @@ -0,0 +1,46 @@ +Normal CCA and PLS by alternating least squares +-------------------------------------------------- +Quicker and more memory efficient for very large data + +.. autoclass:: cca_zoo.models.CCA_ALS + :inherited-members: + :exclude-members: get_params, set_params + +.. autoclass:: cca_zoo.models.PLS_ALS + :inherited-members: + :exclude-members: get_params, set_params + + +Sparsity Inducing Models +-------------------------------------------------- + +.. autoclass:: cca_zoo.models.PMD + :inherited-members: + :exclude-members: get_params, set_params + +.. autoclass:: cca_zoo.models.SCCA + :inherited-members: + :exclude-members: get_params, set_params + +.. autoclass:: cca_zoo.models.ElasticCCA + :inherited-members: + :exclude-members: get_params, set_params + +.. autoclass:: cca_zoo.models.SpanCCA + :inherited-members: + :exclude-members: get_params, set_params + +.. autoclass:: cca_zoo.models.ParkhomenkoCCA + :inherited-members: + :exclude-members: get_params, set_params + +.. autoclass:: cca_zoo.models.SCCA_ADMM + :inherited-members: + :exclude-members: get_params, set_params + +Miscellaneous +-------------------------------------------------- + +.. autoclass:: cca_zoo.models.SWCCA + :inherited-members: + :exclude-members: get_params, set_params \ No newline at end of file diff --git a/docs/source/api/models.rst b/docs/source/api/models.rst index 401c9ec0..2cfeaddc 100644 --- a/docs/source/api/models.rst +++ b/docs/source/api/models.rst @@ -1,54 +1,84 @@ Models ======================= -Base Class --------------------------------- -.. automodule:: cca_zoo.models.cca_base - :members: - :private-members: _CCA_Base +Regularized Canonical Correlation Analysis and Partial Least Squares +------------------------------------------------------------------------ + +Canonical Correlation Analysis +**************************************************** +.. autoclass:: cca_zoo.models.rcca.CCA + :inherited-members: :exclude-members: get_params, set_params -rCCA ---------------------------- +Partial Least Squares +**************************************************** +.. autoclass:: cca_zoo.models.rcca.PLS + :inherited-members: + :exclude-members: get_params, set_params -.. automodule:: cca_zoo.models.rcca +Ridge Regularized Canonical Correlation Analysis +**************************************************** +.. autoclass:: cca_zoo.models.rcca.rCCA :inherited-members: :exclude-members: get_params, set_params GCCA and KGCCA --------------------------- -.. automodule:: cca_zoo.models.gcca +Generalized (MAXVAR) Canonical Correlation Analysis +**************************************************** +.. autoclass:: cca_zoo.models.gcca.GCCA + :inherited-members: + :exclude-members: get_params, set_params + +Kernel Generalized (MAXVAR) Canonical Correlation Analysis +************************************************************ +.. autoclass:: cca_zoo.models.gcca.KGCCA :inherited-members: :exclude-members: get_params, set_params MCCA and KCCA --------------------------- -.. automodule:: cca_zoo.models.mcca +Multiset (SUMCOR) Canonical Correlation Analysis +************************************************** +.. autoclass:: cca_zoo.models.mcca.MCCA :inherited-members: :exclude-members: get_params, set_params -TCCA and KTCCA ---------------------------- - -.. automodule:: cca_zoo.models.tcca +Kernel Multiset (SUMCOR) Canonical Correlation Analysis +******************************************************** +.. autoclass:: cca_zoo.models.mcca.KCCA :inherited-members: :exclude-members: get_params, set_params -Iterative Models --------------------------------- +Tensor Canonical Correlation Analysis +---------------------------------------- -.. automodule:: cca_zoo.models.iterative +Tensor Canonical Correlation Analysis +************************************** +.. autoclass:: cca_zoo.models.tcca.TCCA :inherited-members: :exclude-members: get_params, set_params -Inner Loops --------------------------------- - -.. automodule:: cca_zoo.models.innerloop +Kernel Tensor Canonical Correlation Analysis +********************************************** +.. autoclass:: cca_zoo.models.tcca.KTCCA :inherited-members: :exclude-members: get_params, set_params +More Complex Regularisation using Iterative Models +----------------------------------------------------- + +.. toctree:: + :maxdepth: 4 + iterative.rst +Base Class +-------------------------------- + +.. automodule:: cca_zoo.models.cca_base + :members: + :private-members: _CCA_Base + :exclude-members: get_params, set_params diff --git a/docs/source/conf.py b/docs/source/conf.py index 7f0db98e..5321577f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -12,24 +12,48 @@ # import os import sys +import warnings -sys.path.insert(0, os.path.abspath('../..')) - +sys.path.insert(0, os.path.abspath("../..")) +warnings.filterwarnings( + "ignore", + category=UserWarning, + message="Matplotlib is currently using agg, which is a" + " non-GUI backend, so cannot show the figure.", +) # -- Project information ----------------------------------------------------- -project = 'cca-zoo' -copyright = '2021, James Chapman' -author = 'James Chapman' +project = "cca-zoo" +copyright = "2021, James Chapman" +author = "James Chapman" # -- General configuration --------------------------------------------------- extensions = [ - 'sphinx.ext.autodoc', + "sphinx.ext.autodoc", "sphinx.ext.autosummary", - 'sphinx_autodoc_typehints', - 'sphinx.ext.viewcode' + "sphinx_autodoc_typehints", + "sphinx.ext.viewcode", + "sphinx_gallery.gen_gallery", + "sphinx.ext.napoleon", + "sphinx.ext.githubpages", + "sphinx.ext.mathjax", ] +sphinx_gallery_conf = { + "doc_module": "cca-zoo", + "examples_dirs": "../../examples", # path to your example scripts + "gallery_dirs": "auto_examples", # path to where to save gallery generated output +} + +# -- sphinx.ext.intersphinx +intersphinx_mapping = { + "numpy": ("https://docs.scipy.org/doc/numpy", None), + "python": ("https://docs.python.org/3", None), + "scipy": ("https://docs.scipy.org/doc/scipy/reference", None), + "sklearn": ("http://scikit-learn.org/dev", None), +} + # -- sphinx.ext.autosummary autosummary_generate = True @@ -37,10 +61,10 @@ autoclass_content = "both" autodoc_default_flags = ["members", "show-inheritance"] autodoc_member_order = "bysource" # default is alphabetical -special_members = '--init__' +special_members = "--init__" # Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] +templates_path = ["_templates"] # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -57,4 +81,6 @@ # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = ["_static"] + +master_doc = "index" diff --git a/docs/source/index.rst b/docs/source/index.rst index 69926a13..84857c44 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,11 +13,11 @@ Documentation documentation/install documentation/getting_started documentation/user_guide - documentation/tutorials + auto_examples/index .. toctree:: - :maxdepth: 1 + :maxdepth: 4 :caption: Reference api/data diff --git a/examples/README.rst b/examples/README.rst new file mode 100644 index 00000000..864f95df --- /dev/null +++ b/examples/README.rst @@ -0,0 +1,4 @@ +Tutorials and Examples Gallery +================================ + +Below is a gallery of examples \ No newline at end of file diff --git a/examples/plot_dcca.py b/examples/plot_dcca.py new file mode 100644 index 00000000..16da6ac0 --- /dev/null +++ b/examples/plot_dcca.py @@ -0,0 +1,69 @@ +""" +Deep CCA +=========================== + +This example demonstrates how to easily train Deep CCA models and variants +""" + +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import Subset + +# %% +from cca_zoo.data import Split_MNIST_Dataset +from cca_zoo.deepmodels import ( + DCCA, + CCALightning, + get_dataloaders, + architectures, + DCCA_NOI, + DCCA_SDL, + BarlowTwins, +) + +n_train = 500 +n_val = 100 +train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True) +val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val)) +train_dataset = Subset(train_dataset, np.arange(n_train)) +train_loader, val_loader = get_dataloaders(train_dataset, val_dataset) + +# The number of latent dimensions across models +latent_dims = 2 +# number of epochs for deep models +epochs = 10 + +encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) +encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) + +# %% +# Deep CCA +dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) +dcca = CCALightning(dcca) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) + +# %% +# Deep CCA by Non-Linear Orthogonal Iterations +dcca_noi = DCCA_NOI( + latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2] +) +dcca_noi = CCALightning(dcca_noi) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca_noi, train_loader, val_loader) + +# %% +# Deep CCA by Stochastic Decorrelation Loss +dcca_sdl = DCCA_SDL( + latent_dims=latent_dims, N=len(train_dataset), encoders=[encoder_1, encoder_2] +) +dcca_sdl = CCALightning(dcca_sdl) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca_sdl, train_loader, val_loader) + +# %% +# Deep CCA by Barlow Twins +barlowtwins = BarlowTwins(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) +barlowtwins = CCALightning(barlowtwins) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) diff --git a/examples/plot_dcca_custom.py b/examples/plot_dcca_custom.py new file mode 100644 index 00000000..694943b7 --- /dev/null +++ b/examples/plot_dcca_custom.py @@ -0,0 +1,39 @@ +""" +Deep CCA with more customisation +================================== + +Showing some examples of more advanced functionality with DCCA and pytorch-lightning +""" + +import numpy as np +# %% +import pytorch_lightning as pl +from torch import optim +from torch.utils.data import Subset + +from cca_zoo.data import Split_MNIST_Dataset +from cca_zoo.deepmodels import DCCA, CCALightning, get_dataloaders, architectures + +n_train = 500 +n_val = 100 +train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True) +val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val)) +train_dataset = Subset(train_dataset, np.arange(n_train)) +train_loader, val_loader = get_dataloaders(train_dataset, val_dataset) + +# The number of latent dimensions across models +latent_dims = 2 +# number of epochs for deep models +epochs = 10 + +# TODO add in custom architecture and schedulers and stuff to show it off +encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) +encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) + +# Deep CCA +dcca = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) +dcca = CCALightning(dcca) +optimizer = optim.Adam(dcca.parameters(), lr=1e-3) +scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) diff --git a/examples/plot_dcca_multi.py b/examples/plot_dcca_multi.py new file mode 100644 index 00000000..b0e77a76 --- /dev/null +++ b/examples/plot_dcca_multi.py @@ -0,0 +1,61 @@ +""" +Deep CCA for more than 2 views +================================= + +This example demonstrates how to easily train Deep CCA models and variants +""" + +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import Subset + +# %% +from cca_zoo.data import Split_MNIST_Dataset +from cca_zoo.deepmodels import ( + DCCA, + CCALightning, + get_dataloaders, + architectures, + objectives, + DTCCA, +) + +n_train = 500 +n_val = 100 +train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True) +val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val)) +train_dataset = Subset(train_dataset, np.arange(n_train)) +train_loader, val_loader = get_dataloaders(train_dataset, val_dataset) + +# The number of latent dimensions across models +latent_dims = 2 +# number of epochs for deep models +epochs = 10 + +encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) +encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) + +# %% +# Deep MCCA +dcca = DCCA( + latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.MCCA +) +dcca = CCALightning(dcca) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) + +# %% +# Deep GCCA +dcca = DCCA( + latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.GCCA +) +dcca = CCALightning(dcca) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) + +# %% +# Deep TCCA +dcca = DTCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2]) +dcca = CCALightning(dcca) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) diff --git a/examples/plot_dvcca.py b/examples/plot_dvcca.py new file mode 100644 index 00000000..f039c2e2 --- /dev/null +++ b/examples/plot_dvcca.py @@ -0,0 +1,87 @@ +""" +Deep Variational CCA and Deep Canonically Correlated Autoencoders +==================================================================== + +This example demonstrates multiview models which can reconstruct their inputs +""" + +import numpy as np +import pytorch_lightning as pl +from torch.utils.data import Subset + +# %% +from cca_zoo.data import Split_MNIST_Dataset +from cca_zoo.deepmodels import ( + CCALightning, + get_dataloaders, + architectures, + DCCAE, + DVCCA, +) + +n_train = 500 +n_val = 100 +train_dataset = Split_MNIST_Dataset(mnist_type="MNIST", train=True) +val_dataset = Subset(train_dataset, np.arange(n_train, n_train + n_val)) +train_dataset = Subset(train_dataset, np.arange(n_train)) +train_loader, val_loader = get_dataloaders(train_dataset, val_dataset) + +# The number of latent dimensions across models +latent_dims = 2 +# number of epochs for deep models +epochs = 10 + +encoder_1 = architectures.Encoder( + latent_dims=latent_dims, feature_size=392, variational=True +) +encoder_2 = architectures.Encoder( + latent_dims=latent_dims, feature_size=392, variational=True +) +decoder_1 = architectures.Decoder(latent_dims=latent_dims, feature_size=392) +decoder_2 = architectures.Decoder(latent_dims=latent_dims, feature_size=392) + +# %% +# Deep VCCA +dcca = DVCCA( + latent_dims=latent_dims, + encoders=[encoder_1, encoder_2], + decoders=[decoder_1, decoder_2], +) +dcca = CCALightning(dcca) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) + +# %% +# Deep VCCA (private) +# We need to add additional private encoders and change (double) the dimensionality of the decoders. +private_encoder_1 = architectures.Encoder( + latent_dims=latent_dims, feature_size=392, variational=True +) +private_encoder_2 = architectures.Encoder( + latent_dims=latent_dims, feature_size=392, variational=True +) +private_decoder_1 = architectures.Decoder(latent_dims=2 * latent_dims, feature_size=392) +private_decoder_2 = architectures.Decoder(latent_dims=2 * latent_dims, feature_size=392) + +dcca = DVCCA( + latent_dims=latent_dims, + encoders=[encoder_1, encoder_2], + decoders=[private_decoder_1, private_decoder_2], + private_encoders=[private_encoder_1, private_encoder_2], +) +dcca = CCALightning(dcca) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dcca, train_loader, val_loader) + +# %% +# DCCAE +encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) +encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=392) +dccae_model = DCCAE( + latent_dims=latent_dims, + encoders=[encoder_1, encoder_2], + decoders=[decoder_1, decoder_2], +) +dccae_model = CCALightning(dccae_model) +trainer = pl.Trainer(max_epochs=epochs, enable_checkpointing=False) +trainer.fit(dccae_model, train_loader, val_loader) diff --git a/examples/plot_hyperparameter_selection.py b/examples/plot_hyperparameter_selection.py new file mode 100644 index 00000000..0d687e88 --- /dev/null +++ b/examples/plot_hyperparameter_selection.py @@ -0,0 +1,6 @@ +""" +Hyperparameter Selection +=========================== + +This script will show how to perform hyperparameter selection +""" diff --git a/examples/plot_kernel_cca.py b/examples/plot_kernel_cca.py new file mode 100644 index 00000000..581c4476 --- /dev/null +++ b/examples/plot_kernel_cca.py @@ -0,0 +1,100 @@ +""" +Kernel CCA and Nonparametric CCA +=================================== + +This script demonstrates how to use kernel and nonparametric methods +""" + +# %% +import numpy as np + +from cca_zoo.data import generate_covariance_data +from cca_zoo.model_selection import GridSearchCV +from cca_zoo.models import KCCA + +# %% +np.random.seed(42) +n = 200 +p = 100 +q = 100 +latent_dims = 1 +cv = 3 + +(X, Y), (tx, ty) = generate_covariance_data( + n, view_features=[p, q], latent_dims=latent_dims, correlation=[0.9] +) + + +# %% +# Custom Kernel +def my_kernel(X, Y, param=0): + """ + We create a custom kernel: + + """ + + return np.random.normal(0, param) + + +kernel_custom = KCCA( + latent_dims=latent_dims, + kernel=[my_kernel, my_kernel], + kernel_params=[{"param": 1}, {"param": 1}], +).fit((X, Y)) + +# %% +# Linear +c1 = [0.9, 0.99] +c2 = [0.9, 0.99] +param_grid = {"kernel": ["linear"], "c": [c1, c2]} +kernel_reg = ( + GridSearchCV( + KCCA(latent_dims=latent_dims), param_grid=param_grid, cv=cv, verbose=True + ) + .fit([X, Y]) + .best_estimator_ +) + +# %% +# Polynomial +degree1 = [2, 3] +degree2 = [2, 3] +param_grid = {"kernel": ["poly"], "degree": [degree1, degree2], "c": [c1, c2]} +kernel_poly = ( + GridSearchCV( + KCCA(latent_dims=latent_dims), param_grid=param_grid, cv=cv, verbose=True + ) + .fit([X, Y]) + .best_estimator_ +) + +# %% +# kernel cca (gaussian/rbf) +gamma1 = [1e-1, 1e-2] +gamma2 = [1e-1, 1e-2] +param_grid = {"kernel": ["rbf"], "gamma": [gamma1, gamma2], "c": [c1, c2]} +kernel_poly = ( + GridSearchCV( + KCCA(latent_dims=latent_dims), param_grid=param_grid, cv=cv, verbose=True + ) + .fit([X, Y]) + .best_estimator_ +) + + +# %% +# Custom Kernel +def my_kernel(X, Y, param=0): + """ + We create a custom kernel: + + """ + M = np.random.rand(X.shape[0], X.shape[0]) + param + return X @ M @ M.T @ Y.T + + +kernel_custom = KCCA( + latent_dims=latent_dims, + kernel=[my_kernel, my_kernel], + kernel_params=[{"param": 1}, {"param": 1}], +).fit((X, Y)) diff --git a/examples/plot_many_views.py b/examples/plot_many_views.py new file mode 100644 index 00000000..098e52c1 --- /dev/null +++ b/examples/plot_many_views.py @@ -0,0 +1,6 @@ +""" +More than 2 views +=========================== + +This will compare MCCA, GCCA, TCCA for linear models with more than 2 views +""" diff --git a/examples/plot_ridge_reg.py b/examples/plot_ridge_reg.py new file mode 100644 index 00000000..ad63399b --- /dev/null +++ b/examples/plot_ridge_reg.py @@ -0,0 +1,6 @@ +""" +Ridge Regularised CCA: From CCA to PLS +=========================== + +This script will show how CCA and PLS form opposite ends of a ridge regularisation spectrum +""" diff --git a/examples/plot_sparse_cca.py b/examples/plot_sparse_cca.py new file mode 100644 index 00000000..d73208a9 --- /dev/null +++ b/examples/plot_sparse_cca.py @@ -0,0 +1,137 @@ +""" +Sparse CCA Methods +=========================== + +This script shows how regularised methods can be used to extract sparse solutions to the CCA problem +""" + +# %% +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd + +from cca_zoo.data import generate_covariance_data +from cca_zoo.model_selection import GridSearchCV +from cca_zoo.models import PMD, SCCA, ElasticCCA, CCA, PLS, SCCA_ADMM, SpanCCA + +# %% +np.random.seed(42) +n = 200 +p = 100 +q = 100 +view_1_sparsity = 0.1 +view_2_sparsity = 0.1 +latent_dims = 1 + +(X, Y), (tx, ty) = generate_covariance_data( + n, + view_features=[p, q], + latent_dims=latent_dims, + view_sparsity=[view_1_sparsity, view_2_sparsity], + correlation=[0.9], +) +tx /= np.sqrt(n) +ty /= np.sqrt(n) + + +# %% +def plot_true_weights_coloured(ax, weights, true_weights, title=""): + ind = np.arange(len(true_weights)) + mask = np.squeeze(true_weights == 0) + ax.scatter(ind[~mask], weights[~mask], c="b") + ax.scatter(ind[mask], weights[mask], c="r") + ax.set_title(title) + + +def plot_model_weights(wx, wy, tx, ty): + fig, axs = plt.subplots(2, 2, sharex=True, sharey=True) + plot_true_weights_coloured(axs[0, 0], tx, tx, title="true x weights") + plot_true_weights_coloured(axs[0, 1], ty, ty, title="true y weights") + plot_true_weights_coloured(axs[1, 0], wx, tx, title="model x weights") + plot_true_weights_coloured(axs[1, 1], wy, ty, title="model y weights") + plt.tight_layout() + plt.show() + + +# %% +cca = CCA().fit([X, Y]) +plot_model_weights(cca.weights[0], cca.weights[1], tx, ty) + +# %% +pls = PLS().fit([X, Y]) +plot_model_weights(pls.weights[0], pls.weights[1], tx, ty) + +# %% +pmd = PMD(c=[2, 2]).fit([X, Y]) +plot_model_weights(pmd.weights[0], pmd.weights[1], tx, ty) + +# %% +plt.figure() +plt.title("Objective Convergence") +plt.plot(np.array(pmd.track[0]["objective"]).T) +plt.ylabel("Objective") +plt.xlabel("#iterations") + +# %% +c1 = [1, 3, 7, 9] +c2 = [1, 3, 7, 9] +param_grid = {"c": [c1, c2]} +pmd = GridSearchCV(PMD(), param_grid=param_grid, cv=3, verbose=True).fit([X, Y]) + +# %% +pd.DataFrame(pmd.cv_results_) + +# %% +scca = SCCA(c=[1e-3, 1e-3]).fit([X, Y]) +plot_model_weights(scca.weights[0], scca.weights[1], tx, ty) + +# Convergence +plt.figure() +plt.title("Objective Convergence") +plt.plot(np.array(scca.track[0]["objective"]).T) +plt.ylabel("Objective") +plt.xlabel("#iterations") + +# %% +scca_pos = SCCA(c=[1e-3, 1e-3], positive=[True, True]).fit([X, Y]) +plot_model_weights(scca_pos.weights[0], scca_pos.weights[1], tx, ty) + +# Convergence +plt.figure() +plt.title("Objective Convergence") +plt.plot(np.array(scca_pos.track[0]["objective"]).T) +plt.ylabel("Objective") +plt.xlabel("#iterations") + +# %% +elasticcca = ElasticCCA(c=[10000, 10000], l1_ratio=[0.000001, 0.000001]).fit([X, Y]) +plot_model_weights(elasticcca.weights[0], elasticcca.weights[1], tx, ty) + +# Convergence +plt.figure() +plt.title("Objective Convergence") +plt.plot(np.array(elasticcca.track[0]["objective"]).T) +plt.ylabel("Objective") +plt.xlabel("#iterations") + +# %% +scca_admm = SCCA_ADMM(c=[1e-3, 1e-3]).fit([X, Y]) +plot_model_weights(scca_admm.weights[0], scca_admm.weights[1], tx, ty) + +# Convergence +plt.figure() +plt.title("Objective Convergence") +plt.plot(np.array(scca_admm.track[0]["objective"]).T) +plt.ylabel("Objective") +plt.xlabel("#iterations") + +# %% +spancca = SpanCCA(c=[10, 10], max_iter=2000, rank=20).fit([X, Y]) +plot_model_weights(spancca.weights[0], spancca.weights[1], tx, ty) + +# Convergence +plt.figure() +plt.title("Objective Convergence") +plt.plot(np.array(spancca.track[0]["objective"]).T) +plt.ylabel("Objective") +plt.xlabel("#iterations") diff --git a/requirements/deep.txt b/requirements/deep.txt index b4296f02..effe6a81 100644 --- a/requirements/deep.txt +++ b/requirements/deep.txt @@ -1,3 +1,4 @@ torch>=1.9.0 torchvision +pytorch-lightning Pillow \ No newline at end of file diff --git a/setup.py b/setup.py index fa47162f..6406732c 100644 --- a/setup.py +++ b/setup.py @@ -8,28 +8,28 @@ EXTRA_PACKAGES = {} with open("./requirements/deep.txt", "r") as f: - EXTRA_PACKAGES['deep'] = f.read() + EXTRA_PACKAGES["deep"] = f.read() with open("./requirements/probabilistic.txt", "r") as f: - EXTRA_PACKAGES['probabilistic'] = f.read() + EXTRA_PACKAGES["probabilistic"] = f.read() setup( - name='cca_zoo', - version='0.0.0', + name="cca_zoo", + version="0.0.0", include_package_data=True, - keywords='cca', + keywords="cca", packages=find_packages(), - url='https://github.com/jameschapman19/cca_zoo', - license='MIT', - author='jameschapman', + url="https://github.com/jameschapman19/cca_zoo", + license="MIT", + author="jameschapman", description=( - 'Canonical Correlation Analysis Zoo: CCA, GCCA, MCCA, DCCA, DGCCA, DVCCA, DCCAE, KCCA and regularised variants including sparse CCA , ridge CCA and elastic CCA' + "Canonical Correlation Analysis Zoo: CCA, GCCA, MCCA, DCCA, DGCCA, DVCCA, DCCAE, KCCA and regularised variants including sparse CCA , ridge CCA and elastic CCA" ), long_description=long_description, long_description_content_type="text/markdown", - author_email='james.chapman.19@ucl.ac.uk', - python_requires='>=3.6', + author_email="james.chapman.19@ucl.ac.uk", + python_requires=">=3.6", install_requires=REQUIRED_PACKAGES, extras_require=EXTRA_PACKAGES, - test_suite='test', + test_suite="test", tests_require=[], ) diff --git a/tutorial_notebooks/CCA_Tutorial.ipynb b/tutorial_notebooks/CCA_Tutorial.ipynb deleted file mode 100644 index 0d31c1c4..00000000 --- a/tutorial_notebooks/CCA_Tutorial.ipynb +++ /dev/null @@ -1,3214 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "CCA_Tutorial.ipynb", - "provenance": [], - "collapsed_sections": [], - "authorship_tag": "ABX9TyODnGHi5fJq92xQu05U70sT", - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - }, - "widgets": { - "application/vnd.jupyter.widget-state+json": { - "d5f617ab68ec42b0a81333040a03ce70": { - "model_module": "@jupyter-widgets/controls", - "model_name": "VBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "VBoxView", - "_dom_classes": [ - "widget-interact" - ], - "_model_name": "VBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_082d16684fce464e8ab300fd88db366e", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_42a064361141414d9735dd76a475268a", - "IPY_MODEL_00c063db01344e58b6f4155605f5036e" - ] - } - }, - "082d16684fce464e8ab300fd88db366e": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "42a064361141414d9735dd76a475268a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ToggleButtonsModel", - "model_module_version": "1.5.0", - "state": { - "_options_labels": [ - "CCA", - "PLS" - ], - "_view_name": "ToggleButtonsView", - "style": "IPY_MODEL_eb5dd98764d741b8ac1d86b5ac566ad2", - "_dom_classes": [], - "description": "Model:", - "_model_name": "ToggleButtonsModel", - "tooltips": [], - "index": 0, - "button_style": "", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "_view_module_version": "1.5.0", - "icons": [], - "description_tooltip": null, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_76bd49f092e24228924699e3fd9b735c" - } - }, - "00c063db01344e58b6f4155605f5036e": { - "model_module": "@jupyter-widgets/output", - "model_name": "OutputModel", - "model_module_version": "1.0.0", - "state": { - "_view_name": "OutputView", - "msg_id": "", - "_dom_classes": [], - "_model_name": "OutputModel", - "outputs": [ - { - "output_type": "error", - "ename": "AttributeError", - "evalue": "ignored", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/ipywidgets/widgets/interaction.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_interact_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwidget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwarg\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0mshow_inline_matplotlib_plots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_display\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36minteractive_cca\u001b[0;34m(tog)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mrcca\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrCCA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlatent_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mX_tr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_tr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mtest_scores\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mX_te\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_te\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mplot_latent_train_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mplot_widget\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mwidgets\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minteractive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minteractive_cca\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtog\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtog\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'rCCA' object has no attribute 'scores'" - ] - } - ], - "_view_module": "@jupyter-widgets/output", - "_model_module_version": "1.0.0", - "_view_count": null, - "_view_module_version": "1.0.0", - "layout": "IPY_MODEL_f2c9bdfd43f4436dab5dfcf0cf8e8b80", - "_model_module": "@jupyter-widgets/output" - } - }, - "eb5dd98764d741b8ac1d86b5ac566ad2": { - "model_module": "@jupyter-widgets/controls", - "model_name": "ToggleButtonsStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "button_width": "", - "_model_name": "ToggleButtonsStyleModel", - "description_width": "", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "font_weight": "", - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "76bd49f092e24228924699e3fd9b735c": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "f2c9bdfd43f4436dab5dfcf0cf8e8b80": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "4186484e492348708679feb6fc1a5937": { - "model_module": "@jupyter-widgets/controls", - "model_name": "VBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "VBoxView", - "_dom_classes": [ - "widget-interact" - ], - "_model_name": "VBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_e9033710a974478088014f255d040332", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_a10e2175e99542f8984a5e1ac038d6c4", - "IPY_MODEL_07f1da8af80c4a33873945c3d206835d", - "IPY_MODEL_10adc5e3157243a5ba6c391d59254072" - ] - } - }, - "e9033710a974478088014f255d040332": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "a10e2175e99542f8984a5e1ac038d6c4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatLogSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatLogSliderView", - "orientation": "horizontal", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "readout_format": ".9f", - "_model_module": "@jupyter-widgets/controls", - "style": "IPY_MODEL_e17dc362cb1c46379c2ed85b398d9250", - "layout": "IPY_MODEL_098c336771204a41aa40df3c5a820c70", - "min": -10, - "continuous_update": false, - "description_tooltip": null, - "_dom_classes": [], - "description": "cx", - "_model_name": "FloatLogSliderModel", - "max": 0, - "readout": true, - "step": 0.1, - "base": 10, - "value": 1e-10, - "_view_module_version": "1.5.0" - } - }, - "07f1da8af80c4a33873945c3d206835d": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatLogSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatLogSliderView", - "orientation": "horizontal", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "readout_format": ".9f", - "_model_module": "@jupyter-widgets/controls", - "style": "IPY_MODEL_690cbdac87c94770afaa6fae82747ec4", - "layout": "IPY_MODEL_1b30b8c885004e3e8e4aaeb9c703de24", - "min": -10, - "continuous_update": false, - "description_tooltip": null, - "_dom_classes": [], - "description": "cy", - "_model_name": "FloatLogSliderModel", - "max": 0, - "readout": true, - "step": 0.1, - "base": 10, - "value": 1e-10, - "_view_module_version": "1.5.0" - } - }, - "10adc5e3157243a5ba6c391d59254072": { - "model_module": "@jupyter-widgets/output", - "model_name": "OutputModel", - "model_module_version": "1.0.0", - "state": { - "_view_name": "OutputView", - "msg_id": "", - "_dom_classes": [], - "_model_name": "OutputModel", - "outputs": [ - { - "output_type": "error", - "ename": "ValueError", - "evalue": "ignored", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/ipywidgets/widgets/interaction.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_interact_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwidget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwarg\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0mshow_inline_matplotlib_plots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_display\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36minteractive_cca\u001b[0;34m(cx, cy)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0minteractive_cca\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcy\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mrcca\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrCCA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlatent_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcy\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mX_tr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_tr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mtest_scores\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_te\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mY_te\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mplot_latent_train_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mplot_train_test_corrs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/cca_zoo/models/cca_base.py\u001b[0m in \u001b[0;36mtransform\u001b[0;34m(self, views, y, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0mcheck_is_fitted\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattributes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"weights\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m views = check_views(\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;34m*\u001b[0m\u001b[0mviews\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maccept_sparse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m )\n\u001b[1;32m 72\u001b[0m \u001b[0mviews\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_centre_scale_transform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mviews\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/cca_zoo/utils/check_values.py\u001b[0m in \u001b[0;36mcheck_views\u001b[0;34m(copy, accept_sparse, *views)\u001b[0m\n\u001b[1;32m 24\u001b[0m views = [\n\u001b[1;32m 25\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_nd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maccept_sparse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mview\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mviews\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m ]\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/cca_zoo/utils/check_values.py\u001b[0m in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 24\u001b[0m views = [\n\u001b[1;32m 25\u001b[0m \u001b[0mcheck_array\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mview\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_nd\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcopy\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maccept_sparse\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maccept_sparse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 26\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0mview\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mviews\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 27\u001b[0m ]\n\u001b[1;32m 28\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/sklearn/utils/validation.py\u001b[0m in \u001b[0;36mcheck_array\u001b[0;34m(array, accept_sparse, accept_large_sparse, dtype, order, copy, force_all_finite, ensure_2d, allow_nd, ensure_min_samples, ensure_min_features, estimator)\u001b[0m\n\u001b[1;32m 763\u001b[0m \u001b[0;34m\"Reshape your data either using array.reshape(-1, 1) if \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 764\u001b[0m \u001b[0;34m\"your data has a single feature or array.reshape(1, -1) \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 765\u001b[0;31m \u001b[0;34m\"if it contains a single sample.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 766\u001b[0m )\n\u001b[1;32m 767\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mValueError\u001b[0m: Expected 2D array, got 1D array instead:\narray=[ 0.24873507 -1.10304645 -0.78789571 -0.2254131 -1.31878264 -0.4543386\n -1.07305892 1.42619916 -0.40343571 -0.94682405 0.67444959 -0.62848586\n -1.1024857 0.88752992 -0.32314154 0.8410213 -0.01505387 0.12919879\n 0.56317879 -0.63482736 -0.81021457 1.20183681 0.57891931 -0.06749798\n 0.79570091 0.54015971 0.94330287 -1.14283454 -0.06417904 1.129803\n 1.67722708 -0.53593825 0.15916538 0.01933582 -1.24418122 -0.28755981\n -1.57869268 -0.16732161 -1.03801888 0.16991478 -0.30989722 -0.32581508\n 1.55342149 0.07080917 1.87822239 0.44637153 0.81555286 -0.86952237\n -0.69367864 0.01726388 1.21307524 -0.36295402 -1.43145229 0.39799036\n 0.06397831 -0.55101008 -0.37306379 -0.87107367 0.92390675 0.0233241 ].\nReshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample." - ] - } - ], - "_view_module": "@jupyter-widgets/output", - "_model_module_version": "1.0.0", - "_view_count": null, - "_view_module_version": "1.0.0", - "layout": "IPY_MODEL_adc24ca24d0748edb2e13da7e26f635d", - "_model_module": "@jupyter-widgets/output" - } - }, - "e17dc362cb1c46379c2ed85b398d9250": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "098c336771204a41aa40df3c5a820c70": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "690cbdac87c94770afaa6fae82747ec4": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "1b30b8c885004e3e8e4aaeb9c703de24": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "adc24ca24d0748edb2e13da7e26f635d": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "fd8e3e038fa446c48f29f3c719a88ee8": { - "model_module": "@jupyter-widgets/controls", - "model_name": "VBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "VBoxView", - "_dom_classes": [ - "widget-interact" - ], - "_model_name": "VBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_f5588cc41bfe4dfc99d34f0b8c7092ed", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_9032567411404f3d9e85862bff74b0be", - "IPY_MODEL_08ff7d25488b4b53ad841974fa97e8bd", - "IPY_MODEL_53f71b9ae00a4b4cb836d40358392790" - ] - } - }, - "f5588cc41bfe4dfc99d34f0b8c7092ed": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "9032567411404f3d9e85862bff74b0be": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatSliderView", - "style": "IPY_MODEL_1947b94e59c0410b875b21164719c876", - "_dom_classes": [], - "description": "c1", - "step": 0.1, - "_model_name": "FloatSliderModel", - "orientation": "horizontal", - "max": 7.745966692414834, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 3, - "_view_count": null, - "disabled": false, - "_view_module_version": "1.5.0", - "min": 1, - "continuous_update": false, - "readout_format": ".5f", - "description_tooltip": null, - "readout": true, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_b2afc9cdf727400487ce19f9f0ce7d53" - } - }, - "08ff7d25488b4b53ad841974fa97e8bd": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatSliderView", - "style": "IPY_MODEL_16b08cb41fc94870b7578021c7927d1a", - "_dom_classes": [], - "description": "c2", - "step": 0.1, - "_model_name": "FloatSliderModel", - "orientation": "horizontal", - "max": 7.745966692414834, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 3, - "_view_count": null, - "disabled": false, - "_view_module_version": "1.5.0", - "min": 1, - "continuous_update": false, - "readout_format": ".5f", - "description_tooltip": null, - "readout": true, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_654c62b8b9c3489c9f03fc76e81a2ed9" - } - }, - "53f71b9ae00a4b4cb836d40358392790": { - "model_module": "@jupyter-widgets/output", - "model_name": "OutputModel", - "model_module_version": "1.0.0", - "state": { - "_view_name": "OutputView", - "msg_id": "", - "_dom_classes": [], - "_model_name": "OutputModel", - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaMAAAEUCAYAAACGWlk5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3dd3gU1foH8O/WkArJphCkg0IIAUISuoQSgUAiTQRBiiLFAtf7sxC5V1EQr5FrxetFAUUEBJEIEkCK9NCCFDVgC6BIKpvEJJBsnd8fubtsmcm22ezu7Pt5Hh/JtHPOzDt75pw5MyNiGIYBIYQQ4kFiT2eAEEIIocqIEEKIx1FlRAghxOOoMiKEEOJxVBkRQgjxOKqMCCGEeBxVRj6mqKgIiYmJ0Ol0ns4Kr8aMGYPTp0/bteywYcNw4sQJN+eIeBuKfWHHPlVGdvCmAGjVqhXOnz8PiUTitjROnTqF6dOnIykpCcOGDXNbOqZ27dqFvn37uryd06dPY/DgwTzkiAD+F/tr1qxBRkYGEhMTMWzYMKxZs8ZtaRlQ7DegyogHWq3W01ngVVBQECZOnIjnn3/e01khXk5osc8wDLKzs5Gfn481a9Zg48aN2LVrl6ez5ReoMrLhueeeQ1FREebPn4/ExESsXr0af/75J7p06YKtW7diyJAhmDlzJutVielVpV6vx0cffYS0tDT07dsXf/vb31BVVcWaZnp6Og4dOmT8W6vVol+/figoKDCmbfgRqKmpweLFizFo0CDce++9ePvtt43dGEOHDsWPP/4IAPj666/RpUsX/PrrrwCArVu34oknnmBNv0ePHhg3bhzatGljc/8sWrQIH3/8MQCgtLQUXbp0wcaNGwEAf/zxB/r06QO9Xg8AOHToEMaOHYvk5GRMmTIFP/30E+u+qq+vx6JFi5CSkoL09HSsXr3aat9evnwZmZmZSEpKwtNPPw2VSoXbt29jzpw5KCsrQ2JiIhITE1FaWorvv/8eEyZMQO/evTFgwAD861//slku4p+xP2fOHMTHx0MqlaJjx44YPnw4zp07x7osxT6/qDKyYcWKFWjVqhVWrVqF8+fPY86cOcZ5+fn52L17N9auXWtzO5999hkOHDiADRs24NixY2jevDmWLl3KuuyYMWOQm5tr/Pv48eMIDw9HfHy81bJZWVmQSqXYt28ftm/fjry8PGzduhUAkJKSgjNnzhjz2qZNG+Tn5xv/7tOnj/07goNpGmfOnDFL48yZM0hKSoJYLMalS5ewePFiLF26FKdPn8bkyZPxxBNPQK1WW23z/fffx40bN3DgwAF88skn+Prrr62W2bNnD9asWYNvv/0WP//8M3JychAUFITVq1cjOjoa58+fx/nz5xETE4Ply5djxowZOHfuHPbv34/09HSXy+0P/D32GYbB2bNn0blzZ9b5FPv8osrIBQsWLEBQUBCaNWtmc9nNmzfj73//O1q2bAm5XI6nnnoKe/fuZe3myMzMxMGDB1FXVwcA2LlzJ8aMGWO13M2bN3HkyBEsXrwYQUFBUCgUmDVrlrFbwfRkOXv2LObNm2d2QqakpDhddoM+ffrgu+++g16vR35+Ph577DHjlaTpSb9lyxZMnjwZPXv2hEQiwfjx4yGTyXDhwgWrbe7Zswfz5s1D8+bN0bJlS8yYMcNqmenTpyMmJgYtWrTA0KFDcfnyZc48SqVS/PHHH6ioqEBwcDB69erlcrn9nT/E/sqVK6HX6zFx4kTW+RT7/JJ6OgO+rGXLlnYvW1RUhCeffBJi8Z36XywWQ6lUIiYmxmzZdu3aoVOnTjh06BCGDh2KgwcPYvv27azb1Gq1GDRokHGaXq9HbGwsgIaT5Y033kBZWRn0ej3S09Px/vvv488//0RNTQ3i4uIcLbKVtm3bIjAwEJcvX8Z3332HJ598El9++SWuXLmC/Px8TJ8+3ZjX7du3Y8OGDcZ1NRoNysrKrLZZVlZmLAPAvp+joqKM/w4MDGTdjsHy5cvx3nvvIT09Ha1bt8ZTTz2FoUOHOlVe0kDosb9hwwZs374dmzZtglwuZ12GYp9fVBm5QCQSGf8dGBiI+vp64986nQ4VFRXGv1u2bInXXnsNSUlJdm07IyMDubm50Ov16Ny5M9q1a2e1jOFK89SpU5BKrQ9lu3bt0KxZM2zYsAHJyckICQlBZGQkvvjiC2MXAh9SUlKwd+9eaDQaxMTEICUlBdu3b8dff/1lPOljY2Mxf/58PP744za3FxUVhZKSEmP3SElJid15MT0mBu3bt8dbb70FvV6Pffv2YeHChTh9+jSCgoLs3i4xJ+TY//LLL/HRRx9h48aNNitdin3+UDedHSIjI3H9+vVGl+nQoQNUKhUOHz4MjUaD//73v2Z9wg899BDeeecd3LhxAwBQUVGBAwcOcG5v9OjRyMvLw+eff46MjAzWZaKjozFw4EC8/vrrqK2thV6vxx9//GHsngAarhA3bNhg7Jaw/JuNXq+HSqWCRqMBwzBQqVSs/duWaSQnJwMA+vbtiw0bNiApKck4DHfSpEnYvHkzLl68CIZhcPv2bRw+fBi1tbVW20tPT8eHH36Iv/76C6WlpWZXlLYoFApUVVWhpqbGOG3Hjh2oqKiAWCxGWFgYAPBWEQudv8X+119/jbfffhuffPKJXQN4KPb545258jJz587Ff//7XyQnJ3PesA0NDcWSJUvwz3/+E4MHD0ZgYKDZVdWMGTMwbNgwPProo0hMTMSDDz6I77//njPN6Oho9OrVC+fPn8fo0aM5l3vjjTeg0WgwevRopKSkYOHChSgvLzfOT0lJwa1bt8xOSNO/2eTn56NHjx6YO3cuioqK0KNHD8yePZtzecs0kpKSUF9fbzxBASAhIQHLli3D0qVLkZKSghEjRiAnJ4d1e08++SRatmyJ4cOHY9asWRg5ciRnV4mlTp06YcyYMUhLS0NycjJKS0tx7NgxjBkzBomJiVi+fDnefvttu+51EP+L/XfeeQdVVVV44IEHjKPSXnrpJc7lKfb5I6KP6xFvt2nTJuzevduhq0RChMCfYp9aRsTrlJWVGUcpXblyBZ988gnS0tI8nS1C3M6fY58GMBCvo9FosGTJEvz5558IDQ3FmDFjMHXqVE9nixC38+fYp246QgghHkfddIQQQjyOKiNCCCEeR5URIYQQj/O5AQyVlbeg17Pf5lIoQqBUWj9I5uuEWi5AuGWzLJdYLEJ4eLDL26X4Fw6hlgtwLv59rjLS6xnOk9EwX4iEWi5AuGVzR7ko/oVFqOUCHC8bddMRQgjxOKqMCCGEeJzPddMRQu7Q6bSorCyHVqtGWZnY+GVRIfGGckmlcoSHR0EioZ9Md6E9S4gPq6wsR7NmQQgObgmZTAKtVniVkVQq9mi5GIbBrVvVqKwsR2RkrO0ViFOom44QH6bVqhEcHMb6LRvCD5FIhODgMGi13J9RIa6jyogQH0cVkfvRPnY/qowIIYR4HN0zIoTw6oEHMhEYGIhPP91s/KroAw9k4o033kbHjp1d3v7x40ewZs2HZtMqK5VgGODrr/e6vH3iGVQZEUJ4V1dXh717dyM9nf2z4a4YNCgVgwalGv+uqqrC7NkP44kn/ubQdrRaLaRS+gn0FnQkCPFDJwtKkHOkEMpqFRRhAZiQ2gn941vaXtFOjz46Fx9/vBppaSMhk8mM0//88zpWrHgNVVWVkEgkmDv3SfTrNwAAMGhQMubOfQJHjx7GX3/9hSefXIghQ4Y3mo5Op8OSJS9g6NA0DB9+H4CGbwJ99NEHuHDhO6jVGnTu3BnPPPMCgoKCsHz5y5BIJPjjj99x+/ZtrFu3CRs2rMPevbsBAHFx8Xj66ecQFBTE274g9qF7RoT4mZMFJfh0z09QVqsAAMpqFT7d8xNOFpTwlkbXrnHo0qUrvvrqS7Ppr7zyT9x330h8+ulmvPjiMixb9iIqKyuN84ODg7FmzXq8+OIreOedf9tM58MP3wfDMHj88QXGaRs3forg4GCsXr0en376ORSKKHz22SfG+b/++gvefHMl1q3bhJMn87B3726sWvUx1q/fAp1Oh3Xr1vCwB4ijqGVEiJ/JOVIItcVzO2qtHjlHCnltHc2d+zgWLJiPjIyxAACGAX777ReMHn0/AKBDh47o3LkLCgp+wKBBgwEAw4ePBADExyfg5s1yqFQqSKWBrNs/dOgADhzYh7VrP4NEIjFOz8s7ilu3buHw4YMAAI1Gjc6d7zbOHzJkOAIDG7Z59uwZDB8+AsHBIQCA+++fgHfftV0JEv5RZUSInzG0iOyd7qy2bdujf/+B2LJlo93ryOVyADBWLjqdDrm5O7B58+cAgKlTp2PEiHRcu3YV//73v7BixbsID48w2wbDAM88k4WkpBTWNIKC2Cs34lm8ddNdvXoVkydPxsiRIzF58mRcu3bNapmVK1eif//+GDt2LMaOHYtXXnmFr+QJIXZShAU4NN0Vjz46Fzk5W3H79m2IREDnzvdgz55cAMC1a1dRWPgL4uMTGt1GRsZYrFu3CevWbcKIEem4ffsWFi9+FnPnPolu3bpbLT9o0GBs2bIRKlU9AOD27Vu4du0q67aTk/vg4MH9uH37FhiGQW7udqSk9HWx1MQZvLWMlixZgqlTp2Ls2LHYsWMHXnrpJaxfv95quXHjxmHRokV8JUsIcdCE1E74dM9PZl11cqkYE1I78Z5WdHQMRo4cjc2bNwAAlix5FStWvIYvvtgEiUSCf/5zKcLDwx3a5rZtW3Hjxp/46qsvre5JffDBajz88CysXfshHntsxv+Glovw6KNz0L59B6tt9e8/EIWFv2LevEcAAF27dsPMmbOdKyxxiYhhGJc/qKFUKjFy5EicPn0aEokEOp0Offv2xb59+xARcacJvXLlSty+fdulykiprOX8TkZUVCjKy2uc3ra3Emq5AOGWzbJcYrEICkWIy9u1jP+Skt/RsmU7AI69w83do+n45Ol30xmY7ms+CDX2Aefin5eWUXFxMWJiYoz9vBKJBNHR0SguLjarjABg165dOH78OKKiorBgwQIkJiY6lJatAkVFhTqWeR8h1HIBwi2bO8plGf9lZWJIpXd6203/3Zh7e7bCvT1b8Zo3d7K3XO4kFot5P6ZCjX3A8bI16QCGKVOmYP78+ZDJZMjLy8MTTzyB3bt3O9RMp5aRsAi1bE3VMtLr9cZWg7e0IPjmLeXS6/W8xqpQYx9wLv55udyIjY1FaWkpdDodgIYRMGVlZYiNNX/delRUlPEBuIEDByI2Nha//vorH1kghBDiw3ipjBQKBeLi4pCb2zBKJjc3F3FxcVZddKWlpcZ/X758GTdu3ECHDtY3FQkhhPgX3rrpXn75ZWRlZeGDDz5AWFgYsrOzAQBz5szBwoULkZCQgLfeegsFBQUQi8WQyWR44403EBUVxVcWCCGE+CjeKqNOnTph69atVtNXr15t/LehgiKEEEJMeX6ICiGEEL9HrwMihPCGvjXUtHzpeTFbqDLimZCCgwhX9akTuJmzDdoKJaQRCkROmIiw/33KwRV8fGuIvjNkH8Pb1w1v0jC8fR2AT/7m0BHnkdCCgwhT9akTKF2/DoxaDQDQVihRun4dAPBSIRlYfmvI2e8MiUQidO3ajb4zZKGp3r7eVOieEY8aCw5CvMXNnG3GisiAUatxM2cbr+lYfmvI2e8Mbdz4BX1niEVTvX29qVDLiEdCCw4iTNoKpUPTncH2rSFnvzMkEonoO0MsFGEBrL8t7nj7elOgyohHQgsOb+FN9+G8KS/OkkYoWCseaYSCl+1zfWuIvjPEr6Z8+3pToG46Hk1I7QS5xQsdvSU4ThaU4LkP8vDo6wfx3Ad5vH5i2p2a4hPZvpgXV0ROmAjR/z5iZyCSyxE5YaLL227sW0P0nSF+9Y9viZnpXY0Xu4qwAMxM7+pzF0cG1DLikSEIvO3K2ZcHVnjTTVpvyosrDIMU3DGazta3hj77bJ3D3xkSiUTo0iXOpe8M1dZpUFmjgk6nh0QiRnhoAEICZU5vz1v0j29pjD1Dq331zkte89vjCF6+Z9SU6K3djnvugzzO7sMVTwx0JWsus1W2R18/yDnv46xh7sgSJ0fy4u3fM/Ilrpartk4D5V/1MP2pE4lEUDRv5lCF5M3fM7K84AQaemU81VLy2PeMiHffS/DlgRXedB/Om/JC7FdZo4LlNTfDMKisUQmidQQIo9VO94x44O33Erh+LH3hR9Sb7sN5U16I/XQ69lYV13Rf5MsXnAbUMuKBt1+V+PKoGz7uw/HVavXWe4IMw0AkEnk0D95MIhGzVjwSif3X4t5+N0MIrXaqjHjg7Vcl3vojai/Tm7SO4nvwhit5cQexWAKdTgupVBjdTe4QHhrAes8oPNT+H2qdTguxWOJw2k3Vfe/LF5wGVBnxwNuuSrhOAG/6EW0q3t5qdVVgYAhqaqrQooUC1OvOznBfyNnRdAyjR01NJQIDHRuA0pSjWF254PSW+91UGfHAm65KfHkYtzt4e6vVVSEhzVFZWY7S0j8hFoug1wvnPoiBWCzmpVwyADIRAD1Q+1fDf/YRQS5vhpCQ5g6l19QXQs5ccHrT7wVVRjzwpm4wobcEHOVtrVa+iUQiREREA6BHG9zpZEEJcjafcOj89oULIW/6vaDKiCfe0g1mzwngLc3ypuBNrVbivRo7J5xtPfjChZA3VZjUySwwtoZxe/swdL4J7ZUphH+2zgln38bvC48CeNNjH9QyEgjTKztLpieANzXLm4q3tFqJd7J1TjjbevCm7nsu3tRzQJWRALC9CsTA8gTwpmY5aXr+1EVrL1vnhKPdbb60jx2pMN1dLr+ojHwpOJzBdmUHsL97zhf6sYl78D1ySijnla1zwpHWgzeNTrOXPT0HTVEuwd8z8od7JI60dnyhH5u4B59fIhbSeWXrnHDkvqNQv/bcFOUSfMvIH+6RONLacbQf25GrXz6vlNm25Ui+iTU+u2i94bxqytc82XvfUajd4E1RLsFXRkINDlOO3oS098RypGnORzOeaxCGslqFj3MvQSQWQatjnN6+o3kQWoXHZxetp88rb33Nk1C7wZuiXILvpvOmoYvu4q7hy440zV1txlt2+1jSMTBWRM5s35k8+HLXExs+u2g9fV7ZE2+e+LpxY/vYMj+Hv7vu9vzwpSm69wXfMmJrNQANPzTPfZAnmCtfdwxfduTq19UrZa5BGLbweSXO9QO3af/Pgmgt8TnU2NNDgm3Fmz0tJ3e0grn2MQCr/Ly56RxCAqV4KO0er4+nphimztuXXq9evYqsrCxUVVWhRYsWyM7ORvv27c2W0el0ePXVV3Hs2DGIRCLMnTsXkyZNcigdti+9Vp860fAJ5coKSMMjjJ9QNkzXVChRKwvBwfBeuBzW0bheXPUVDK04j1DtLYjEYkCvhy60BY5EJOKMrA36aK4jteI8JDVVEAUHQyQSQV9bC/xvWdNptj7dbMyjyWeeAfZPP5suy5YW17/Z8sCWLtd8y23tC+2BM7I2VmXpo7mOETXfm21z2TnG7AcirvoKhlScR5j2FmSNpfu/Y7ZNGmd2bGwx3b7YgWPDtm8N858+XM+6/WppMA5HJBrzZ/iCZnzNFc5tdcoc2SRfegXgULzYG6em50xooMx4Hpiuz/ZjbrpPbJ0zXP+2WvbWLeN5XRDakbUr1zIebtdr0Uynsjp2irAAvNhbhBtbvoC4pspsPttxdeScsYwB03W4vrYMNP5FVq60HDnv7f29cjYtNs586ZW3ymjGjBmYOHEixo4dix07dmDbtm1Yv3692TLbt2/Hzp07sXr1alRVVWHcuHHYtGkTWrdubXc6lidj9akTKF2/DoxabZwmkssROmAgak7kmU3XiKTYHdUPl8M6Iq76CkaXn4SM0VmloRFJcDGkE3rWFrLO5yKSyxEzY5bVgWLLIyQSACJAp7WZb0eY5oFr3zQ23xQjlWF3VD/8ENzBOC3h1lWMLj8FkVZjts3qYeOx+log1Fo96761la5GJMHuqP6cFZJEBOM9o8aOXWP7pbF9K5LL8W3sQJyRtWHdvmX++miuY3gx97Y6P/U4RN0SjdPcVRnZOoZsHIlTLUQARJBC79T6fGKLRwA248H02MVVX8G4qjOcscd2XB05Z9jWARr/ZD3A/hiGI/vT1Tw6mxbA3sK8f8jdDldGvNwzUiqVuHTpEjIyMgAAGRkZuHTpEioqKsyW2717NyZNmgSxWIyIiAikpaXhm2++cSntmznbrHYgo1aj+ugRq+kyRoshFecBAEMqznMGr4zRoXfNrw792BnSvZmzzfi3oY/453UbrQ+yTmdWETWWb2fzwLVvGptvSqTVYFTtD2b3okbV/mBWERm2qT+QC5lUhJBAKeu+tZWujNEZj40lRVgAHs3ohkdGx0ERFtDoseNia98yajVSK85DLhWzbt8yf0k3zjS6rT8+2+hQ/pxl6xiysYzTxrYlBWNWETm6Pp9EWg0GlZ+zmm4rHkyP3bDKC43GHttxdeScYVsHsH0vja3V5Mj+dDWPzqbFdZ/VmfthvNwzKi4uRkxMDCSSho9PSSQSREdHo7i4GBEREWbLtWrVyvh3bGwsSkocu6loWbv+UlnBviDHK+fDtLfM/s9FBOcajNrKCkRFheLwd9ex/pufodLobKZlitHr4eo3Ow154No3tuabktT+hXVLRhn/zhu3hnW5MO0t3KrXIUAm4SyvrXRN14sKD8SM9DgMSTLvJrx/yN2cebDJxmcIJLV/YcGDvSBdus5m/mwdU9VNJaKiQh3Ooi12x78NhmPh7LZcXd9ZbPvdnvMrTHsLATIJQjS1jW7X2dhtbB0AmJURj/e3XoRKw15pRoUHurw/Xcmjs2ltP36S9T7r+j2X8fE/Rzi0TZ8bwGDZTSENj2joF7X0v35SSzXSYABAtTQYzRsJYgYipyokaXgEystrsC63wBh4ttLiI122PHDtG1vzTdVIg/H14V+N/dlc61T/b7+qNDrUykIQynLS20rXsA0AyJ7XHwBYPx1gT75ZccSE6XbvadsCVyIUjeZPLhVDH9oCkpoqzm0FRCqa5J6Rs/vCcCwsp9m7LVfXd5ZpjBhwxZvlMjNGdYGslPvYNnZcHTlnLNcBgPi2LTBjVBds2v8zbtWbV0hyqRjjBnVweX+6kkdn0yqvrGOdf7OyzjPddLGxsSgtLYVO17CTdTodysrKEBsba7VcUVGR8e/i4mK0bOnaaIzICRMhksvNponkcoQNTmWdLkrLgCIsAIcjEqERsdfFGpEE50Lvhkbk2GeGRXK5cWCCabO7IS2LbUkkgMQ8fWfT5coD175pbL5lfg6G9zIb3sy2jkYkweGIO/dHDob3cjhd023Y6tKwlW82XDFxJ30ptknj8NwHeajsk8aSPykORyQah83fNflBzm2J5HK0nT7Nofw5y9l9YTgWtralhQhai58JR9bnEyOV4XhUb7NpcqkYGD6m0XRFcjnunjkV/eNbcsbed3f14TyutmKDK03LfdQ/viVWPp2KOZndoAgLgAiNP4bhyP505Lx2ZR3L9bjO1cjwQLu2ZUry8ssvv+zwWhaCgoJw7NgxSKVSdO3aFTt37kR5eTmmTTM/ITUaDbZv347MzExUVlbitddew/PPP4/mze3/gmJdnRqmQy4CWreBTKFA/bVr0NfXQxqhQPSUqVCMzrwzva7OOL3DiGEYkdIWA4f3RrPoSON8iMUAw0AX2gJHW/bDsbB46ELD0VZfBbG6HqLgYIgDAhr6Vf+3rOk0w/YNN/WOf1+EOlVD5XwzIBxV0hDEqpQI0Gsgi1Ageuo0hCQmov7aNejq6lAtDcb+yBScUSSYLStiSYvr35Z5MNs3JvuAa74oOBh1ejGkjM6Yn8thHaHTM/i9pBojUtqarWOab9OBB0xUS4welWgz3b9+LYRIXW+2DblUjIfS7kGbaO6rKLZ823NsLGMCYjEYhkGNSfp1Kh0uVsvRPbEzgitKjPmPnTYNw2aOxYiUtmgTHWKVB0O6hrTajBiO27dNboKLRAgKcv2HutH4N8mHvTHCtV91dXWolYVgX2QflER2NJ4H9q5vz3Hh+rfVshoNpBEKxDw0FQFJffF7STXqVDoowgLwUNo96DMkkTNde86J2GnTMHBaptVx1dXV2YwNtnw3to8AoE10CEaktMVj43tgYHwMZ6w3tj8dKaM9v1fOphUaJMePV5TQmbTW5VIx5o5LQHTzZsZp9sQ/b6PpCgsLkZWVherqaoSFhSE7OxsdO3bEnDlzsHDhQiQkJECn02Hp0qXIy8sDAMyZMweTJ092KB22oa0G3vBFSAO2N2lzDeHkGvZpGGHTlOVqbNTPx1nDzP52pIxcCv6owrrcAo89w2Nr3zvLmaGt9uAr/n3pTROeOq/dFRsG3vR75Qq+RtPxds+oU6dO2Lp1q9X01atXG/8tkUjwyiuv8JWkV3PkITFPP0Boyp3vuWMzJKkN4tu2cD7DLvL0a208wRffLO0J/hgbzuDrgXufG8DgS+w9SN70ES53vefOWwn1XWKN8YaXnPqCpo4NX2qtugNVRl7CW37UvalibAre1CptKr54xe+JH+qmjA1qrVJlRFh4S8XYFPyt8gV8rzV4+LvrHvmhborY4HpTPeB/rVXBVUb+3tQljvOnyhfwvdbg+j2XPdat6M7YsKxk2Xhza5VvgqqMPHUFRezDPuqG/7cUkMb5WmvwJseDlb7+Q81WyVry1taqOwiqMvLkFRRpHFefeFhoM4+OpvNXvtQajAwPZH3S39d/qLkqWQNvbq26g6A+rifUKygh4BrBtX7PZQ/liPDF3R+xm5Ee5/YPu3lCY28pEIvuXEgL5eOOtgiqZSTUKyhv4cr9OK4LAltXh8T9XDmuTTEKbEhSG1TX1PtMt6K9ZqTHYeUXF8wu0kw/lQL4160GQVVGbAdXCFdQ3sDVHx2uEVzOvMOK8MfV49pUzyz5UreivdgqWZVGh9o680/L+MutBkFVRkK9gvIGrv7ocI3gmpEex3teif1cPa6++MySN7GsZLlex+UP+1NQlREgzCsob+Dqjw7XCK4hSW0E8X4uX+XqcfW1Z5a8nT/vT8FVRsQ9+DhJ6ELB+12qT1wAACAASURBVLh6XH3tmSVv58/7kyojD/DFB3P9+SQRMlePq689s+Tt/Hl/UmXUxHz1HVT+fJIIGR/HlVq8/PLX/UmVURPz5Tcm++tJInR0XIk3ENRDr76ARh8RQog1ahk1MX8eLUMI8Q7eeN+aKqMmRgMBiC/zxh8x4hg+7lu7Iw6om66J9Y9viZnpXY0tIUVYAGamd6UTmng9w4+YoWVv+BHzl3enCUVj963t4a44oJaRB9ANY+KLfHnwDbnD1fvW7ooDahkRQuxCg2+Egev+tL33rd0VB1QZEULs4uqPGPEOE1I7ufRJDnfFAVVGhBC7uPojRryDq/et3RUHdM+IEGIXeguHcLhy39pdcUCVESHEbjT4hgDuiQPqpiOEEOJx1DIibkUPSRJC7OFyZVRXV4cXXngBBQUFkEgkWLRoEYYOHWq13OnTpzF37ly0b98eACCXy7F161ZXkydezFffUE4IaXouV0Zr165FSEgI9u/fj2vXrmHatGnYt28fgoODrZbt1KkTcnJyXE2S+Ah6SJIQYi+X7xnt2bMHkydPBgC0b98e3bt3x9GjR13OGPF99JAkIcReLreMioqKcNdddxn/jo2NRUkJ+zuKrl27hvHjx0MqlWLq1KkYP368w+kpFCGNzo+KCnV4m77AF8sVFR6I8so61umm5fHFstnDHeWi+BcWoZYLcLxsNiuj8ePHo6ioiHXeiRMn7E4oPj4eR44cQWhoKK5fv45HHnkEMTExGDBggP25BaBU1kKvZ1jnRUWFory8xqHt+QJfLde4QR1Y31A+blAHY3l8tWy2WJZLLBbZrEjsQfEvHEItF+Bc/NusjL766qtG57dq1Qo3btxAREQEAKC4uBh9+/a1Wi4k5E5G2rRpg7S0NJw7d87hyoj4DnpIkhBiL5e76UaNGoUtW7YgISEB165dww8//IA333zTarmysjJERUVBJBKhqqoKeXl5+Nvf/uZq8sTL0UOShBB7uFwZzZ49G1lZWbjvvvsgFouxdOlSYyvo3XffRXR0NB566CHs27cPn3/+OaRSKXQ6HcaNG4e0tDSXC0AIIcT3iRiGYe+A9lLUZy4sQi0b3TPiD5XL9zgT//Q6IEIIIR7nc68DEotFLs33VUItFyDcspmWi68yUvwLi1DLBTge/z7XTUcIIUR4qJuOEEKIx1FlRAghxOOoMiKEEOJxVBkRQgjxOKqMCCGEeBxVRoQQQjyOKiNCCCEeR5URIYQQj6PKiBBCiMdRZUQIIcTjqDIihBDicVQZEUII8TiqjJpIVlYW3n77bbuWHTZsGE6cOOFwGqtWrcI//vEPh9fzZmfPnsXIkSPtWvb06dMYPHiwm3NEHEWx7xx/i32f+4QE4TZ//ny3p/HOO+/g22+/RWFhIR5//HEsWLDAreklJydj7969vGwrKysLMTEx+Pvf/87L9oj3cHfsK5VKLF++HGfOnEFdXR3uvvtuvPDCC+jZs6fb0vS32KeWEXFIu3bt8OyzzyI1NdXTWSGkydy+fRsJCQnIycnBmTNnMH78eMydOxe3bt3ydNYEgyojE8OGDcOaNWuQmZmJXr16YfHixbh58yYee+wxJCYmYtasWfjrr7+My3/77bcYM2YMkpOTMX36dBQWFhrnXbp0CePHj0diYiKefvppqFQqs7QOHTqEsWPHIjk5GVOmTMFPP/1kM38XL17EwIEDodPpjNP279+PzMxMAMDKlSvx7LPPGudduHABU6ZMQXJyMu6//36cPn0aAHDq1CnjOgDwyCOPYOLEica/p06digMHDrDmYfz48UhNTUVwcHCjeVWpVOjRowcqKioAAP/973/RrVs31NbWAmhoYS1fvhwAoFarkZ2djSFDhmDAgAF46aWXUF9fD8C6+6GgoADjxo1DYmIiFi5ciKefftqqC+jjjz9G//79MWjQIGzbtg0AsGXLFuzcuRNr165FYmKi8Ur6o48+wr333ovExESMHDkSJ0+ebLRcQkWx34Ar9tu0aYNHHnkE0dHRkEgkmDx5MjQaDa5evWq1LMW+kxhiNHToUGbSpElMeXk5U1JSwvTr148ZN24cU1BQwNTX1zPTp09nVq5cyTAMw1y5coXp2bMnc/z4cUatVjMfffQRk5aWxqhUKkalUjFDhgxhPvnkE0atVjN79uxhunXrxrz11lsMwzBMQUEB069fP+bChQuMVqtlcnJymKFDhzIqlcqYj7y8PNY8Dh8+nDl+/Ljx7wULFjAffvghwzAM89577zHPPPMMwzAMU1JSwvTp04c5fPgwo9PpmOPHjzN9+vRhlEolU1dXx3Tv3p1RKpWMWq1m+vfvzwwaNIipqalh6urqmISEBKaioqLRffXMM88w7733XqPLTJ06lfnmm28YhmGYRx55hBk+fDhz+PBh47x9+/YxDMMwy5cvZ+bNm8dUVlYyNTU1zLx585h///vfDMMwzKlTp5h7772XYRjGuF/XrVvHqNVqZu/evUx8fLxxv546dYqJi4tj3nnnHUatVjOHDx9mevTowVRVVTEMwzCLFi0yLsswDFNYWMgMHjyYKSkpYRiGYa5fv878/vvvjZZJqCj27Y99hmGYS5cuMd27d2eqq6tZ51PsO45aRhYefvhhREZGIiYmBsnJyejRowe6deuGgIAA3Hfffbh06RIAYPfu3UhNTcXAgQMhk8kwe/Zs1NfX4/z587h48SI0Gg1mzpwJmUyGUaNGISEhwZjGli1bMHnyZPTs2RMSiQTjx4+HTCbDhQsXbOZvzJgxyM3NBQDU1tbi6NGjGDNmjNVyO3bswODBg5GamgqxWIyBAweie/fuOHLkCJo1a4aEhAScPXsWBQUF6Nq1K3r37o1z587hwoULaNeuHcLDw13elykpKcjPz4dWq8XPP/+M6dOnIz8/HyqVCj/88AOSk5PBMAy++OILLF68GC1atEBISAjmzZuHXbt2WW3v4sWL0Gq1mDFjBmQyGUaMGGG2XwFAKpXiySefhEwmQ2pqKoKCglivXgFAIpFArVajsLAQGo0GrVu3Rtu2bV0ut6+i2Lcv9mtra/H888/jqaeeQmhoKOsyFPuOowEMFiIjI43/DggIMPu7WbNmuH37NgCgrKwMrVq1Ms4Ti8WIjY1FaWkpJBIJYmJiIBLd+e676bJFRUXYvn07NmzYYJym0WhQVlZmM3+ZmZmYMmUKXnnlFezfvx/dunXDXXfdZbVcUVERvvnmGxw6dMg4TavVom/fvgAaTpYzZ84gJiYGKSkpCAsLQ35+PuRyOfr06WMzH/bo06cP/vWvf+HSpUu45557MHDgQPzjH/8wO+mVSiXq6uowYcIE43oMw0Cv11ttr6yszGq/xsbGmi3TokULSKV3wjowMNB4zCy1a9cOixcvxsqVK/Hbb79h0KBBxhu9/ohi33bs19fXY/78+ejZsyfmzZvHuRzFvuOoMnJSdHQ0fvnlF+PfDMOguLjYGDClpaVgGMYYPEVFRWjTpg2AhiCaP38+Hn/8cYfT7dy5M1q1aoWjR48iNzcXGRkZrMvFxsZi7NixePXVV1nn9+nTB6+//jpatWqFOXPmoHnz5njxxRchk8kwbdo0h/PFJjExEVevXsX+/fuRkpKCzp07o6ioCEeOHEFKSgoAIDw8HM2aNcOuXbtsnghRUVFW+7W4uNi4X20xPZENMjMzkZmZidraWrz00kv497//jRUrVjhYUv/ir7GvVqvx5JNPIiYmBkuXLm00rxT7jqNuOielp6fjyJEjOHnyJDQaDT7++GPI5XIkJiaiV69ekEqlWL9+PTQaDfbt24cffvjBuO6kSZOwefNmXLx4EQzD4Pbt2zh8+LDxBqctGRkZ+PTTT5Gfn49Ro0axLnP//ffj0KFDOHbsGHQ6HVQqFU6fPo2SkhIAd06W77//Hj169MDdd9+NGzdu4PvvvzeeLGw0Gg1UKhUYhoFWq4VKpTK7qWwqMDAQ3bt3x8aNG41XnImJidi8ebMxDbFYjEmTJuG1116DUqkEAJSWluLYsWNW2+vVqxckEgk2bNgArVaLAwcOmO1XWxQKBf7880/j31euXMHJkyehVqshl8sREBAAsZhOCVv8MfY1Gg0WLlyIgIAAZGdn24wTin3H0ZnnpI4dO2LFihVYtmwZ+vXrh0OHDmHVqlWQy+WQy+VYuXIlvvrqK/Tp0we7d+/GfffdZ1w3ISEBy5Ytw9KlS5GSkoIRI0YgJyfH7rQzMjKQn5+Pfv36ISIignWZ2NhYfPDBB/jwww/Rv39/pKamYu3atcYugKCgIMTHx6Nz586Qy+UAGk6WVq1aQaFQcKb94osvokePHsjNzcWqVavQo0cP7Nixg3P5lJQUaLVa9OjRA0DDVemtW7fMTvrnnnsO7dq1w4MPPojevXtj1qxZrH3dhv365ZdfIiUlBV9//TWGDBlizL8tDzzwAH777TckJyfjiSeegFqtxptvvom+ffti0KBBqKiowP/93//ZtS1/5o+xf/78eRw6dAh5eXlISUlBYmIiEhMTcfbsWc68Uuw7RsQwDOPRHBDigkmTJmHKlClmw3MJ8QdCi31qGRGfcubMGZSXl0Or1eKrr77Czz//jHvvvdfT2SLE7YQe+zSAgfiUq1ev4umnn0ZdXR1at26N9957D9HR0Z7OFiFuJ/TYp246QgghHkfddIQQQjyOKiNCCCEe53P3jCorb0GvZ+9ZVChCoFTa97yCLxFquQDhls2yXGKxCOHhjb9c1h4U/8Ih1HIBzsW/z1VGej3DeTIa5guRr5brZEEJco4UQlmtgiIsABNSO6F/fEuzZXy1bLa4o1wU/8Ii1HIBjpfN5yoj4jtOFpTg0z0/Qa1teNhQWa3Cp3saPhdgWSERQvwb3TMibpNzpNBYERmotXrkHCnkWIMQ4q+oMiJuo6xWOTSdEOK/qDIibqMIC3BoOiHEf1FlRNxmQmonyKXmISaXijEhtZOHckQI8VY0gIG4jWGQgq3RdIQQQpURcav+8S2p8iGE2ETddIQQQjyOKiNCCCEeR5URIYQQj6PKiBBCiMfRAAZCiN3sedcgIc6gyogQYhd61yBxJ+qmI4TYhd41SNyJt5bR1atXkZWVhaqqKrRo0QLZ2dlo37692TIrV67Epk2bjN9t7927N5YsWcJXFgghbkTvGiTuxFtltGTJEkydOhVjx47Fjh078NJLL2H9+vVWy40bNw6LFi3iK1lCSBNRhAWwVjz0rkHCB1666ZRKJS5duoSMjAwAQEZGBi5duoSKigo+Nk8I8QL0rkHiTry0jIqLixETEwOJRAIAkEgkiI6ORnFxMSIiIsyW3bVrF44fP46oqCgsWLAAiYmJDqWlUIQ0Oj8qKtSxzPsIoZYLEG7Z3FEuT8b//UNCERbaDOv3XMbNyjpEhgdiRnochiS1cVuaBhQjvsfRsjXpaLopU6Zg/vz5kMlkyMvLwxNPPIHdu3cjPDzc7m0olbWcn7ONigpFeXkNX9n1GkItFyDcslmWSywW2axI7OHp+I9v2wLZ8/qbTXN3mv4SI0LiTPzz0k0XGxuL0tJS6HQ6AIBOp0NZWRliY2MtMhgFmUwGABg4cCBiY2Px66+/8pEFQgghPoyXykihUCAuLg65ubkAgNzcXMTFxVl10ZWWlhr/ffnyZdy4cQMdOnTgIwuEEEJ8GG/ddC+//DKysrLwwQcfICwsDNnZ2QCAOXPmYOHChUhISMBbb72FgoICiMViyGQyvPHGG4iKiuIrC4QQQnwUb5VRp06dsHXrVqvpq1evNv7bUEERQgghpugNDIQQQjyOKiNCCCEeRy9KJYQQP+ONb1+nyogQQvyIt759nbrpCCHEj3jr29epMiKEED/irW9fp8qIEEL8CNdb1j399nWqjAghxI9469vXaQADIYT4EcMgBRpNRwghxKP6x7f0eOVjiSoj4hJvfF6BEOJ7qDIiTvPW5xUIIb6HBjAQp3nr8wqEEN9DLSPiNG99XoE0LeqqJXyglhFxmrc+r0CajqGr1nABYuiqPVlQ4uGcEV9DlRFxmiefV6g+dQJXnn8Gvzw2C1eefwbVp064PU1ijbpqCV+om444zVPPK1SfOoHS9evAqNUAAG2FEqXr1wEAwvoNcGvaxBx11RK+UGXkxapPncDNnG34pbIC0vAIRE6Y6HU/tp54XuFmzjZjRWTAqNW4mbPN5f1D9z8cowgLYK14qKuWOIq66byU4epfW6EEGMZ49U/dUQ0tIUem24vufzjOW18tQ3wPVUZeqrGrf38njVA4NN1edP/Dcf3jW2JmeldjS0gRFoCZ6V2pNUkcRt10XspdV/+exNYFdv+QUIe3Ezlhotk9IwAQyeWInDDRpfzR/Q/neOOrZYjvocrIS0kjFKwVj6tX/57C9baGsNBmiG/bwqFtGe4L3czZBm2FEtIIBS/30+j+ByGeQ5WRl3LX1b+ncHWBrd9zGdnz+ju8vbB+A3gfzDEhtZNZhQnQ/Q9CmgpVRl7K7Orfi0fT2Yurq+tmZV0T54Sbt75aX0gMI0T5bNESYaDKyIsZrv6jokJRXl7j6ey4hKsLLDI80AO54Ub3P9yHng9rnL8/VkCj6UiT4BoCPCM9zkM5Ik2NRohyo8cKqGVEHODKlRtXF9iQpDY+3+oj9hHiCFG+NPZYgb+0jkQMwzB8bOjq1avIyspCVVUVWrRogezsbLRv395sGZ1Oh1dffRXHjh2DSCTC3LlzMWnSJIfSUSprodebZ9nYD11ZAV1IcxyJSMQZWRv00VxHasV5iGuqUCsLwcHwXvgj+m6IRCLU1mmN8yU1VYBYDOj10IW2sFpfUlMFUXAwRCIR9LW1xmVNp7H1f5v+eJtuy5BGTZ0GwyovIERTCz1HutIIBTrMehg11XXGvnbTdA3/1tXWGstY1jrOWFGw9dEXhHZkzRfbtkz3l0Fc9RUMqTiPMO0t6ENb4K7JDyKs3wDO8rLtG0O+NBVK1nyz4aoMTcto69iwLWt6zA3bja+5YlyObb5pHtn2cafMkWaVrFgsgkIR4lCsOxP/hnuLAPtoQ9N9KBYBegYIbiaxOidMz5nQQBnrsbR1ccK2r7kYtmuab+MxvHXL7Lw25JsrBuw9J0xjj+t3wVZssOXbdB3TfWv6b7Z9zyau+goy6n60+g1ii2+u89re3yu2fci1fa5y3Xlk426H45+3ymjGjBmYOHEixo4dix07dmDbtm1Yv3692TLbt2/Hzp07sXr1alRVVWHcuHHYtGkTWrdubXc6liejZT80AGhEElwM6YSetYWQMTqz6buj+uNyWEfEVV/B6PKTZvNtrW+LSC5HzIxZxhPVMDKLLS0tRABEkOLOCcqZrkQCQATo2APWMu+7o/qjMKIz5rSvQ9jBr8z2DSOVYXdUP/wQ3KHRfWC6rcthHY3T2NYRyeWoHjYeq68FcpbXdN9wHTNDvmemdwUAs8DX6vRQacxDVS4Vs5aRjUguR+iAgag5kce6rGlZE25dxejyUxBpNazz5VKx8cFOtrKI5HJ0fupxiLolGqe5qzJiS58tXiyPERt745RrW7b2i0024pwtHu2JAVvnhK3fBVuxwZZvtrw6ytb5acqR89rA1jnJtX1b5FIxFjzYy+yRDXvin5d7RkqlEpcuXUJGRgYAICMjA5cuXUJFRYXZcrt378akSZMgFosRERGBtLQ0fPPNNy6lzdYPLWN06F3zq9XBkDE6DKk4DwAYUnGe82BxrW+Laf+3abObLS0pGLMTvNF0dTq7KiLDNoZUnG9I+9tdVvtGpNVgUPk5znyxbcsU2zqMWg18u6vR8pruG65jZsj3pv0/m/Wf36rXWVVEADjLyIZRq1F99AjnsqZlHVR+zvzHxmK+6VsZuO6D/PHZRpt54gNb+mzxwqjVCPlmC2b/thVx1VdYt2VvnFoebwNb+8UmG3HOFo/2xICtc8LW74Kt2GDLN1teHWXr/DTlyHltYOuc5Nq+LYZHNhzFyz2j4uJixMTEQCKRAAAkEgmio6NRXFyMiIgIs+VatWpl/Ds2NhYlJY7doLOsXX+prGBdTgT2Bl+Y9pbZ/7lwrW+LtrICUVGhqDAZOWYrLT7SNWVIL0RT2+h8e/JluQzXOqZpcS1j2Ddcx8yw3q16+y8CuMrIqpFuItP0ufJvOr2iWtVoWVQ3lYiKcvztErbYG/9sRACaa29hdPlJALC6anckTrn2u6394iq2PNoTA66eE46cM5brOMvR9Z3Jo61z0tn83Kysczj+fW40nVJZi/LyGuN/0vAI1uUYiFinV0uDzf7PhWt9W/6SBGHWK98gOPBOPW8rLT7SNWVIr1bG3iy2dx+wLcO1jmlaXMtIwyMaPWaO7Ce2dG3R29i3tvaL6fSIsIBGyxIQqTCLU6XSgUqzEfbGf2O4rtod2f9c+93WfjESO/fTw5ZHe2LA1XPCkXPGch1nObq+M3m0dU46m5/I8ECH45+Xyig2NhalpaXQ6RquaHU6HcrKyhAbG2u1XFFRkfHv4uJitGzp2kiRyAkTIZLLzaZpRBKcC70bGpHEavrhiIZ+/MMRiVbzba1vi2H7ymoV6uq1kEpEnGlpIYLWYvdzpiuRABL7GrGGPMilYmD4GKt9w0hlOB7VmzNfbNsyxbaOSC4Hho8xDt3mWsZwo5frmBnyHRJoX1m5yshVlsaOqWlZj0f1BiOVcc43vJWh+tQJ6FT1VtsSyeVoO32aXWVwFdu+tCde2K5y7Y1Ty+NtYPq2CtZ8mawfNjjVruNmii0e7YkBW+eErd8FW7HBtr/Z8moPiQiN/m5wceS8NrB1TnJt3xZnH9mQvPzyyy87vJaFoKAgHDt2DFKpFF27dsXOnTtRXl6OadPMT0iNRoPt27cjMzMTlZWVeO211/D888+jefPmdqdVV6eG6ZCLgNZtIFMoUH/tGvT19dCFtsDRlv1wLCweutBwtNVXQaSuR60sBPsUKfgj+m4EyCQokjQ3zher6xuu1BiGdX2xuh6i4GCIAwIa+lX/t6xhml6tRrU0GPsjU4xdHwyAwAAJwoLk+ANhZtvShbbA0dj++DHgLtylroBcr4aeI11phAKd5j8GebeEhjLW1ZnlxTQPhjKWtY7DQ2n3oM+QxDv7pq4O0ggFYh6aioCkvvi9pNoqX2zbMuwvtVYPRVgARmb2Q+f4DmbbjJ4yFR1GDIOieTPW7RqWMYzcMT1muro6q3wndFTgxytK6PTcXZYhgVI8PLKLVRkbOzZnFAmokoYgVqVEgF4DkcUxvxjUHoqwANyX0Rcdu7U3btNy/kNp9yC+5krDDd9688pIHBKCmGnT0WbEcNy+bTKoQSRCUJBjP77OxL80QoHoqdMQkph4Z5+wqJYG41x4NzBoGNFleU4Yz5nIPiiJ7Gh1LE2Pd51KZ7Zfbrz3DqrzjkPUrJnVOWNYXzE60/y4cRCHhIDRaMyOgVgEY76lEjHyysRQBzVvNI7ZzgnT2OP6XbAVG2z723Qdw75Va/Vm/zYtg+n5NXVEFyTeHdXo+Wl63jNqtc3z2vKcMKzDdU6y/sbIpOhc8Rt61RSiVhyAuuaRrOUyxMHIAR0cjn/eRtMVFhYiKysL1dXVCAsLQ3Z2Njp27Ig5c+Zg4cKFSEhIgE6nw9KlS5GXlwcAmDNnDiZPnuxQOmxDWw089aaCR18/yDnv46xhLm9fCG9g4MJVNtNhw2xDR+199sJdx+bK889wvsi24xtvWpXLnUO7Ddj2JddoP8MoKj45m9Yvj83inHfPmnWs5bJ88S5gPprPF3jLed3YEH1nj6kz8c/bQ6+dOnXC1q1braavXr3a+G+JRIJXXnmFryS9Br3tmX98vZbHXcfGVx7gdNcbztk4+wVeZ95QTw+J8oPrbfpAwznozq8qW6I3MPCA3vbsvdx1bHzpEx/ueMM5G2craGfeUN/U354S6nvjbFXqTXnR5XOj6bwRfe3Se7nr2LDd8PXlT3zwwdkv8Ib1G4CYGbOMy0kjFDa7gbhatu7ojRDye+NsVeru+qoy6zZ536Kforc9ey93HJum7P7yFa58g8vR1ltT9kYIuUvQVjd2U35XjSojQpzUVN1fvqIpK2hHvz3lSjebkD9Hb6tSb8pjSpWRAAm1f5t4v6asoO1t8dq6SW9LY60HX/9YoD2VelMdU6qMBMbVE48QoXG1m42r9TAluhql6++8dNVXPxboLbcYqDISGCH3b3sDanX6Hle72bhaD+GfroC2iYY9+wOqjARGyP3bnkatTt/Ex7NmbK2HX3zkWTNfQUO7eXayoATPfZCHR18/iOc+yGvy4Z9NOeTV3zTW6iTei+uT93w8a+bIdNI4qox45A3PI7jrxCPU6vRV9KyZb6BuOh55w/0aR4e8EvvRa598Fz1r5v2oMuKRt1w5e8voGKGh1z4RS/SsGX+oMuIRXTm7xttHqlGrkxD3ocqIR3Tl7DxfGalGrU5C3IMGMPCIXpjqPBqpRoh/o5YRz+jK2Tnecr+NEOIZ1DIiXoGejyLEv1HLyE94++AAut9GiH+jysgP+MLgABqp5rqyI0dxdd0GeuaF+CSqjPyANzyMaw+63+a86lMnUPbZp9CrGu6x+eobpIn/ontGfoAGBwjfzZxtxorIwPAGaUJ8AVVGfoAGBwgf15ui6Q3SxFdQZeQH6OWpwkdvkCa+jiojP0AP4wpf5ISJEAeYt3TpDdLEl9AABj9BgwOELazfAISGBdJoOuKzqDIiRCCiUwdD1C3R09kgxCnUTUcIIcTjXG4Z1dXV4YUXXkBBQQEkEgkWLVqEoUOHWi13+vRpzJ07F+3btwcAyOVybN261dXkCSGECIDLldHatWsREhKC/fv349q1a5g2bRr27duH4OBgq2U7deqEnJwcV5MkhBAiMC530+3ZsweTJ08GALRv3x7du3fH0aNHXc4YIYQQ/+Fyy6ioqAh33XWX8e/Y2FiUlJSwLnvt2jWMHz8eUqkUU6dOxfjxt3+KIQAABc5JREFU4x1OT6EIaXR+VFSow9v0BUItFyDcsrmjXBT/wuJquQ5/dx3r91zGzco6RIYHYkZ6HIYkteEpd65xtGw2K6Px48ejqKiIdd6JEyfsTig+Ph5HjhxBaGgorl+/jkceeQQxMTEYMMCxoadKZS30eoZ1XlRUKMrLaxzani8QarkA4ZbNslxischmRWIPin/hcLVcli9ALq+sw8ovLqC6pt7jj3E4E/82K6Ovvvqq0fmtWrXCjRs3EBERAQAoLi5G3759rZYLCbmTkTZt2iAtLQ3nzp1zuDIihBDiOy9AtpfL94xGjRqFLVu2AGjohvvhhx9w7733Wi1XVlYGhmm4oquqqkJeXh66du3qavKEEOKXhPYCZJfvGc2ePRtZWVm47777IBaLsXTpUmMr6N1330V0dDQeeugh7Nu3D59//jmkUil0Oh3GjRuHtLQ0lwtACCH+SBEWwFrx+OoLkEWMobniI6jPXFiEWja6Z8QfKhc7y3tGQMMLkL3hvZNuuWfkbcRikUvzfZVQywUIt2ym5eKrjBT/wuJKuQYmxCKwmRR7T/2ByloVwkMCMLJfW/S+O4rHHDrP0fj3uZYRIYQQ4aF30xFCCPE4qowIIYR4HFVGhBBCPI4qI0IIIR5HlREhhBCPo8qIEEKIx1FlRAghxOOoMiKEEOJxVBkRQgjxOEFURlevXsXkyZMxcuRITJ48GdeuXfN0lpxSWVmJOXPmYOTIkcjMzMRTTz2FiooKAMCFCxdw//33Y+TIkXj00UehVCo9nFvnvP/+++jSpQt++eUXAMIol0qlwpIlSzBixAhkZmbixRdfBNB0cUnx7zuEFv+8xj4jANOnT2e2b9/OMAzDbN++nZk+fbqHc+ScyspK5tSpU8a/X3/9deaFF15gdDodk5aWxuTn5zMMwzD/+c9/mKysLE9l02k//vgjM3v2bGbo0KHMzz//LJhyLVu2jFm+fDmj1+sZhmGY8vJyhmGaLi4p/n2DEOOfz9j3+cro5s2bTFJSEqPVahmGYRitVsskJSUxSqXSwzlz3TfffMPMnDmTuXjxIjNmzBjjdKVSyfTq1cuDOXOcSqViHnzwQeb69evGk1EI5aqtrWWSkpKY2tpas+lNFZcU/75BiPHPd+z7fDddcXExYmJiIJFIAAASiQTR0dEoLi72cM5co9fr8fnnn2PYsGEoLi5Gq1atjPMiIiKg1+tRVVXlwRw65t1338X999+P1q1bG6cJoVzXr19HixYt8P7772PChAmYPn06zp4922RxSfHvG4QY/3zHvs9XRkK1bNkyBAUF4eGHH/Z0Vlx2/vx5/Pjjj5g6daqns8I7nU6H69evo1u3bsjJycGzzz6LBQsW4Pbt257Omk+j+Pd+fMe+z33PyFJsbCxKS0uh0+kgkUig0+lQVlaG2NhYT2fNadnZ2fj999+xatUqiMVixMbGoqioyDi/oqICYrEYLVq08GAu7Zefn4/CwkIMHz4cAFBSUoLZs2dj+vTpPl0uoCH+pFIpMjIyAAA9e/ZEeHg4mjVr1iRxSfHv/YQa/3zHvs+3jBQKBeLi4pCbmwsAyM3NRVxcHCIiIjycM+e89dZb+PHHH/Gf//wHcrkcANC9e3fU19fj7NmzAIDNmzdj1KhRnsymQ+bOnYvjx4/j4MGDOHjwIFq2bIm1a9fiscce8+lyAQ1dK3379kVeXh6AhlFESqUS7du3b5K4pPj3fkKNf75jXxAf1yssLERWVhaqq6sRFhaG7OxsdOzY0dPZctivv/6KjIwMtG/fHs2aNQMAtG7dGv/5z39w7tw5LFmyBCqVCnfddRdWrFiByMhID+fYOcOGDcOqVatwzz33CKJc169fx+LFi1FVVQWpVIqnn34aqampTRaXFP++RUjxz2fsC6IyIoQQ4tt8vpuOEEKI76PKiBBCiMdRZUQIIcTjqDIihBDicVQZEUII8TiqjAghhHgcVUaEEEI8jiojQgghHvf/SddH2fTa/BAAAAAASUVORK5CYII=\n", - "text/plain": "
" - }, - "metadata": {} - } - ], - "_view_module": "@jupyter-widgets/output", - "_model_module_version": "1.0.0", - "_view_count": null, - "_view_module_version": "1.0.0", - "layout": "IPY_MODEL_83245b18235e4bad8141bc399ea5cd94", - "_model_module": "@jupyter-widgets/output" - } - }, - "1947b94e59c0410b875b21164719c876": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "b2afc9cdf727400487ce19f9f0ce7d53": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "16b08cb41fc94870b7578021c7927d1a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "654c62b8b9c3489c9f03fc76e81a2ed9": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "83245b18235e4bad8141bc399ea5cd94": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "d841bfa912f54aec8f0b185e97f33c06": { - "model_module": "@jupyter-widgets/controls", - "model_name": "VBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "VBoxView", - "_dom_classes": [ - "widget-interact" - ], - "_model_name": "VBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_da249cda804d4ca9b5ff93c27bfb3a65", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_71bdb348f5804e45a2adeb256ee72c72", - "IPY_MODEL_116218d7741949e08f547ac9844aa7e1", - "IPY_MODEL_9a3e2484604b49fdbae2dad6d60f43b2" - ] - } - }, - "da249cda804d4ca9b5ff93c27bfb3a65": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "71bdb348f5804e45a2adeb256ee72c72": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatLogSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatLogSliderView", - "orientation": "horizontal", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "readout_format": ".9f", - "_model_module": "@jupyter-widgets/controls", - "style": "IPY_MODEL_e93f9ba98fe34786a3064647deae6cfd", - "layout": "IPY_MODEL_f2174b7c3df64f278e748afaf20ff90d", - "min": -10, - "continuous_update": false, - "description_tooltip": null, - "_dom_classes": [], - "description": "c1", - "_model_name": "FloatLogSliderModel", - "max": 0, - "readout": true, - "step": 0.1, - "base": 10, - "value": 0.005, - "_view_module_version": "1.5.0" - } - }, - "116218d7741949e08f547ac9844aa7e1": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatLogSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatLogSliderView", - "orientation": "horizontal", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "readout_format": ".9f", - "_model_module": "@jupyter-widgets/controls", - "style": "IPY_MODEL_94b31bb0dadf44eaa94e43614a626040", - "layout": "IPY_MODEL_6c7b054a3a7b4aa6bae9122dfcb83bde", - "min": -10, - "continuous_update": false, - "description_tooltip": null, - "_dom_classes": [], - "description": "c2", - "_model_name": "FloatLogSliderModel", - "max": 0, - "readout": true, - "step": 0.1, - "base": 10, - "value": 0.005, - "_view_module_version": "1.5.0" - } - }, - "9a3e2484604b49fdbae2dad6d60f43b2": { - "model_module": "@jupyter-widgets/output", - "model_name": "OutputModel", - "model_module_version": "1.0.0", - "state": { - "_view_name": "OutputView", - "msg_id": "", - "_dom_classes": [], - "_model_name": "OutputModel", - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - } - ], - "_view_module": "@jupyter-widgets/output", - "_model_module_version": "1.0.0", - "_view_count": null, - "_view_module_version": "1.0.0", - "layout": "IPY_MODEL_f2be86b07cc34fec98d8a0ac03f6e410", - "_model_module": "@jupyter-widgets/output" - } - }, - "e93f9ba98fe34786a3064647deae6cfd": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "f2174b7c3df64f278e748afaf20ff90d": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "94b31bb0dadf44eaa94e43614a626040": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "6c7b054a3a7b4aa6bae9122dfcb83bde": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "f2be86b07cc34fec98d8a0ac03f6e410": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "e7c5993c7704484db6390038d50e8e52": { - "model_module": "@jupyter-widgets/controls", - "model_name": "VBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "VBoxView", - "_dom_classes": [ - "widget-interact" - ], - "_model_name": "VBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_156c40392f954961b7c6cebee3912d7d", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_13fc9e6c26604f58a3a821c119bd268a", - "IPY_MODEL_73ae601863af45f6b0d97be47ab8149b", - "IPY_MODEL_04140cf3fa164a64930ec261b9f556f6" - ] - } - }, - "156c40392f954961b7c6cebee3912d7d": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "13fc9e6c26604f58a3a821c119bd268a": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatLogSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatLogSliderView", - "orientation": "horizontal", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "readout_format": ".9f", - "_model_module": "@jupyter-widgets/controls", - "style": "IPY_MODEL_1172e605bb634b31ad0c9a437db7c047", - "layout": "IPY_MODEL_33780801a70f4d168ec7fd904ee99685", - "min": -10, - "continuous_update": false, - "description_tooltip": null, - "_dom_classes": [], - "description": "cx", - "_model_name": "FloatLogSliderModel", - "max": 0, - "readout": true, - "step": 0.1, - "base": 10, - "value": 1e-10, - "_view_module_version": "1.5.0" - } - }, - "73ae601863af45f6b0d97be47ab8149b": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatLogSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatLogSliderView", - "orientation": "horizontal", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "disabled": false, - "readout_format": ".9f", - "_model_module": "@jupyter-widgets/controls", - "style": "IPY_MODEL_e7a68ca2f9794421b2f047ddadfc6741", - "layout": "IPY_MODEL_6aa67cbfb05b4ff9ba03eb365e94c458", - "min": -10, - "continuous_update": false, - "description_tooltip": null, - "_dom_classes": [], - "description": "cy", - "_model_name": "FloatLogSliderModel", - "max": 0, - "readout": true, - "step": 0.1, - "base": 10, - "value": 1e-10, - "_view_module_version": "1.5.0" - } - }, - "04140cf3fa164a64930ec261b9f556f6": { - "model_module": "@jupyter-widgets/output", - "model_name": "OutputModel", - "model_module_version": "1.0.0", - "state": { - "_view_name": "OutputView", - "msg_id": "", - "_dom_classes": [], - "_model_name": "OutputModel", - "outputs": [ - { - "output_type": "error", - "ename": "AttributeError", - "evalue": "ignored", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/ipywidgets/widgets/interaction.py\u001b[0m in \u001b[0;36mupdate\u001b[0;34m(self, *args)\u001b[0m\n\u001b[1;32m 255\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mwidget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_interact_value\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 256\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mwidget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwarg\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 258\u001b[0m \u001b[0mshow_inline_matplotlib_plots\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_display\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mresult\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m\u001b[0m in \u001b[0;36minteractive_cca\u001b[0;34m(cx, cy)\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mrcca\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrCCA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlatent_dims\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mcx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mcy\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mHX_tr\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mHY_tr\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0mtest_scores\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransform\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mHX_te\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mHY_te\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 8\u001b[0;31m \u001b[0mplot_latent_train_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 9\u001b[0m \u001b[0mplot_train_test_corrs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mscores\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtest_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mplot_model_weights\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mrcca\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweights\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtitle\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'Model weights'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'rCCA' object has no attribute 'scores'" - ] - } - ], - "_view_module": "@jupyter-widgets/output", - "_model_module_version": "1.0.0", - "_view_count": null, - "_view_module_version": "1.0.0", - "layout": "IPY_MODEL_d83a784c9575476e86292a2f795a439c", - "_model_module": "@jupyter-widgets/output" - } - }, - "1172e605bb634b31ad0c9a437db7c047": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "33780801a70f4d168ec7fd904ee99685": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "e7a68ca2f9794421b2f047ddadfc6741": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "6aa67cbfb05b4ff9ba03eb365e94c458": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "d83a784c9575476e86292a2f795a439c": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "2c3a73827f794598bb0bb9a749c4b70f": { - "model_module": "@jupyter-widgets/controls", - "model_name": "VBoxModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "VBoxView", - "_dom_classes": [ - "widget-interact" - ], - "_model_name": "VBoxModel", - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.5.0", - "box_style": "", - "layout": "IPY_MODEL_f4a50fa01e66441090ebd78f4be413d4", - "_model_module": "@jupyter-widgets/controls", - "children": [ - "IPY_MODEL_6590606b34d643fb8c6c7e06599d5bca", - "IPY_MODEL_886e8a7d34ca460db623ee35abce05a6", - "IPY_MODEL_6097fdf1adb94dc9bf89bd5e0a101fc8" - ] - } - }, - "f4a50fa01e66441090ebd78f4be413d4": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "6590606b34d643fb8c6c7e06599d5bca": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatSliderView", - "style": "IPY_MODEL_75aa1b12f01448198ad356465332fffd", - "_dom_classes": [], - "description": "c1", - "step": 0.1, - "_model_name": "FloatSliderModel", - "orientation": "horizontal", - "max": 21.540659228538015, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 3, - "_view_count": null, - "disabled": false, - "_view_module_version": "1.5.0", - "min": 1, - "continuous_update": false, - "readout_format": ".5f", - "description_tooltip": null, - "readout": true, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_6049c8f1a0c54f11955956952181b242" - } - }, - "886e8a7d34ca460db623ee35abce05a6": { - "model_module": "@jupyter-widgets/controls", - "model_name": "FloatSliderModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "FloatSliderView", - "style": "IPY_MODEL_ed894f48599f4920bf247584398d2e0e", - "_dom_classes": [], - "description": "c2", - "step": 0.1, - "_model_name": "FloatSliderModel", - "orientation": "horizontal", - "max": 4.898979485566356, - "_view_module": "@jupyter-widgets/controls", - "_model_module_version": "1.5.0", - "value": 3, - "_view_count": null, - "disabled": false, - "_view_module_version": "1.5.0", - "min": 1, - "continuous_update": false, - "readout_format": ".5f", - "description_tooltip": null, - "readout": true, - "_model_module": "@jupyter-widgets/controls", - "layout": "IPY_MODEL_ca3f1c4c5a804634a7f13b0b9ab0ebdd" - } - }, - "6097fdf1adb94dc9bf89bd5e0a101fc8": { - "model_module": "@jupyter-widgets/output", - "model_name": "OutputModel", - "model_module_version": "1.0.0", - "state": { - "_view_name": "OutputView", - "msg_id": "", - "_dom_classes": [], - "_model_name": "OutputModel", - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": "
" - }, - "metadata": {} - } - ], - "_view_module": "@jupyter-widgets/output", - "_model_module_version": "1.0.0", - "_view_count": null, - "_view_module_version": "1.0.0", - "layout": "IPY_MODEL_64aa754a29844764b480dc5c91b700be", - "_model_module": "@jupyter-widgets/output" - } - }, - "75aa1b12f01448198ad356465332fffd": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "6049c8f1a0c54f11955956952181b242": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "ed894f48599f4920bf247584398d2e0e": { - "model_module": "@jupyter-widgets/controls", - "model_name": "SliderStyleModel", - "model_module_version": "1.5.0", - "state": { - "_view_name": "StyleView", - "handle_color": null, - "_model_name": "SliderStyleModel", - "description_width": "initial", - "_view_module": "@jupyter-widgets/base", - "_model_module_version": "1.5.0", - "_view_count": null, - "_view_module_version": "1.2.0", - "_model_module": "@jupyter-widgets/controls" - } - }, - "ca3f1c4c5a804634a7f13b0b9ab0ebdd": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - }, - "64aa754a29844764b480dc5c91b700be": { - "model_module": "@jupyter-widgets/base", - "model_name": "LayoutModel", - "model_module_version": "1.2.0", - "state": { - "_view_name": "LayoutView", - "grid_template_rows": null, - "right": null, - "justify_content": null, - "_view_module": "@jupyter-widgets/base", - "overflow": null, - "_model_module_version": "1.2.0", - "_view_count": null, - "flex_flow": null, - "width": null, - "min_width": null, - "border": null, - "align_items": null, - "bottom": null, - "_model_module": "@jupyter-widgets/base", - "top": null, - "grid_column": null, - "overflow_y": null, - "overflow_x": null, - "grid_auto_flow": null, - "grid_area": null, - "grid_template_columns": null, - "flex": null, - "_model_name": "LayoutModel", - "justify_items": null, - "grid_row": null, - "max_height": null, - "align_content": null, - "visibility": null, - "align_self": null, - "height": null, - "min_height": null, - "padding": null, - "grid_auto_rows": null, - "grid_gap": null, - "max_width": null, - "order": null, - "_view_module_version": "1.2.0", - "grid_template_areas": null, - "object_position": null, - "object_fit": null, - "grid_auto_columns": null, - "margin": null, - "display": null, - "left": null - } - } - } - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "h9McsiUZLsp0" - }, - "source": [ - "# An interactive introduction to Multiview Learning with Canonical Correlation Analysis and Partial Least Squares using cca-zoo\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "UmQifwjwI1mh" - }, - "source": [ - "## Table of Contents\n", - "1. [Objective of the Tutorial](#obj)\n", - "2. [Set-up](#set-up)\n", - "3. [Introduction to multiview machine learning](#mv)\n", - " * [Machine Learning Framework](#mlframework)\n", - " * [Latent Variable Models](#lv)\n", - "4. [Simulating Data](#datagen)\n", - "5. [Overfitting and Sample Size](#samplesize)\n", - "6. [Ridge Regularisation: From CCA to PLS](#ridge)\n", - "7. [Lasso Regularisation](#lasso)\n", - " * [Sparse Partial Least Squares](#spls)\n", - " * [Sparse Canonical Correlation Analysis](#scca)\n", - "8. [Application: a neuroimaging case study](#haxby)\n", - " * [Regularised CCA](#rccahaxby)\n", - " * [Sparse Partial Least Squares](#splshaxby)\n", - " * [Sparse Canonical Correlation Analysis](#sccahaxby)\n", - "9. [Conclusion](#conclusion)\n", - "10. [References](#references)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WoKA75AsKSTp" - }, - "source": [ - "## Objective \n", - "With the large multiview and multimodal datasets available in medical imaging and healthcare informatics, there is a strong need for multivariate statistical and machine learning methods which can capture relationships across different views of data.\n", - "\n", - "Canonical correlation analysis and partial least squares are arguably the two most commonly used methods in modelling two-view data with latent variable models. Extensions to these methods have included Generalized CCA (and Multiset CCA) which can capture relationships between more than 2 views and various forms of regularised CCA and PLS.\n", - "\n", - "This tutorial has three main aims \n", - "\n", - "* to give an interactive and visual introduction to CCA and PLS and to demonstrate the effect and benefits of using regularised versions of these classical methods\n", - "* to demonstrate how we might use these methods in a neuroimaging example\n", - "* to introduce cca-zoo, an open source python package containing implementations of this family of algorithms. The source code is available at https://github.com/jameschapman19/cca_zoo and documentation at https://cca-zoo.readthedocs.io/en/latest/index.html" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Jp_CSLIvJKBJ" - }, - "source": [ - "## Set-up \n", - "This tutorial depends on the package cca-zoo and nilearn for a simulated neuroimaging example.\n", - "\n", - "In order to make use of the interactive elements, the notebook should be run in a google colab instance." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mvcwPQARKlkY" - }, - "source": [ - "### Installation and Imports" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0OhyKPp0Lwit", - "outputId": "b6e6666b-8247-4a90-dae5-edc43f539f1f" - }, - "source": [ - "# @markdown Execute this cell to install cca-zoo and nilearn\n", - "!pip install cca-zoo --upgrade\n", - "!pip install nilearn" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting cca-zoo\n", - " Downloading cca_zoo-1.9.0-py3-none-any.whl (68 kB)\n", - "\u001b[?25l\r\u001b[K |████▊ | 10 kB 22.2 MB/s eta 0:00:01\r\u001b[K |█████████▌ | 20 kB 25.8 MB/s eta 0:00:01\r\u001b[K |██████████████▎ | 30 kB 13.4 MB/s eta 0:00:01\r\u001b[K |███████████████████ | 40 kB 9.9 MB/s eta 0:00:01\r\u001b[K |███████████████████████▊ | 51 kB 7.2 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▌ | 61 kB 7.6 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 68 kB 2.7 MB/s \n", - "\u001b[?25hRequirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (1.0.1)\n", - "Collecting scikit-learn>=0.23\n", - " Downloading scikit_learn-1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.1 MB)\n", - "\u001b[K |████████████████████████████████| 23.1 MB 1.6 MB/s \n", - "\u001b[?25hCollecting mvlearn\n", - " Downloading mvlearn-0.4.1-py3-none-any.whl (2.1 MB)\n", - "\u001b[K |████████████████████████████████| 2.1 MB 37.2 MB/s \n", - "\u001b[?25hCollecting tensorly\n", - " Downloading tensorly-0.6.0-py3-none-any.whl (160 kB)\n", - "\u001b[K |████████████████████████████████| 160 kB 86.6 MB/s \n", - "\u001b[?25hCollecting scipy>=1.7\n", - " Downloading scipy-1.7.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (28.5 MB)\n", - "\u001b[K |████████████████████████████████| 28.5 MB 50 kB/s \n", - "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (1.19.5)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (1.1.5)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (3.2.2)\n", - "Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (0.11.2)\n", - "Collecting threadpoolctl>=2.0.0\n", - " Downloading threadpoolctl-3.0.0-py3-none-any.whl (14 kB)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (0.10.0)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (2.4.7)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (2.8.2)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (1.3.2)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->cca-zoo) (1.15.0)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->cca-zoo) (2018.9)\n", - "Collecting nose\n", - " Downloading nose-1.3.7-py3-none-any.whl (154 kB)\n", - "\u001b[K |████████████████████████████████| 154 kB 87.2 MB/s \n", - "\u001b[?25hInstalling collected packages: threadpoolctl, scipy, scikit-learn, nose, tensorly, mvlearn, cca-zoo\n", - " Attempting uninstall: scipy\n", - " Found existing installation: scipy 1.4.1\n", - " Uninstalling scipy-1.4.1:\n", - " Successfully uninstalled scipy-1.4.1\n", - " Attempting uninstall: scikit-learn\n", - " Found existing installation: scikit-learn 0.22.2.post1\n", - " Uninstalling scikit-learn-0.22.2.post1:\n", - " Successfully uninstalled scikit-learn-0.22.2.post1\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n", - "Successfully installed cca-zoo-1.9.0 mvlearn-0.4.1 nose-1.3.7 scikit-learn-1.0 scipy-1.7.1 tensorly-0.6.0 threadpoolctl-3.0.0\n", - "Collecting nilearn\n", - " Downloading nilearn-0.8.1-py3-none-any.whl (10.0 MB)\n", - "\u001b[K |████████████████████████████████| 10.0 MB 5.8 MB/s \n", - "\u001b[?25hRequirement already satisfied: requests>=2 in /usr/local/lib/python3.7/dist-packages (from nilearn) (2.23.0)\n", - "Requirement already satisfied: numpy>=1.16 in /usr/local/lib/python3.7/dist-packages (from nilearn) (1.19.5)\n", - "Requirement already satisfied: joblib>=0.12 in /usr/local/lib/python3.7/dist-packages (from nilearn) (1.0.1)\n", - "Requirement already satisfied: pandas>=0.24.0 in /usr/local/lib/python3.7/dist-packages (from nilearn) (1.1.5)\n", - "Requirement already satisfied: scipy>=1.2 in /usr/local/lib/python3.7/dist-packages (from nilearn) (1.7.1)\n", - "Requirement already satisfied: nibabel>=2.5 in /usr/local/lib/python3.7/dist-packages (from nilearn) (3.0.2)\n", - "Requirement already satisfied: scikit-learn>=0.21 in /usr/local/lib/python3.7/dist-packages (from nilearn) (1.0)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->nilearn) (2018.9)\n", - "Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas>=0.24.0->nilearn) (2.8.2)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas>=0.24.0->nilearn) (1.15.0)\n", - "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests>=2->nilearn) (1.24.3)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests>=2->nilearn) (2021.5.30)\n", - "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests>=2->nilearn) (2.10)\n", - "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests>=2->nilearn) (3.0.4)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn>=0.21->nilearn) (3.0.0)\n", - "Installing collected packages: nilearn\n", - "Successfully installed nilearn-0.8.1\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "h9PE1A0fS266", - "cellView": "form", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "076de56d-adde-44ea-f240-b116cdef4470" - }, - "source": [ - "# @markdown Execute this cell to import modules\n", - "import ipywidgets as widgets\n", - "import seaborn as sns\n", - "from cca_zoo.models import rCCA, PMD, SCCA\n", - "from cca_zoo.data import generate_covariance_data\n", - "import numpy as np\n", - "import pandas as pd\n", - "from sklearn.model_selection import train_test_split\n", - "import matplotlib.pyplot as plt\n", - "from nilearn import datasets\n", - "from nilearn.image import index_img\n", - "from nilearn.input_data import NiftiMasker\n", - "from nilearn.plotting import view_img\n", - "from sklearn.preprocessing import OneHotEncoder\n", - "from nilearn.plotting import plot_stat_map, show\n", - "from sklearn.model_selection import train_test_split\n", - "np.random.seed(42)\n", - "sns.set(font_scale=1)" - ], - "execution_count": 2, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/nilearn/datasets/__init__.py:96: FutureWarning: Fetchers from the nilearn.datasets module will be updated in version 0.9 to return python strings instead of bytes and Pandas dataframes instead of Numpy arrays.\n", - " \"Numpy arrays.\", FutureWarning)\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TMApL_uRnguy" - }, - "source": [ - "### Plotting Helpers\n", - "We will use a number of standard figures to visualise the results of models used throughout this notebook" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "MieUEzyKS0eA", - "cellView": "form" - }, - "source": [ - "# @markdown Execute this cell to access plotting functions\n", - "def plot_latent_train_test(train_scores, test_scores, title='Projections of data in latent space'):\n", - " train_data = pd.DataFrame(\n", - " {'phase': np.asarray(['train'] * train_scores[0].shape[0]).astype(str)})\n", - " x_vars=[f'View 1 dimension {f}' for f in range(1,train_scores[0].shape[1]+1)]\n", - " y_vars=[f'View 2 dimension {f}' for f in range(1,train_scores[1].shape[1]+1)]\n", - " train_data[x_vars] = train_scores[0]\n", - " train_data[y_vars] = train_scores[1]\n", - " test_data = pd.DataFrame(\n", - " {'phase': np.asarray(['test'] * test_scores[0].shape[0]).astype(str)})\n", - " test_data[x_vars] = test_scores[0]\n", - " test_data[y_vars] = test_scores[1]\n", - " data = pd.concat([train_data, test_data], axis=0)\n", - " cca_pp = sns.pairplot(data, hue='phase',x_vars=x_vars,y_vars=y_vars, corner=True)\n", - " cca_pp.fig.set_size_inches(10,5)\n", - " if title:\n", - " cca_pp.fig.suptitle(title)\n", - " return cca_pp\n", - "\n", - "def plot_train_test_corrs(train_scores, test_scores, title='Train vs Test correlations'):\n", - " latent_dims=train_scores[0].shape[1]\n", - " train_corrs=np.diag(np.corrcoef(train_scores[0],train_scores[1],rowvar=False)[:latent_dims,latent_dims:])\n", - " test_corrs=np.diag(np.corrcoef(test_scores[0],test_scores[1],rowvar=False)[:latent_dims,latent_dims:])\n", - " train_corr_data=pd.DataFrame({'correlation':train_corrs,'dimension':np.arange(latent_dims)+1,'phase': np.asarray(['train'] * latent_dims).astype(str)})\n", - " test_corr_data=pd.DataFrame({'correlation':test_corrs,'dimension':np.arange(latent_dims)+1,'phase': np.asarray(['test'] * latent_dims).astype(str)})\n", - " corr_data = pd.concat([train_corr_data, test_corr_data], axis=0)\n", - " # setting the dimensions of the plot\n", - " fig2, ax = plt.subplots()\n", - " cca_bp=sns.barplot(x=\"dimension\", y=\"correlation\", hue=\"phase\", data=corr_data,ax=ax)\n", - " cca_bp.set_ylim(bottom=-1, top=1)\n", - " if title:\n", - " cca_bp.set_title(title)\n", - " return cca_bp\n", - "\n", - "def plot_true_weights_coloured(ax, weights, true_weights=None, title=''):\n", - " if true_weights is None:\n", - " true_weights=np.ones(len(weights))\n", - " ind = np.arange(len(true_weights))\n", - " mask = np.squeeze(true_weights == 0)\n", - " ax.scatter(ind[~mask], weights[~mask], c='b',label='Non-Zero')\n", - " ax.scatter(ind[mask], weights[mask], c='r',label='Zero')\n", - " ax.set_title(title)\n", - "\n", - "def plot_model_weights(wx,wy,tx=None,ty=None, title='Model weights vs. True weights'):\n", - " if tx is None and ty is None:\n", - " fig,axs=plt.subplots(1,2)\n", - " plot_true_weights_coloured(axs[0],wx,title='model view 1 weights')\n", - " plot_true_weights_coloured(axs[1],wy,title='model view 2 weights')\n", - " axs[1].legend()\n", - " else:\n", - " fig,axs=plt.subplots(2,2,sharex=True,sharey=True)\n", - " plot_true_weights_coloured(axs[0,0],tx,tx,title='true view 1 weights')\n", - " plot_true_weights_coloured(axs[0,1],ty,ty,title='true view 2 weights')\n", - " plot_true_weights_coloured(axs[1,0],wx,tx,title='model view 1 weights')\n", - " plot_true_weights_coloured(axs[1,1],wy,ty,title='model view 2 weights')\n", - " axs[0,1].legend()\n", - " plt.tight_layout()\n", - " return fig" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cqonSYEbE4c6" - }, - "source": [ - "## Introduction to Multivariate Machine Learning " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "hhiQ_Yg8qc1b" - }, - "source": [ - "### Latent Variable Models \n", - "Latent variable models assume that the two views of data are derived from some shared latent (unobserved) variables. For neuroimaging and behaviour studies we might expect an underlying health condition to influence both an MRI scan of a patient and data relating to their behaviour.\n", - "\n", - "PLS and CCA models work by projecting the observed data from each view into latent variables for each view that are highly correlated. These can be interpeted as estimates of the true latent variable. Often in neuroimaging studies the biggest components can be shown to be related to variables like Age, Scanner Type and Gender.\n", - "\n", - "![Latent Variable Model of Brain-Behaviour data](https://raw.githubusercontent.com/jameschapman19/education-challenge/main/image.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SmxsMVGQ6VV2" - }, - "source": [ - "### Optimisation Problems\n", - "Partial Least Squares and Canonical Correlation Analysis share the same objective function (maximising covariance between the projections of each view). The difference is that PLS constrains the variance of the weights (their 2 norm) whereas CCA constrains the variance of the projections." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-NDpDlMDdCqj" - }, - "source": [ - "#### Partial Least Squares\n", - "$w_{opt}=\\underset{w}{\\mathrm{argmax}}\\{ w_1^TX_1^TX_2w_2 \\}$\n", - "\n", - "$\\text{subject to:}$\n", - "\n", - "$\\|w_1\\|_2=1$,\n", - "$\\|w_2\\|_2=1$" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BdAh2shKdhKr" - }, - "source": [ - "#### Canonical Correlation Analysis\n", - "$w_{opt}=\\underset{w}{\\mathrm{argmax}}\\{ w_1^TX_1^TX_2w_2 \\}$\n", - "\n", - "$\\text{subject to:}$\n", - "\n", - "$\\|X_1w_1\\|_2=1$\n", - "\n", - "$\\|X_2w_2\\|_2=1$\n", - "\n", - "\n", - "The result of these optimisations is that PLS is much more biased towards the largest principal components in the data whereas CCA is sensitive to high correlations, even if they only explain small ammounts of the data.\n", - "\n", - "Also notice that in the special case where the covariance matrices for each view ($X_1^TX_1$ and $X_2^TX_2$) are identity matrices $I$, the CCA and PLS problems are identical (since $\\|X_iw_i\\|_2=w_i^TX_i^TX_iw_i=w_i^Tw_i$)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2YYL_xght8Dk" - }, - "source": [ - "#### Relationship to PCA\n", - "It is interesting to compare these objectives to the more familiar principal components analysis. In PCA, we find weights that project the data in directions that maximise variance. If our two views are the same then PLS is exactly the same model as PCA!\n", - "\n", - "$w_{opt}=\\underset{w}{\\mathrm{argmax}}\\{ w^TX^TXw \\}$\n", - "\n", - "$\\text{subject to:}$\n", - "\n", - "$\\|w\\|_2=1$," - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_irxVD8gTY68" - }, - "source": [ - "### Machine Learning Framework \n", - "\n", - "In a machine learning framework, we are primarily interested in the generalization of a model to data not used in the training process. In this tutorial we will generate simulated data which we will split into train and test data to demonstrate how different models generalize.\n", - "\n", - "![Framework](https://raw.githubusercontent.com/jameschapman19/education-challenge/main/image2.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "on2NP1JlnmCf" - }, - "source": [ - "### Simulated Data \n", - "We will use a joint multivariate normal distribution to generate two views of data: view 1 and view 2. This method, described in detail by Helmer et al [1], allows us to control:\n", - "\n", - "* Number of training samples\n", - "* Number of test samples\n", - "* Number of features in view 1\n", - "* Number of features in view 2\n", - "* Level of sparsity in view 1 (fraction of variables from view 1 involved in correlation)\n", - "* Level of sparsity in view 2 (fraction of variables from view 2 involved in correlation)\n", - "\n", - "As well as the strength of the population correlation and the covariance structure within each view. We assume that the underlying correlation is perfect (1) and that the within-view variance is identity.\n", - "\n", - "This data generation process is also equivalent to a latent variable model where view 1 and view 2 are conditionally independent given some latent variable Z." - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "CQ_1X7_inoCH", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 524 - }, - "cellView": "form", - "outputId": "2b82696b-dad7-45c2-b5f2-6c66adc35968" - }, - "source": [ - "# @markdown Execute this cell to choose data parameters! You can see the projections of the generated data using the true weights. These should have almost perfect correlations in both the training and test data.\n", - "style = {'description_width': 'initial'}\n", - "\n", - "N_train= widgets.IntSlider(value=100,min=10,max=500,description='Train Samples',style=style,continuous_update=False)\n", - "N_test= widgets.IntSlider(value=100,min=10,max=500,description='Test Samples',style=style,continuous_update=False)\n", - "X_features=widgets.IntSlider(value=60,min=10,max=500,description='View 1 features',style=style,continuous_update=False)\n", - "Y_features=widgets.IntSlider(value=60,min=10,max=500,description='View 2 features',style=style,continuous_update=False)\n", - "view_1_sparsity=widgets.FloatSlider(value=0.5,min=0,max=1,description='View 1 Sparsity',style=style,continuous_update=False)\n", - "view_2_sparsity=widgets.FloatSlider(value=0.5,min=0,max=1,description='View 2 Sparsity',style=style,continuous_update=False)\n", - "\n", - "def generate_data(N_train,N_test,X_features,Y_features, view_1_sparsity, view_2_sparsity):\n", - " (X,Y),(tx,ty)=generate_covariance_data(N_train+N_test,view_features=[X_features,Y_features],latent_dims=1,correlation=1, view_sparsity=[view_1_sparsity, view_2_sparsity])\n", - " X_tr,X_te,Y_tr,Y_te=train_test_split(X,Y,train_size=N_train)\n", - " plot_latent_train_test(np.stack((X_tr@tx,Y_tr@ty)), np.stack((X_te@tx,Y_te@ty)), title='Projections of data in latent space using true model weights')\n", - " return (X_tr,X_te,Y_tr,Y_te,tx,ty)\n", - "\n", - "out=widgets.interactive(generate_data, N_train=N_train,N_test=N_test,X_features=X_features,Y_features=Y_features,view_1_sparsity=view_1_sparsity,view_2_sparsity=view_2_sparsity)\n", - "display(out)" - ], - "execution_count": 4, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "484098178a6341af93cb9d9f95172a31", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(IntSlider(value=100, continuous_update=False, description='Train Samples', max=500, min=…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7bQazPYUFtVD" - }, - "source": [ - "## Overfitting and Sample Size \n", - "When the sample size is smaller than the number of features, CCA will overfit the data (you can try this by changing the data parameters). Visually, CCA overfitting looks like all the training samples are projected with perfect correlation but the test samples are almost completely uncorrelated. The intuition behind this is that there will be at least one spurious relationship when the data are not full rank.\n", - "\n", - "PLS on the other hand is always well-posed even when the number of samples is less than the number of features. However since PLS maximises covariance rather than correlation, the correlation of the latent variables in the training data is typically much lower. " - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 383, - "referenced_widgets": [ - "d5f617ab68ec42b0a81333040a03ce70", - "082d16684fce464e8ab300fd88db366e", - "42a064361141414d9735dd76a475268a", - "00c063db01344e58b6f4155605f5036e", - "eb5dd98764d741b8ac1d86b5ac566ad2", - "76bd49f092e24228924699e3fd9b735c", - "f2c9bdfd43f4436dab5dfcf0cf8e8b80" - ] - }, - "id": "p8Yy3n_mPsJI", - "outputId": "2d1a9ff0-f548-4417-c52a-3b87a4720cf7" - }, - "source": [ - "# @markdown Execute this cell to toggle between CCA and PLS models (you may need to toggle when you change the data!)\n", - "X_tr,X_te,Y_tr,Y_te, tx,ty=out.result\n", - "tog=widgets.ToggleButtons(\n", - " options=['CCA', 'PLS'],\n", - " description='Model:',\n", - " disabled=False,\n", - " button_style='', # 'success', 'info', 'warning', 'danger' or ''\n", - ")\n", - "def interactive_cca(tog):\n", - " if tog=='CCA':\n", - " rcca=rCCA(latent_dims=1,c=1e-9).fit([X_tr,Y_tr])\n", - " elif tog=='PLS':\n", - " rcca=rCCA(latent_dims=1,c=1).fit([X_tr,Y_tr])\n", - " test_scores=rcca.transform([X_te,Y_te])\n", - " plot_latent_train_test(rcca.scores,test_scores)\n", - "\n", - "plot_widget=widgets.interactive(interactive_cca,tog=tog)\n", - "display(plot_widget)" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d5f617ab68ec42b0a81333040a03ce70", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(ToggleButtons(description='Model:', options=('CCA', 'PLS'), value='CCA'), Output()), _do…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "jDWExSvJKvfv" - }, - "source": [ - "## Ridge Regularised CCA: from CCA to PLS \n", - "\n", - "Given that PLS has an effect that is similar to ridge regularisation, a natural extension is to mix CCA and PLS to control this regularisation effect. Vinod proposed the Canonical Ridge model [2] which combines the constraints from the two models with mixing parameters $c_1$ and $c_2$ which make the model more or less CCA-like or PLS-like.\n", - "\n", - "$w_{opt}=\\underset{w}{\\mathrm{argmax}}\\{ w_1^TX_1^TX_2w_2 \\}$\n", - "\n", - "$\\text{subject to:}$\n", - "\n", - "$(1-c_1)\\|X_1w_1\\|_2+c_1\\|w_1\\|_2=1$\n", - "\n", - "$(1-c_2)\\|X_2w_2\\|_2+c_2\\|w_2\\|_2=1$\n", - "\n", - "In the next interactive widget, you can vary $c_1$ and $c_2$ yourself. The plots are provided to help visualise:\n", - "\n", - "* The projections into latent space\n", - "* The train and test correlations\n", - "* The model weights vs. the true weights" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 642, - "referenced_widgets": [ - "4186484e492348708679feb6fc1a5937", - "e9033710a974478088014f255d040332", - "a10e2175e99542f8984a5e1ac038d6c4", - "07f1da8af80c4a33873945c3d206835d", - "10adc5e3157243a5ba6c391d59254072", - "e17dc362cb1c46379c2ed85b398d9250", - "098c336771204a41aa40df3c5a820c70", - "690cbdac87c94770afaa6fae82747ec4", - "1b30b8c885004e3e8e4aaeb9c703de24", - "adc24ca24d0748edb2e13da7e26f635d" - ] - }, - "id": "AmfuhADyLyrz", - "outputId": "f5c0bf75-22ce-41ad-e646-8abe5e105213" - }, - "source": [ - "# @markdown Execute this cell to change model regularisation (there's a bit of a lag as the model needs to fit in the background!)\n", - "X_tr,X_te,Y_tr,Y_te, tx,ty=out.result\n", - "style = {'description_width': 'initial'}\n", - "cx=widgets.FloatLogSlider(base=10,value=-1, min=-10, max=0,description='cx',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "cy=widgets.FloatLogSlider(base=10,value=-1, min=-10, max=0,description='cy',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "def interactive_cca(cx,cy):\n", - " rcca=rCCA(latent_dims=1,c=[cx,cy]).fit([X_tr,Y_tr])\n", - " test_scores=rcca.transform(X_te,Y_te)\n", - " plot_latent_train_test(rcca.scores,test_scores)\n", - " plot_train_test_corrs(rcca.scores,test_scores)\n", - " plot_model_weights(rcca.weights[0],rcca.weights[1],tx,ty, title='Model weights vs. True weights')\n", - "\n", - "plot_widget=widgets.interactive(interactive_cca, cx=cx,cy=cy)\n", - "display(plot_widget)" - ], - "execution_count": 6, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "4186484e492348708679feb6fc1a5937", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(FloatLogSlider(value=1e-10, continuous_update=False, description='cx', max=0.0, min=-10.…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2CYEeK4Gp51n" - }, - "source": [ - "As we increase $c_1$ and $c_2$, the correlation of the training data decreases (and the covariance increases). \n", - "\n", - "There is often a 'sweet spot' somewhere between CCA and PLS where the test correlation is maximised. You can see the train and test correlations in the bar plot.\n", - "\n", - "We now also display the learnt model weights alongside the true weights." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DSSQVqm80IX3" - }, - "source": [ - "## Lasso Regularisation \n", - "Just as the canonical ridge is analagous to ridge regularised regression, we can also use lasso regularisation in PLS and CCA models. Lasso regularisation is helpful in biomedical applications as the effects we are looking for are often parsimonious (only a small number of the features contain a signal) and the natural feature selection mechanism helps interpretability.\n", - "\n", - "In this section, we introduce Sparse Partial Least Squares and Sparse Canonical Correlation Analysis." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "r_r4wkcoLClz" - }, - "source": [ - "### Sparse Partial Least Squares \n", - "By adding an additional constraint on the weights, Witten demonstrated a sparse partial least squares model [3] able to recover signals that were only present in a fraction of the variables. Since this method adapts the PLS optimization problem, it also tends to show the same ridge regularisation properties.\n", - "\n", - "$w_{opt}=\\underset{w}{\\mathrm{argmax}}\\{ w_1^TX_1^TX_2w_2 \\}$\n", - "\n", - "$\\text{subject to:}$\n", - "\n", - "$\\|w_1\\|_2=1$,\n", - "$\\|w_1\\|_2=1$\n", - "\n", - "$\\|w_1\\|_1\\leq c_1$,\n", - "$\\|w_2\\|_1\\leq c_2$\n", - "\n", - "reducing the value of $c_1$ and $c_2$ leads to a sparser solution with more zero weights \n", - "\n", - "Have a look for yourself!" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "vOowgyeJ2xV_", - "colab": { - "base_uri": "https://localhost:8080/", - "height": 956, - "referenced_widgets": [ - "fd8e3e038fa446c48f29f3c719a88ee8", - "f5588cc41bfe4dfc99d34f0b8c7092ed", - "9032567411404f3d9e85862bff74b0be", - "08ff7d25488b4b53ad841974fa97e8bd", - "53f71b9ae00a4b4cb836d40358392790", - "1947b94e59c0410b875b21164719c876", - "b2afc9cdf727400487ce19f9f0ce7d53", - "16b08cb41fc94870b7578021c7927d1a", - "654c62b8b9c3489c9f03fc76e81a2ed9", - "83245b18235e4bad8141bc399ea5cd94" - ] - }, - "outputId": "92f7440c-26d2-43cb-dbff-482d6eadbe6a" - }, - "source": [ - "# @markdown Execute this cell to change model regularisation\n", - "X_tr,X_te,Y_tr,Y_te, tx,ty=out.result\n", - "style = {'description_width': 'initial'}\n", - "c1=widgets.FloatSlider(value=3, min=1, max=np.sqrt(X_tr.shape[1]),description='c1',readout=True,readout_format='.5f',style=style,continuous_update=False)\n", - "c2=widgets.FloatSlider(value=3, min=1, max=np.sqrt(Y_tr.shape[1]),description='c2',readout=True,readout_format='.5f',style=style,continuous_update=False)\n", - "def interactive_cca(c1,c2):\n", - " spls=PMD(latent_dims=1,c=[c1,c2]).fit([X_tr,Y_tr])\n", - " test_scores=spls.transform([X_te,Y_te])\n", - " plot_latent_train_test(spls.scores,test_scores)\n", - " plot_train_test_corrs(spls.scores,test_scores)\n", - " plot_model_weights(spls.weights[0],spls.weights[1],tx,ty, title='Model weights vs. True weights')\n", - "\n", - "plot_widget=widgets.interactive(interactive_cca, c1=c1,c2=c2)\n", - "display(plot_widget)" - ], - "execution_count": 7, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fd8e3e038fa446c48f29f3c719a88ee8", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(FloatSlider(value=3.0, continuous_update=False, description='c1', max=7.745966692414834,…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "53q1yj1CLGG6" - }, - "source": [ - "### Sparse Canonical Correlation Analysis \n", - "There have been a number of attempts to construct sparse CCA models. One of the most succesful is Mai's variant [4] which uses a lasso penalty with a form much like lasso regression [7].\n", - "\n", - "\n", - "$w_{opt}=\\underset{w}{\\mathrm{argmax}}\\{ w_1^TX_1^TX_2w_2 - c_1\\|w_1\\|_1- c_2\\|w_2\\|_1\\}$\n", - "\n", - "$\\text{subject to:}$\n", - "\n", - "$w_1^TX_1^TX_1w_1=1$\n", - "\n", - "$w_2^TX_2^TX_2w_2=1$\n", - "\n", - "This is a slightly different form to the sparse PLS model. Since the 1-norm of the weights penalizes the objective rather than being a constraint, a higher value of $c_1$ and $c_2$ leads to a more sparse (more weights equal to zero) solution.\n", - "\n", - "Have a go for yourself!" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 956, - "referenced_widgets": [ - "d841bfa912f54aec8f0b185e97f33c06", - "da249cda804d4ca9b5ff93c27bfb3a65", - "71bdb348f5804e45a2adeb256ee72c72", - "116218d7741949e08f547ac9844aa7e1", - "9a3e2484604b49fdbae2dad6d60f43b2", - "e93f9ba98fe34786a3064647deae6cfd", - "f2174b7c3df64f278e748afaf20ff90d", - "94b31bb0dadf44eaa94e43614a626040", - "6c7b054a3a7b4aa6bae9122dfcb83bde", - "f2be86b07cc34fec98d8a0ac03f6e410" - ] - }, - "id": "UYKWAgTTLI5C", - "outputId": "88b69d24-5768-4880-9371-6619f5b21320" - }, - "source": [ - "# @markdown Execute this cell to change model regularisation\n", - "X_tr,X_te,Y_tr,Y_te, tx,ty=out.result\n", - "style = {'description_width': 'initial'}\n", - "c1=widgets.FloatLogSlider(value=0.005,base=10,min=-10, max=0,description='c1',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "c2=widgets.FloatLogSlider(value=0.005,base=10,min=-10, max=0,description='c2',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "def interactive_cca(c1,c2):\n", - " scca=SCCA(latent_dims=1,c=[c1,c2]).fit([X_tr,Y_tr])\n", - " test_scores=scca.transform([X_te,Y_te])\n", - " plot_latent_train_test(scca.scores,test_scores)\n", - " plot_train_test_corrs(scca.scores,test_scores)\n", - " plot_model_weights(scca.weights[0],scca.weights[1],tx,ty, title='Model weights vs. True weights')\n", - "\n", - "plot_widget=widgets.interactive(interactive_cca, c1=c1,c2=c2)\n", - "display(plot_widget)" - ], - "execution_count": 8, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "d841bfa912f54aec8f0b185e97f33c06", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(FloatLogSlider(value=0.005, continuous_update=False, description='c1', max=0.0, min=-10.…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "cnhys1o9Xdh7" - }, - "source": [ - "## Apply what we have learnt: a neuroimaging example \n", - "\n", - "This part leans heavily on the nilearn [5,6] tutorial in:\n", - "https://nilearn.github.io/auto_examples/02_decoding/plot_haxby_different_estimators.html#sphx-glr-auto-examples-02-decoding-plot-haxby-different-estimators-py.\n", - "\n", - "The Haxby dataset a well known neuroscience dataset which contains functional magnetic resonance images (fMRI) from various subjects while they observe different tasks (e.g. house, cat). \n", - "\n", - "We will apply our multivariate models to this data.\n", - "\n", - "We take masked versions of the fMRI data as view 1. As the other view we will take one hot encoded matrix of the 12 task labels. We will also add 12 random columns to view 2 which we hope our model will learn to ignore since there should be no correlated signal with the fMRI data.\n", - "\n", - "![Hax](https://raw.githubusercontent.com/jameschapman19/education-challenge/main/image3.png)" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "1YS65dje8BHO", - "cellView": "form", - "outputId": "a4060ed6-6733-4d67-ee19-44c8ad33845d" - }, - "source": [ - "# @markdown Execute this cell to fetch and preprocess the Haxby Dataset!\n", - "haxby_dataset = datasets.fetch_haxby()\n", - "mask_filename = haxby_dataset.mask_vt[0]\n", - "fmri_filename = haxby_dataset.func[0]\n", - "behavioural = pd.read_csv(haxby_dataset.session_target[0], delimiter=' ')\n", - "condition_mask=np.ones(len(behavioural),dtype=bool)\n", - "fmri_niimgs = index_img(fmri_filename,condition_mask)\n", - "session_label = behavioural['chunks']\n", - "masker = NiftiMasker(mask_img=mask_filename, sessions=session_label,\n", - " smoothing_fwhm=4, standardize=True,\n", - " memory=\"nilearn_cache\", memory_level=1)\n", - "#Transform data so it is usable\n", - "fmri_masked = masker.fit_transform(fmri_niimgs)\n", - "session_label = OneHotEncoder(handle_unknown='ignore').fit_transform(session_label[:,None])\n", - "session_label = np.hstack((session_label.toarray(),np.zeros((session_label.shape[0],12))))\n", - "session_label = session_label + 0.00001*np.random.normal(size=(session_label.shape))\n", - "HX_tr,HX_te,HY_tr,HY_te=train_test_split(fmri_masked,session_label)" - ], - "execution_count": 9, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n", - "Dataset created in /root/nilearn_data/haxby2001\n", - "\n", - "Downloading data from https://www.nitrc.org/frs/download.php/7868/mask.nii.gz ...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - " ...done. (0 seconds, 0 min)\n", - " ...done. (0 seconds, 0 min)\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Downloading data from http://data.pymvpa.org/datasets/haxby2001/MD5SUMS ...\n", - "Downloading data from http://data.pymvpa.org/datasets/haxby2001/subj2-2010.01.14.tar.gz ...\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "Downloaded 290586624 of 291168628 bytes (99.8%, 0.0s remaining) ...done. (9 seconds, 0 min)\n", - "Extracting data from /root/nilearn_data/haxby2001/f33ff337e914bf7fded743c7107979f9/subj2-2010.01.14.tar.gz..... done.\n", - "/usr/local/lib/python3.7/dist-packages/nilearn/_utils/helpers.py:145: FutureWarning: The parameter \"sessions\" will be removed in 0.9.0 release of Nilearn. Please use the parameter \"runs\" instead.\n", - " return func(*args, **kwargs)\n", - "/usr/local/lib/python3.7/dist-packages/nilearn/input_data/nifti_masker.py:529: UserWarning: Persisting input arguments took 1.69s to run.\n", - "If this happens often in your code, it can cause performance problems \n", - "(results will be correct in all cases). \n", - "The reason for this is probably some large input arguments for a wrapped\n", - " function (e.g. large strings).\n", - "THIS IS A JOBLIB ISSUE. If you can, kindly provide the joblib's team with an\n", - " example so that they can fix the problem.\n", - " dtype=self.dtype\n", - "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:14: FutureWarning: Support for multi-dimensional indexing (e.g. `obj[:, None]`) is deprecated and will be removed in a future version. Convert to a numpy array before indexing instead.\n", - " \n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xOG2MTBtMD0M" - }, - "source": [ - "### Regularised CCA " - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 414, - "referenced_widgets": [ - "e7c5993c7704484db6390038d50e8e52", - "156c40392f954961b7c6cebee3912d7d", - "13fc9e6c26604f58a3a821c119bd268a", - "73ae601863af45f6b0d97be47ab8149b", - "04140cf3fa164a64930ec261b9f556f6", - "1172e605bb634b31ad0c9a437db7c047", - "33780801a70f4d168ec7fd904ee99685", - "e7a68ca2f9794421b2f047ddadfc6741", - "6aa67cbfb05b4ff9ba03eb365e94c458", - "d83a784c9575476e86292a2f795a439c" - ] - }, - "id": "kAXZD20_Mhl8", - "outputId": "9a420f5a-9182-4c2d-832c-1be65f174bd2" - }, - "source": [ - "# @markdown Execute this cell to change model regularisation (there's a bit of a lag as the model needs to fit in the background!)\n", - "style = {'description_width': 'initial'}\n", - "cx=widgets.FloatLogSlider(base=10,value=-1, min=-10, max=0,description='cx',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "cy=widgets.FloatLogSlider(base=10,value=-1, min=-10, max=0,description='cy',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "def interactive_cca(cx,cy):\n", - " rcca=rCCA(latent_dims=1,c=[cx,cy]).fit([HX_tr,HY_tr])\n", - " test_scores=rcca.transform([HX_te,HY_te])\n", - " plot_latent_train_test(rcca.scores,test_scores)\n", - " plot_train_test_corrs(rcca.scores,test_scores)\n", - " plot_model_weights(rcca.weights[0],rcca.weights[1], title='Model weights')\n", - "\n", - "plot_widget=widgets.interactive(interactive_cca, cx=cx,cy=cy)\n", - "display(plot_widget)" - ], - "execution_count": 10, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "e7c5993c7704484db6390038d50e8e52", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(FloatLogSlider(value=1e-10, continuous_update=False, description='cx', max=0.0, min=-10.…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O3tNrjn4MDp5" - }, - "source": [ - "### Sparse PLS " - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 956, - "referenced_widgets": [ - "2c3a73827f794598bb0bb9a749c4b70f", - "f4a50fa01e66441090ebd78f4be413d4", - "6590606b34d643fb8c6c7e06599d5bca", - "886e8a7d34ca460db623ee35abce05a6", - "6097fdf1adb94dc9bf89bd5e0a101fc8", - "75aa1b12f01448198ad356465332fffd", - "6049c8f1a0c54f11955956952181b242", - "ed894f48599f4920bf247584398d2e0e", - "ca3f1c4c5a804634a7f13b0b9ab0ebdd", - "64aa754a29844764b480dc5c91b700be" - ] - }, - "id": "OmbViQYYMjjL", - "outputId": "8f6727d4-9fad-40b4-bf73-920ed56a20fd" - }, - "source": [ - "# @markdown Execute this cell to change model regularisation\n", - "style = {'description_width': 'initial'}\n", - "c1=widgets.FloatSlider(value=3, min=1, max=np.sqrt(HX_tr.shape[1]),description='c1',readout=True,readout_format='.5f',style=style,continuous_update=False)\n", - "c2=widgets.FloatSlider(value=3, min=1, max=np.sqrt(HY_tr.shape[1]),description='c2',readout=True,readout_format='.5f',style=style,continuous_update=False)\n", - "def interactive_cca(c1,c2):\n", - " spls=PMD(latent_dims=1,c=[c1,c2]).fit([HX_tr,HY_tr])\n", - " test_scores=spls.transform([HX_te,HY_te])\n", - " plot_latent_train_test(spls.scores,test_scores)\n", - " plot_train_test_corrs(spls.scores,test_scores)\n", - " plot_model_weights(spls.weights[0],spls.weights[1], title='Model weights')\n", - "\n", - "plot_widget=widgets.interactive(interactive_cca, c1=c1,c2=c2)\n", - "display(plot_widget)" - ], - "execution_count": 11, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "2c3a73827f794598bb0bb9a749c4b70f", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(FloatSlider(value=3.0, continuous_update=False, description='c1', max=21.540659228538015…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "BiFXzjF2MDMh" - }, - "source": [ - "### Sparse CCA " - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 956 - }, - "id": "xzUD-ylNgq2L", - "cellView": "form", - "outputId": "8a9780ef-5566-4bc1-eade-6ba08a0acc66" - }, - "source": [ - "# @markdown Execute this cell to change model regularisation\n", - "style = {'description_width': 'initial'}\n", - "c1=widgets.FloatLogSlider(value=0.005,base=10,min=-10, max=0,description='c1',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "c2=widgets.FloatLogSlider(value=0.005,base=10,min=-10, max=0,description='c2',readout=True,readout_format='.9f',style=style,continuous_update=False)\n", - "def interactive_cca(c1,c2):\n", - " scca=SCCA(latent_dims=1,c=0.01,initialization='random').fit([HX_tr,HY_tr])\n", - " test_scores=scca.transform([HX_te,HY_te])\n", - " plot_latent_train_test(scca.scores,test_scores)\n", - " plot_train_test_corrs(scca.scores,test_scores)\n", - " plot_model_weights(scca.weights[0],scca.weights[1], title='Model weights')\n", - " \n", - "plot_widget=widgets.interactive(interactive_cca, c1=c1,c2=c2)\n", - "display(plot_widget)" - ], - "execution_count": 12, - "outputs": [ - { - "output_type": "display_data", - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "5df438a7257a4684b33107ad602d5bb4", - "version_minor": 0, - "version_major": 2 - }, - "text/plain": [ - "interactive(children=(FloatLogSlider(value=0.005, continuous_update=False, description='c1', max=0.0, min=-10.…" - ] - }, - "metadata": {} - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JWvWx2SdwDlB" - }, - "source": [ - "## Conclusion \n", - "We have introduced Canonical Correlation Analysis and Partial Least Squares in an interactive way.\n", - "\n", - "We showed how regularisation can be used to improve the generalizability and interpretability of our models.\n", - "\n", - "We applied these models to a well-known neuroimaging dataset.\n", - "\n", - "We introduced cca-zoo, an open-source python package that implements variants of CCA and PLS models with different regularisation effects. The goal of the package is to give researchers access to flexible regularisation methods. \n", - "\n", - "Further documentation is available at https://cca-zoo.readthedocs.io/en/latest/index.html \n", - "\n", - "the source code is available at https://github.com/jameschapman19/cca_zoo\n", - "\n", - "\n", - "### Acknowledgements\n", - "Thanks to the authors of the winners of last year's MEC 'Introduction to Medical Image Registration with DeepReg, Between Old and New'[8] for the basic structure of a tutorial notebook. Also thanks to the Neuromatch Academy summer school [9] who's extensive use of notebooks with widgets taught me everything I know!\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "holjn-oSNdy6" - }, - "source": [ - "## References " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IvJXI8Vg-OMM" - }, - "source": [ - "[1] Helmer, Markus, et al. \"On stability of Canonical Correlation Analysis and Partial Least Squares with application to brain-behavior associations.\" BioRxiv (2021): 2020-08.\n", - "\n", - "[2] Vinod, Hrishikesh D. \"Canonical ridge and econometrics of joint production.\" Journal of econometrics 4.2 (1976): 147-166.\n", - "\n", - "[3] Witten, Daniela M., Robert Tibshirani, and Trevor Hastie. \"A penalized matrix decomposition, with applications to sparse principal components and canonical correlation analysis.\" Biostatistics 10.3 (2009): 515-534.\n", - "\n", - "[4] Mai, Qing, and Xin Zhang. \"An iterative penalized least squares approach to sparse canonical correlation analysis.\" Biometrics 75.3 (2019): 734-744.\n", - "\n", - "[5] Abraham, Alexandre, et al. \"Machine learning for neuroimaging with scikit-learn.\" Frontiers in neuroinformatics 8 (2014): 14.\n", - "\n", - "[6] Pedregosa, Fabian, et al. \"Scikit-learn: Machine learning in Python.\" the Journal of machine Learning research 12 (2011): 2825-2830.\n", - "\n", - "[7] Tibshirani, Robert. \"Regression shrinkage and selection via the lasso.\" Journal of the Royal Statistical Society: Series B (Methodological) 58.1 (1996): 267-288.\n", - "\n", - "[8] Fu, Yunguan, et al. \"DeepReg: a deep learning toolkit for medical image registration.\" arXiv preprint arXiv:2011.02580 (2020).\n", - "\n", - "[9] van Viegen, Tara, et al. \"Neuromatch Academy: teaching computational neuroscience with global accessibility.\" arXiv preprint arXiv:2012.08973 (2020)." - ] - } - ] -} \ No newline at end of file diff --git a/tutorial_notebooks/cca_zoo_mnist.ipynb b/tutorial_notebooks/cca_zoo_mnist.ipynb deleted file mode 100644 index 09f8fda5..00000000 --- a/tutorial_notebooks/cca_zoo_mnist.ipynb +++ /dev/null @@ -1,2356 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 2 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" - }, - "colab": { - "name": "cca_zoo_tutorial.ipynb", - "provenance": [], - "include_colab_link": true - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "Gn1KbNv1vLTj" - }, - "source": [ - "# A tutorial comparing the train and test correlations of different models on MNIST data" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "gYoDpAd1vLTk", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "3e927908-a3b5-4e5a-bce3-d807688c7a9a" - }, - "source": [ - "!pip install --upgrade cca-zoo[deep,probabilistic]" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Requirement already satisfied: cca-zoo[deep,probabilistic] in /usr/local/lib/python3.7/dist-packages (1.8.0)\n", - "Requirement already satisfied: mvlearn in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.4.1)\n", - "Requirement already satisfied: scipy>=1.7 in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (1.7.1)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (3.2.2)\n", - "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (1.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (1.19.5)\n", - "Requirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.11.2)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (1.1.5)\n", - "Requirement already satisfied: tensorly in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.6.0)\n", - "Requirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (1.0.1)\n", - "Requirement already satisfied: jax~=0.2.20 in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.2.20)\n", - "Requirement already satisfied: arviz in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.11.2)\n", - "Requirement already satisfied: numpyro in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.8.0)\n", - "Requirement already satisfied: torch>=1.9.0 in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (1.9.0+cu102)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (0.10.0+cu102)\n", - "Requirement already satisfied: Pillow in /usr/local/lib/python3.7/dist-packages (from cca-zoo[deep,probabilistic]) (7.1.2)\n", - "Requirement already satisfied: absl-py in /usr/local/lib/python3.7/dist-packages (from jax~=0.2.20->cca-zoo[deep,probabilistic]) (0.12.0)\n", - "Requirement already satisfied: opt-einsum in /usr/local/lib/python3.7/dist-packages (from jax~=0.2.20->cca-zoo[deep,probabilistic]) (3.3.0)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=1.9.0->cca-zoo[deep,probabilistic]) (3.7.4.3)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from absl-py->jax~=0.2.20->cca-zoo[deep,probabilistic]) (1.15.0)\n", - "Requirement already satisfied: xarray>=0.16.1 in /usr/local/lib/python3.7/dist-packages (from arviz->cca-zoo[deep,probabilistic]) (0.18.2)\n", - "Requirement already satisfied: netcdf4 in /usr/local/lib/python3.7/dist-packages (from arviz->cca-zoo[deep,probabilistic]) (1.5.7)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.7/dist-packages (from arviz->cca-zoo[deep,probabilistic]) (21.0)\n", - "Requirement already satisfied: setuptools>=38.4 in /usr/local/lib/python3.7/dist-packages (from arviz->cca-zoo[deep,probabilistic]) (57.4.0)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo[deep,probabilistic]) (1.3.2)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo[deep,probabilistic]) (2.4.7)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo[deep,probabilistic]) (0.10.0)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo[deep,probabilistic]) (2.8.2)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->cca-zoo[deep,probabilistic]) (2018.9)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->cca-zoo[deep,probabilistic]) (2.2.0)\n", - "Requirement already satisfied: cftime in /usr/local/lib/python3.7/dist-packages (from netcdf4->arviz->cca-zoo[deep,probabilistic]) (1.5.0)\n", - "Requirement already satisfied: jaxlib>=0.1.65 in /usr/local/lib/python3.7/dist-packages (from numpyro->cca-zoo[deep,probabilistic]) (0.1.71+cuda111)\n", - "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from numpyro->cca-zoo[deep,probabilistic]) (4.62.2)\n", - "Requirement already satisfied: flatbuffers<3.0,>=1.12 in /usr/local/lib/python3.7/dist-packages (from jaxlib>=0.1.65->numpyro->cca-zoo[deep,probabilistic]) (1.12)\n", - "Requirement already satisfied: nose in /usr/local/lib/python3.7/dist-packages (from tensorly->cca-zoo[deep,probabilistic]) (1.3.7)\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "O5VEWk1BvLTl", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "5870455a-2f40-40ec-8878-6ecc9d9e2a89" - }, - "source": [ - "# Imports\n", - "import numpy as np\n", - "from cca_zoo.data import Noisy_MNIST_Dataset\n", - "import itertools\n", - "import matplotlib.pyplot as plt\n", - "from torch.utils.data import Subset\n", - "from torch import optim\n", - "from cca_zoo.deepmodels import objectives, architectures, DeepWrapper, DCCA,DCCA_NOI,DVCCA,DCCAE,DTCCA\n", - "from sklearn.utils.fixes import loguniform\n", - "# Load MNIST Data\n", - "N = 500\n", - "dataset = Noisy_MNIST_Dataset(mnist_type='FashionMNIST', train=True)\n", - "ids = np.arange(min(2 * N, len(dataset)))\n", - "np.random.shuffle(ids)\n", - "train_ids, val_ids = np.array_split(ids, 2)\n", - "val_dataset = Subset(dataset, val_ids)\n", - "train_dataset = Subset(dataset, train_ids)\n", - "test_dataset = Noisy_MNIST_Dataset(mnist_type='FashionMNIST', train=False)\n", - "test_ids = np.arange(min(N, len(test_dataset)))\n", - "np.random.shuffle(test_ids)\n", - "test_dataset = Subset(test_dataset, test_ids)\n", - "(train_view_1, train_view_2),_ = train_dataset.dataset.to_numpy(\n", - " train_dataset.indices)\n", - "(val_view_1, val_view_2),_ = val_dataset.dataset.to_numpy(val_dataset.indices)\n", - "(test_view_1, test_view_2),_ = test_dataset.dataset.to_numpy(\n", - " test_dataset.indices)\n", - "\n", - "# Settings\n", - "\n", - "# The number of latent dimensions across models\n", - "latent_dims = 2\n", - "# The number of cv used for cross-validation/hyperparameter tuning\n", - "cv = 3\n", - "# For running hyperparameter tuning in parallel (0 if not)\n", - "jobs = 4\n", - "# Number of iterations for iterative algorithms\n", - "max_iter = 2\n", - "# number of epochs for deep models\n", - "epochs = 50" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)\n", - " return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "C97m5-5tvLTn" - }, - "source": [ - "# Canonical Correlation Analysis" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "gKYi0wtkvLTn" - }, - "source": [ - "from cca_zoo.models import CCA, CCA_ALS\n", - "\"\"\"\n", - "### Linear CCA by eigendecomposition\n", - "\"\"\"\n", - "linear_cca = CCA(latent_dims=latent_dims)\n", - "\n", - "linear_cca.fit((train_view_1, train_view_2))\n", - "\n", - "linear_cca_results = np.stack(\n", - " (linear_cca.score((train_view_1, train_view_2)), linear_cca.score((test_view_1, test_view_2))))\n", - "\n", - "\"\"\"\n", - "### Linear CCA by alternating least squares (can pass more than 2 views)\n", - "\"\"\"\n", - "\n", - "linear_cca_als = CCA_ALS(latent_dims=latent_dims)\n", - "\n", - "linear_cca_als.fit((train_view_1, train_view_2))\n", - "\n", - "linear_cca_als_results = np.stack(\n", - " (linear_cca_als.score((train_view_1, train_view_2)), linear_cca_als.score((test_view_1, test_view_2))))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "OeqtGYW6vLTo" - }, - "source": [ - "# Partial Least Squares\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "-Z4jHKvovLTp" - }, - "source": [ - "from cca_zoo.models import PLS, PLS_ALS\n", - "\"\"\"\n", - "### PLS (2 views)\n", - "\"\"\"\n", - "pls = PLS(latent_dims=latent_dims)\n", - "\n", - "pls.fit((train_view_1, train_view_2))\n", - "\n", - "pls_results = np.stack(\n", - " (pls.score((train_view_1, train_view_2)), pls.score((test_view_1, test_view_2))))\n", - "\n", - "pls_als = PLS_ALS(latent_dims=latent_dims)\n", - "\n", - "pls_als.fit((train_view_1, train_view_2))\n", - "\n", - "pls_als_results = np.stack(\n", - " (pls_als.score((train_view_1, train_view_2)), pls_als.score((test_view_1, test_view_2))))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "thDJioUZvLTp" - }, - "source": [ - "# Extension to multiple views\n", - "\n" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "CYTlHO8qvLTq", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "7f043566-28b6-4a0c-91ea-bd4efb9083f1" - }, - "source": [ - "from cca_zoo.models import GCCA, MCCA, PLS_ALS\n", - "\"\"\"\n", - "### (Regularized) Generalized CCA(can pass more than 2 views)\n", - "\"\"\"\n", - "train_view_3=train_view_1+np.random.rand(*train_view_1.shape)\n", - "test_view_3=test_view_1+np.random.rand(*test_view_1.shape)\n", - "\n", - "# small ammount of regularisation added since data is not full rank\n", - "c=[0.5,0.5,0.5]\n", - "\n", - "gcca = GCCA(latent_dims=latent_dims,c=c)\n", - "\n", - "gcca.fit((train_view_1, train_view_2,train_view_3))\n", - "\n", - "gcca_results = np.stack((gcca.score((train_view_1, train_view_2, train_view_3)), gcca.score((test_view_1, test_view_2, test_view_3))))\n", - "\n", - "\"\"\"\n", - "### (Regularized) Multiset CCA(can pass more than 2 views)\n", - "\"\"\"\n", - "\n", - "mcca = MCCA(latent_dims=latent_dims, c=c)\n", - "\n", - "mcca.fit((train_view_1, train_view_2,train_view_1))\n", - "\n", - "mcca_results = np.stack((mcca.score((train_view_1, train_view_2, train_view_3)), mcca.score((test_view_1, test_view_2, test_view_3))))\n", - "\n", - "\"\"\"\n", - "### Multiset CCA by alternating least squares\n", - "\"\"\"\n", - "mcca_als = CCA_ALS(latent_dims=latent_dims, max_iter=max_iter)\n", - "\n", - "mcca_als.fit((train_view_1, train_view_2,train_view_3))\n", - "\n", - "mcca_als_results = np.stack(\n", - " (mcca_als.score((train_view_1, train_view_2, train_view_3)), mcca_als.score((test_view_1, test_view_2, test_view_3))))\n", - "\n", - "\"\"\"\n", - "### Multiset PLS by alternating least squares\n", - "\"\"\"\n", - "mcca_pls = PLS_ALS(latent_dims=latent_dims)\n", - "\n", - "mcca_pls.fit((train_view_1, train_view_2,train_view_1))\n", - "\n", - "mcca_pls_results = np.stack(\n", - " (mcca_als.score((train_view_1, train_view_2, train_view_3)), mcca_pls.score((test_view_1, test_view_2, test_view_3))))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/cca_zoo/models/innerloop.py:82: UserWarning: For more than 2 views require generalized=True\n", - " warnings.warn(\"For more than 2 views require generalized=True\")\n", - "/usr/local/lib/python3.7/dist-packages/cca_zoo/models/innerloop.py:82: UserWarning: For more than 2 views require generalized=True\n", - " warnings.warn(\"For more than 2 views require generalized=True\")\n", - "/usr/local/lib/python3.7/dist-packages/cca_zoo/models/innerloop.py:82: UserWarning: For more than 2 views require generalized=True\n", - " warnings.warn(\"For more than 2 views require generalized=True\")\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "7gMOkBrsvLTr" - }, - "source": [ - "# Tensor CCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "7ee38U6BvLTr", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1543df60-fedf-447d-bada-4e97b15fb816" - }, - "source": [ - "from cca_zoo.models import TCCA\n", - "\"\"\"\n", - "### (Regularized) Tensor CCA(can pass more than 2 views)\n", - "\"\"\"\n", - "\n", - "tcca = TCCA(latent_dims=latent_dims, c=c)\n", - "\n", - "#memory requirement for tensor is massive so take first 100 features\n", - "tcca.fit((train_view_1[:,:100], train_view_2[:,:100],train_view_3[:,:100]))\n", - "\n", - "tcca_results = np.stack((tcca.score((train_view_1[:,:100], train_view_2[:,:100], train_view_3[:,:100])), tcca.score((test_view_1[:,:100], test_view_2[:,:100], test_view_3[:,:100]))))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "reconstruction error=0.9661771276746018\n", - "iteration 1, reconstruction error: 0.9518385892645673, decrease = 0.014338538410034518, unnormalized = 24.974625897501497\n", - "iteration 2, reconstruction error: 0.9507160541458431, decrease = 0.0011225351187241772, unnormalized = 24.945172484955357\n", - "iteration 3, reconstruction error: 0.9507116250799326, decrease = 4.429065910582786e-06, unnormalized = 24.945056273797878\n", - "iteration 4, reconstruction error: 0.950711606421758, decrease = 1.8658174560926e-08, unnormalized = 24.945055784239106\n", - "iteration 5, reconstruction error: 0.950711606295279, decrease = 1.26479049455952e-10, unnormalized = 24.945055780920512\n", - "PARAFAC converged after 5 iterations\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "lsuPKE35vLTs" - }, - "source": [ - "# Weighted GCCA/Missing Observation GCCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "QgHS4svTvLTt" - }, - "source": [ - "#observation_matrix\n", - "K = np.ones((3, N))\n", - "K[0, 200:] = 0\n", - "K[1, :100] = 0\n", - "\n", - "#view weights\n", - "view_weights=[1,2,1.2]\n", - "\n", - "c=[0.5,0.5,0.5]\n", - "\n", - "gcca = GCCA(latent_dims=latent_dims,c=c,view_weights=view_weights)\n", - "\n", - "gcca.fit((train_view_1, train_view_2,train_view_1),K=K)\n", - "\n", - "gcca_results = np.stack((gcca.score((train_view_1, train_view_2)), gcca.score((test_view_1, test_view_2))))" - ], - "execution_count": null, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "4O9JHGcBvLTt" - }, - "source": [ - "# Regularised CCA solutions based on alternating minimisation/alternating least squares\n", - "\n", - "We implement Witten's penalized matrix decomposition form of sparse CCA using 'pmd'\n", - "\n", - "We implement Waaijenborg's penalized CCA using elastic net using 'elastic'\n", - "\n", - "We implement Mai's sparse CCA using 'scca'\n", - "\n", - "Furthermore, any of these methods can be extended to multiple views. Witten describes this method explicitly." - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "-PdP_V7WvLTt", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "0ef4ca0c-36d5-43c3-e38b-914b52687254" - }, - "source": [ - "from cca_zoo.model_selection import GridSearchCV, RandomizedSearchCV\n", - "from cca_zoo.models import rCCA, PMD,SCCA,ElasticCCA\n", - "\n", - "def scorer(estimator,X):\n", - " dim_corrs=estimator.score(X)\n", - " return dim_corrs.mean()\n", - "\n", - "\"\"\"\n", - "### Ridge CCA (can pass more than 2 views)\n", - "\"\"\"\n", - "c1 = [0.1, 0.3, 0.7, 0.9]\n", - "c2 = [0.1, 0.3, 0.7, 0.9]\n", - "param_grid = {'c': [c1,c2]}\n", - "\n", - "ridge = GridSearchCV(rCCA(latent_dims=latent_dims),param_grid=param_grid,\n", - " cv=cv,\n", - " verbose=True,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "\n", - "ridge_results = np.stack((ridge.score((train_view_1,train_view_2)), ridge.score((test_view_1, test_view_2))))\n", - "\n", - "\"\"\"\n", - "### Sparse CCA (Penalized Matrix Decomposition) (can pass more than 2 views)\n", - "\"\"\"\n", - "\n", - "# PMD\n", - "c1 = [1, 3, 7, 9]\n", - "c2 = [1, 3, 7, 9]\n", - "param_grid = {'c': [c1,c2]}\n", - "\n", - "pmd = GridSearchCV(PMD(latent_dims=latent_dims),param_grid=param_grid,\n", - " cv=cv,\n", - " verbose=True,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "\n", - "pmd_results = np.stack((pmd.score((train_view_1,train_view_2)), pmd.score((test_view_1, test_view_2))))\n", - "\n", - "\"\"\"\n", - "### Sparse CCA (can pass more than 2 views)\n", - "\"\"\"\n", - "\n", - "# Sparse CCA\n", - "c1 = [0.00001, 0.0001]\n", - "c2 = [0.00001, 0.0001]\n", - "param_grid = {'c': [c1,c2]}\n", - "\n", - "scca = GridSearchCV(SCCA(latent_dims=latent_dims),param_grid=param_grid,\n", - " cv=cv,\n", - " verbose=True,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "\n", - "scca_results = np.stack(\n", - " (scca.score((train_view_1,train_view_2)), scca.score((test_view_1, test_view_2))))\n", - "\n", - "\n", - "\"\"\"\n", - "### Elastic CCA (can pass more than 2 views)\n", - "\"\"\"\n", - "\n", - "# Elastic CCA\n", - "c1 = loguniform(1e-4, 1e0)\n", - "c2 = loguniform(1e-4, 1e0)\n", - "l1_1 = loguniform(1e-4, 1e0)\n", - "l1_2 = loguniform(1e-4, 1e0)\n", - "param_grid = {'c': [c1,c2], 'l1_ratio': [l1_1,l1_2]}\n", - "\n", - "elastic = RandomizedSearchCV(ElasticCCA(latent_dims=latent_dims),param_distributions=param_grid,\n", - " cv=cv,\n", - " verbose=True,n_iter=5,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "\n", - "elastic_results = np.stack(\n", - " (elastic.score((train_view_1,train_view_2)), elastic.score((test_view_1, test_view_2))))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n", - "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n", - "Fitting 3 folds for each of 4 candidates, totalling 12 fits\n", - "Fitting 3 folds for each of 5 candidates, totalling 15 fits\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "RLvrF3bGvLTu" - }, - "source": [ - "# Kernel CCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "cdu62DxIvLTv", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "1d94ef2a-0242-421e-bdd2-69a80414d682" - }, - "source": [ - "from cca_zoo.models import KCCA\n", - "\"\"\"\n", - "### Kernel CCA\n", - "\n", - "Similarly, we can use kernel CCA methods with [method='kernel']\n", - "\n", - "We can use different kernels and their associated parameters in a similar manner to before\n", - "- regularized linear kernel CCA: parameters : 'kernel'='linear', 0<'c'<1\n", - "- polynomial kernel CCA: parameters : 'kernel'='poly', 'degree', 0<'c'<1\n", - "- gaussian rbf kernel CCA: parameters : 'kernel'='gaussian', 'sigma', 0<'c'<1\n", - "\"\"\"\n", - "# %%\n", - "# r-kernel cca\n", - "c1 = [0.9, 0.99]\n", - "c2 = [0.9, 0.99]\n", - "\n", - "param_grid = {'kernel': ['linear'], 'c': [c1,c2]}\n", - "\n", - "kernel_reg = GridSearchCV(KCCA(latent_dims=latent_dims),param_grid=param_grid,\n", - " cv=cv,\n", - " verbose=True,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "kernel_reg_results = np.stack((\n", - " kernel_reg.score((train_view_1,train_view_2)),\n", - " kernel_reg.score((test_view_1, test_view_2))))\n", - "\n", - "# kernel cca (poly)\n", - "degree1 = [2, 3]\n", - "degree2 = [2, 3]\n", - "\n", - "param_grid = {'kernel': ['poly'], 'degree': [degree1,degree2],\n", - " 'c': [c1,c2]}\n", - "\n", - "kernel_poly = GridSearchCV(KCCA(latent_dims=latent_dims),param_grid=param_grid,\n", - " cv=cv,\n", - " verbose=True,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "\n", - "kernel_poly_results = np.stack((\n", - " kernel_poly.score((train_view_1,train_view_2)),\n", - " kernel_poly.score((test_view_1, test_view_2))))\n", - "\n", - "# kernel cca (gaussian)\n", - "gamma1 = [1e+1, 1e+2, 1e+3]\n", - "gamma2 = [1e+1, 1e+2, 1e+3]\n", - "\n", - "param_grid = {'kernel': ['rbf'], 'gamma': [gamma1,gamma2],\n", - " 'c': [c1,c2]}\n", - "\n", - "kernel_gaussian = GridSearchCV(KCCA(latent_dims=latent_dims),param_grid=param_grid,\n", - " cv=cv,\n", - " verbose=True,scoring=scorer).fit([train_view_1,train_view_2]).best_estimator_\n", - "\n", - "kernel_gaussian_results = np.stack((\n", - " kernel_gaussian.score((train_view_1,train_view_2)),\n", - " kernel_gaussian.score((test_view_1, test_view_2))))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Fitting 3 folds for each of 4 candidates, totalling 12 fits\n", - "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n", - "Fitting 3 folds for each of 36 candidates, totalling 108 fits\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n", - "/usr/local/lib/python3.7/dist-packages/sklearn/model_selection/_search.py:972: UserWarning: One or more of the test scores are non-finite: [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan\n", - " nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan]\n", - " category=UserWarning,\n", - "/usr/local/lib/python3.7/dist-packages/numpy/lib/function_base.py:2559: RuntimeWarning: invalid value encountered in true_divide\n", - " c /= stddev[:, None]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "iJfzMn4KvLTv" - }, - "source": [ - "# Deep CCA\n", - "\n", - "DCCA can be optimized using Andrew's original tracenorm objective or Wang's DCCA by nonlinear orthogonal iterations using the argument als=True." - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "KvW0LO4KvLTw", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "3e57f638-f6c1-483b-a088-aaf23fd1b333" - }, - "source": [ - "\"\"\"\n", - "### Deep Learning\n", - "\n", - "We also have deep CCA methods (and autoencoder variants)\n", - "- Deep CCA (DCCA)\n", - "- Deep Canonically Correlated Autoencoders (DCCAE)\n", - "\n", - "\"\"\"\n", - "\n", - "# %%\n", - "# DCCA\n", - "print('DCCA')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "dcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])\n", - "\n", - "dcca_model = DeepWrapper(dcca_model)\n", - "\n", - "dcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dcca_results = np.stack((dcca_model.score(train_dataset), dcca_model.score(test_dataset)))\n", - "\n", - "# DCCA_NOI\n", - "print('DCCA by non-linear orthogonal iterations')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "dcca_noi_model = DCCA_NOI(latent_dims=latent_dims, encoders=[encoder_1, encoder_2],N=len(train_dataset))\n", - "\n", - "dcca_noi_model = DeepWrapper(dcca_noi_model)\n", - "\n", - "dcca_noi_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dcca_noi_results = np.stack(\n", - " (dcca_noi_model.score(train_dataset), dcca_noi_model.score(test_dataset)))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DCCA\n", - "total parameters: 201476\n", - "====> Epoch: 1 Average train loss: -0.2531\n", - "====> Epoch: 1 Average val loss: -0.1096\n", - "Min loss -0.11\n", - "====> Epoch: 2 Average train loss: -0.1009\n", - "====> Epoch: 2 Average val loss: -0.1308\n", - "Min loss -0.13\n", - "====> Epoch: 3 Average train loss: -0.2040\n", - "====> Epoch: 3 Average val loss: -0.1842\n", - "Min loss -0.18\n", - "====> Epoch: 4 Average train loss: -0.1796\n", - "====> Epoch: 4 Average val loss: -0.1576\n", - "====> Epoch: 5 Average train loss: -0.2058\n", - "====> Epoch: 5 Average val loss: -0.3233\n", - "Min loss -0.32\n", - "====> Epoch: 6 Average train loss: -0.3632\n", - "====> Epoch: 6 Average val loss: -0.4711\n", - "Min loss -0.47\n", - "====> Epoch: 7 Average train loss: -0.4117\n", - "====> Epoch: 7 Average val loss: -0.4762\n", - "Min loss -0.48\n", - "====> Epoch: 8 Average train loss: -0.4564\n", - "====> Epoch: 8 Average val loss: -0.4032\n", - "====> Epoch: 9 Average train loss: -0.5325\n", - "====> Epoch: 9 Average val loss: -0.4755\n", - "====> Epoch: 10 Average train loss: -0.5422\n", - "====> Epoch: 10 Average val loss: -0.5092\n", - "Min loss -0.51\n", - "====> Epoch: 11 Average train loss: -0.7966\n", - "====> Epoch: 11 Average val loss: -0.7684\n", - "Min loss -0.77\n", - "====> Epoch: 12 Average train loss: -0.8014\n", - "====> Epoch: 12 Average val loss: -0.7995\n", - "Min loss -0.80\n", - "====> Epoch: 13 Average train loss: -0.8132\n", - "====> Epoch: 13 Average val loss: -0.7553\n", - "====> Epoch: 14 Average train loss: -0.8314\n", - "====> Epoch: 14 Average val loss: -0.7947\n", - "====> Epoch: 15 Average train loss: -0.8431\n", - "====> Epoch: 15 Average val loss: -0.7909\n", - "====> Epoch: 16 Average train loss: -0.7976\n", - "====> Epoch: 16 Average val loss: -0.8460\n", - "Min loss -0.85\n", - "====> Epoch: 17 Average train loss: -0.8252\n", - "====> Epoch: 17 Average val loss: -0.8178\n", - "====> Epoch: 18 Average train loss: -0.5628\n", - "====> Epoch: 18 Average val loss: -0.5937\n", - "====> Epoch: 19 Average train loss: -0.6322\n", - "====> Epoch: 19 Average val loss: -0.6510\n", - "====> Epoch: 20 Average train loss: -0.7647\n", - "====> Epoch: 20 Average val loss: -0.7421\n", - "====> Epoch: 21 Average train loss: -0.7851\n", - "====> Epoch: 21 Average val loss: -0.7247\n", - "====> Epoch: 22 Average train loss: -0.7602\n", - "====> Epoch: 22 Average val loss: -0.8007\n", - "====> Epoch: 23 Average train loss: -0.8833\n", - "====> Epoch: 23 Average val loss: -0.7915\n", - "====> Epoch: 24 Average train loss: -0.9343\n", - "====> Epoch: 24 Average val loss: -1.0192\n", - "Min loss -1.02\n", - "====> Epoch: 25 Average train loss: -1.0146\n", - "====> Epoch: 25 Average val loss: -0.9402\n", - "====> Epoch: 26 Average train loss: -1.1286\n", - "====> Epoch: 26 Average val loss: -1.1075\n", - "Min loss -1.11\n", - "====> Epoch: 27 Average train loss: -1.3057\n", - "====> Epoch: 27 Average val loss: -1.1418\n", - "Min loss -1.14\n", - "====> Epoch: 28 Average train loss: -1.2272\n", - "====> Epoch: 28 Average val loss: -1.2574\n", - "Min loss -1.26\n", - "====> Epoch: 29 Average train loss: -1.2456\n", - "====> Epoch: 29 Average val loss: -1.2860\n", - "Min loss -1.29\n", - "====> Epoch: 30 Average train loss: -1.1873\n", - "====> Epoch: 30 Average val loss: -1.2548\n", - "====> Epoch: 31 Average train loss: -0.4281\n", - "====> Epoch: 31 Average val loss: -0.4875\n", - "====> Epoch: 32 Average train loss: -0.4639\n", - "====> Epoch: 32 Average val loss: -0.4214\n", - "====> Epoch: 33 Average train loss: -0.5640\n", - "====> Epoch: 33 Average val loss: -0.5791\n", - "====> Epoch: 34 Average train loss: -0.5631\n", - "====> Epoch: 34 Average val loss: -0.5806\n", - "====> Epoch: 35 Average train loss: -0.6183\n", - "====> Epoch: 35 Average val loss: -0.7391\n", - "====> Epoch: 36 Average train loss: -0.7337\n", - "====> Epoch: 36 Average val loss: -0.7412\n", - "====> Epoch: 37 Average train loss: -0.8041\n", - "====> Epoch: 37 Average val loss: -0.7272\n", - "====> Epoch: 38 Average train loss: -0.8215\n", - "====> Epoch: 38 Average val loss: -0.7341\n", - "====> Epoch: 39 Average train loss: -0.9231\n", - "====> Epoch: 39 Average val loss: -0.8679\n", - "====> Epoch: 40 Average train loss: -0.9075\n", - "====> Epoch: 40 Average val loss: -0.9036\n", - "====> Epoch: 41 Average train loss: -0.9635\n", - "====> Epoch: 41 Average val loss: -1.0391\n", - "====> Epoch: 42 Average train loss: -0.5608\n", - "====> Epoch: 42 Average val loss: -0.5265\n", - "====> Epoch: 43 Average train loss: -0.5490\n", - "====> Epoch: 43 Average val loss: -0.4652\n", - "====> Epoch: 44 Average train loss: -0.5735\n", - "====> Epoch: 44 Average val loss: -0.5087\n", - "====> Epoch: 45 Average train loss: -0.4709\n", - "====> Epoch: 45 Average val loss: -0.5251\n", - "====> Epoch: 46 Average train loss: -0.5153\n", - "====> Epoch: 46 Average val loss: -0.5300\n", - "====> Epoch: 47 Average train loss: -0.6953\n", - "====> Epoch: 47 Average val loss: -0.6359\n", - "====> Epoch: 48 Average train loss: -0.7484\n", - "====> Epoch: 48 Average val loss: -0.7373\n", - "====> Epoch: 49 Average train loss: -0.7750\n", - "====> Epoch: 49 Average val loss: -0.6816\n", - "====> Epoch: 50 Average train loss: -0.4636\n", - "====> Epoch: 50 Average val loss: -0.3775\n", - "DCCA by non-linear orthogonal iterations\n", - "total parameters: 201484\n", - "====> Epoch: 1 Average train loss: 0.0077\n", - "====> Epoch: 1 Average val loss: 0.0074\n", - "Min loss 0.01\n", - "====> Epoch: 2 Average train loss: 0.0071\n", - "====> Epoch: 2 Average val loss: 0.0069\n", - "Min loss 0.01\n", - "====> Epoch: 3 Average train loss: 0.0072\n", - "====> Epoch: 3 Average val loss: 0.0069\n", - "====> Epoch: 4 Average train loss: 0.0067\n", - "====> Epoch: 4 Average val loss: 0.0070\n", - "====> Epoch: 5 Average train loss: 0.0064\n", - "====> Epoch: 5 Average val loss: 0.0069\n", - "====> Epoch: 6 Average train loss: 0.0066\n", - "====> Epoch: 6 Average val loss: 0.0070\n", - "====> Epoch: 7 Average train loss: 0.0067\n", - "====> Epoch: 7 Average val loss: 0.0068\n", - "Min loss 0.01\n", - "====> Epoch: 8 Average train loss: 0.0066\n", - "====> Epoch: 8 Average val loss: 0.0063\n", - "Min loss 0.01\n", - "====> Epoch: 9 Average train loss: 0.0068\n", - "====> Epoch: 9 Average val loss: 0.0064\n", - "====> Epoch: 10 Average train loss: 0.0066\n", - "====> Epoch: 10 Average val loss: 0.0065\n", - "====> Epoch: 11 Average train loss: 0.0067\n", - "====> Epoch: 11 Average val loss: 0.0063\n", - "Min loss 0.01\n", - "====> Epoch: 12 Average train loss: 0.0066\n", - "====> Epoch: 12 Average val loss: 0.0068\n", - "====> Epoch: 13 Average train loss: 0.0065\n", - "====> Epoch: 13 Average val loss: 0.0066\n", - "====> Epoch: 14 Average train loss: 0.0065\n", - "====> Epoch: 14 Average val loss: 0.0066\n", - "====> Epoch: 15 Average train loss: 0.0063\n", - "====> Epoch: 15 Average val loss: 0.0064\n", - "====> Epoch: 16 Average train loss: 0.0060\n", - "====> Epoch: 16 Average val loss: 0.0057\n", - "Min loss 0.01\n", - "====> Epoch: 17 Average train loss: 0.0057\n", - "====> Epoch: 17 Average val loss: 0.0054\n", - "Min loss 0.01\n", - "====> Epoch: 18 Average train loss: 0.0045\n", - "====> Epoch: 18 Average val loss: 0.0045\n", - "Min loss 0.00\n", - "====> Epoch: 19 Average train loss: 0.0044\n", - "====> Epoch: 19 Average val loss: 0.0043\n", - "Min loss 0.00\n", - "====> Epoch: 20 Average train loss: 0.0042\n", - "====> Epoch: 20 Average val loss: 0.0045\n", - "====> Epoch: 21 Average train loss: 0.0042\n", - "====> Epoch: 21 Average val loss: 0.0042\n", - "Min loss 0.00\n", - "====> Epoch: 22 Average train loss: 0.0042\n", - "====> Epoch: 22 Average val loss: 0.0042\n", - "Min loss 0.00\n", - "====> Epoch: 23 Average train loss: 0.0042\n", - "====> Epoch: 23 Average val loss: 0.0041\n", - "Min loss 0.00\n", - "====> Epoch: 24 Average train loss: 0.0042\n", - "====> Epoch: 24 Average val loss: 0.0040\n", - "Min loss 0.00\n", - "====> Epoch: 25 Average train loss: 0.0042\n", - "====> Epoch: 25 Average val loss: 0.0043\n", - "====> Epoch: 26 Average train loss: 0.0041\n", - "====> Epoch: 26 Average val loss: 0.0043\n", - "====> Epoch: 27 Average train loss: 0.0042\n", - "====> Epoch: 27 Average val loss: 0.0044\n", - "====> Epoch: 28 Average train loss: 0.0041\n", - "====> Epoch: 28 Average val loss: 0.0042\n", - "====> Epoch: 29 Average train loss: 0.0041\n", - "====> Epoch: 29 Average val loss: 0.0043\n", - "====> Epoch: 30 Average train loss: 0.0039\n", - "====> Epoch: 30 Average val loss: 0.0039\n", - "Min loss 0.00\n", - "====> Epoch: 31 Average train loss: 0.0038\n", - "====> Epoch: 31 Average val loss: 0.0037\n", - "Min loss 0.00\n", - "====> Epoch: 32 Average train loss: 0.0079\n", - "====> Epoch: 32 Average val loss: 0.0079\n", - "====> Epoch: 33 Average train loss: 0.0081\n", - "====> Epoch: 33 Average val loss: 0.0077\n", - "====> Epoch: 34 Average train loss: 0.0072\n", - "====> Epoch: 34 Average val loss: 0.0075\n", - "====> Epoch: 35 Average train loss: 0.0065\n", - "====> Epoch: 35 Average val loss: 0.0062\n", - "====> Epoch: 36 Average train loss: 0.0060\n", - "====> Epoch: 36 Average val loss: 0.0056\n", - "====> Epoch: 37 Average train loss: 0.0057\n", - "====> Epoch: 37 Average val loss: 0.0057\n", - "====> Epoch: 38 Average train loss: 0.0045\n", - "====> Epoch: 38 Average val loss: 0.0045\n", - "====> Epoch: 39 Average train loss: 0.0043\n", - "====> Epoch: 39 Average val loss: 0.0044\n", - "====> Epoch: 40 Average train loss: 0.0042\n", - "====> Epoch: 40 Average val loss: 0.0042\n", - "====> Epoch: 41 Average train loss: 0.0040\n", - "====> Epoch: 41 Average val loss: 0.0039\n", - "====> Epoch: 42 Average train loss: 0.0040\n", - "====> Epoch: 42 Average val loss: 0.0040\n", - "====> Epoch: 43 Average train loss: 0.0042\n", - "====> Epoch: 43 Average val loss: 0.0041\n", - "====> Epoch: 44 Average train loss: 0.0040\n", - "====> Epoch: 44 Average val loss: 0.0038\n", - "====> Epoch: 45 Average train loss: 0.0040\n", - "====> Epoch: 45 Average val loss: 0.0037\n", - "Min loss 0.00\n", - "====> Epoch: 46 Average train loss: 0.0041\n", - "====> Epoch: 46 Average val loss: 0.0038\n", - "====> Epoch: 47 Average train loss: 0.0039\n", - "====> Epoch: 47 Average val loss: 0.0040\n", - "====> Epoch: 48 Average train loss: 0.0035\n", - "====> Epoch: 48 Average val loss: 0.0034\n", - "Min loss 0.00\n", - "====> Epoch: 49 Average train loss: 0.0033\n", - "====> Epoch: 49 Average val loss: 0.0035\n", - "====> Epoch: 50 Average train loss: 0.0034\n", - "====> Epoch: 50 Average val loss: 0.0033\n", - "Min loss 0.00\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "LvhEwkx2vLTx" - }, - "source": [ - "# DCCA with custom optimizers and schedulers" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "pyNA7lBEvLTx", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "ff1b9c97-a2d6-4f18-9204-343ff3fd0dc8" - }, - "source": [ - "# DCCA\n", - "print('DCCA')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "dcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])\n", - "optimizer = optim.Adam(dcca_model.parameters(), lr=1e-4)\n", - "scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 1)\n", - "dcca_model = DeepWrapper(dcca_model,optimizer=optimizer,scheduler=scheduler)\n", - "\n", - "dcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dcca_results = np.stack((dcca_model.score(train_dataset), dcca_model.score(test_dataset)))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DCCA\n", - "total parameters: 201476\n", - "====> Epoch: 1 Average train loss: -0.2345\n", - "====> Epoch: 1 Average val loss: -0.1404\n", - "Min loss -0.14\n", - "====> Epoch: 2 Average train loss: -0.1096\n", - "====> Epoch: 2 Average val loss: -0.1372\n", - "====> Epoch: 3 Average train loss: -0.0880\n", - "====> Epoch: 3 Average val loss: -0.0736\n", - "====> Epoch: 4 Average train loss: -0.0473\n", - "====> Epoch: 4 Average val loss: -0.1699\n", - "Min loss -0.17\n", - "====> Epoch: 5 Average train loss: -0.1097\n", - "====> Epoch: 5 Average val loss: -0.1187\n", - "====> Epoch: 6 Average train loss: -0.1023\n", - "====> Epoch: 6 Average val loss: -0.0718\n", - "====> Epoch: 7 Average train loss: -0.1285\n", - "====> Epoch: 7 Average val loss: -0.1539\n", - "====> Epoch: 8 Average train loss: -0.0861\n", - "====> Epoch: 8 Average val loss: -0.0859\n", - "====> Epoch: 9 Average train loss: -0.0531\n", - "====> Epoch: 9 Average val loss: -0.0679\n", - "====> Epoch: 10 Average train loss: -0.0867\n", - "====> Epoch: 10 Average val loss: -0.0726\n", - "====> Epoch: 11 Average train loss: -0.0745\n", - "====> Epoch: 11 Average val loss: -0.0786\n", - "====> Epoch: 12 Average train loss: -0.0469\n", - "====> Epoch: 12 Average val loss: -0.1140\n", - "====> Epoch: 13 Average train loss: -0.0485\n", - "====> Epoch: 13 Average val loss: -0.1343\n", - "====> Epoch: 14 Average train loss: -0.1541\n", - "====> Epoch: 14 Average val loss: -0.0680\n", - "====> Epoch: 15 Average train loss: -0.1251\n", - "====> Epoch: 15 Average val loss: -0.1210\n", - "====> Epoch: 16 Average train loss: -0.0867\n", - "====> Epoch: 16 Average val loss: -0.0625\n", - "====> Epoch: 17 Average train loss: -0.0633\n", - "====> Epoch: 17 Average val loss: -0.0377\n", - "====> Epoch: 18 Average train loss: -0.0532\n", - "====> Epoch: 18 Average val loss: -0.0286\n", - "====> Epoch: 19 Average train loss: -0.1026\n", - "====> Epoch: 19 Average val loss: -0.1444\n", - "====> Epoch: 20 Average train loss: -0.1870\n", - "====> Epoch: 20 Average val loss: -0.1416\n", - "====> Epoch: 21 Average train loss: -0.1480\n", - "====> Epoch: 21 Average val loss: -0.0909\n", - "====> Epoch: 22 Average train loss: -0.1719\n", - "====> Epoch: 22 Average val loss: -0.1753\n", - "Min loss -0.18\n", - "====> Epoch: 23 Average train loss: -0.0477\n", - "====> Epoch: 23 Average val loss: -0.1516\n", - "====> Epoch: 24 Average train loss: -0.1052\n", - "====> Epoch: 24 Average val loss: -0.0912\n", - "====> Epoch: 25 Average train loss: -0.0992\n", - "====> Epoch: 25 Average val loss: -0.0723\n", - "====> Epoch: 26 Average train loss: -0.1183\n", - "====> Epoch: 26 Average val loss: -0.0333\n", - "====> Epoch: 27 Average train loss: -0.0867\n", - "====> Epoch: 27 Average val loss: -0.1468\n", - "====> Epoch: 28 Average train loss: -0.0383\n", - "====> Epoch: 28 Average val loss: -0.0647\n", - "====> Epoch: 29 Average train loss: -0.1259\n", - "====> Epoch: 29 Average val loss: -0.0760\n", - "====> Epoch: 30 Average train loss: -0.1398\n", - "====> Epoch: 30 Average val loss: -0.0601\n", - "====> Epoch: 31 Average train loss: -0.1736\n", - "====> Epoch: 31 Average val loss: -0.1128\n", - "====> Epoch: 32 Average train loss: -0.1219\n", - "====> Epoch: 32 Average val loss: -0.0435\n", - "====> Epoch: 33 Average train loss: -0.0648\n", - "====> Epoch: 33 Average val loss: -0.1038\n", - "====> Epoch: 34 Average train loss: -0.2130\n", - "====> Epoch: 34 Average val loss: -0.0537\n", - "====> Epoch: 35 Average train loss: -0.0897\n", - "====> Epoch: 35 Average val loss: -0.1286\n", - "====> Epoch: 36 Average train loss: -0.0613\n", - "====> Epoch: 36 Average val loss: -0.1210\n", - "====> Epoch: 37 Average train loss: -0.2377\n", - "====> Epoch: 37 Average val loss: -0.0619\n", - "====> Epoch: 38 Average train loss: -0.0458\n", - "====> Epoch: 38 Average val loss: -0.1473\n", - "====> Epoch: 39 Average train loss: -0.0578\n", - "====> Epoch: 39 Average val loss: -0.1190\n", - "====> Epoch: 40 Average train loss: -0.1635\n", - "====> Epoch: 40 Average val loss: -0.1229\n", - "====> Epoch: 41 Average train loss: -0.1248\n", - "====> Epoch: 41 Average val loss: -0.0981\n", - "====> Epoch: 42 Average train loss: -0.0984\n", - "====> Epoch: 42 Average val loss: -0.0647\n", - "====> Epoch: 43 Average train loss: -0.1104\n", - "====> Epoch: 43 Average val loss: -0.1612\n", - "====> Epoch: 44 Average train loss: -0.0655\n", - "====> Epoch: 44 Average val loss: -0.1284\n", - "====> Epoch: 45 Average train loss: -0.0875\n", - "====> Epoch: 45 Average val loss: -0.0430\n", - "====> Epoch: 46 Average train loss: -0.1070\n", - "====> Epoch: 46 Average val loss: -0.1176\n", - "====> Epoch: 47 Average train loss: -0.0518\n", - "====> Epoch: 47 Average val loss: -0.0676\n", - "====> Epoch: 48 Average train loss: -0.2479\n", - "====> Epoch: 48 Average val loss: -0.1518\n", - "====> Epoch: 49 Average train loss: -0.0629\n", - "====> Epoch: 49 Average val loss: -0.1117\n", - "====> Epoch: 50 Average train loss: -0.1809\n", - "====> Epoch: 50 Average val loss: -0.1175\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "5eV-WnWRvLTy" - }, - "source": [ - "# DGCCA and DMCCA for more than 2 views\n", - "\n", - "The only change we need to make is to the objective argument to perform DGCCA and DMCCA." - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "_XKkNitdvLTy", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "3e35ad57-ea2b-4631-e0d2-8f14282affb5" - }, - "source": [ - "# DGCCA\n", - "print('DGCCA')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "dgcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.GCCA)\n", - "\n", - "dgcca_model = DeepWrapper(dgcca_model)\n", - "\n", - "dgcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dgcca_results = np.stack(\n", - " (dgcca_model.score(train_dataset), dgcca_model.score(test_dataset)))\n", - "\n", - "# DMCCA\n", - "print('DMCCA')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "dmcca_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2], objective=objectives.MCCA)\n", - "\n", - "dmcca_model = DeepWrapper(dmcca_model)\n", - "\n", - "dmcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dmcca_results = np.stack(\n", - " (dmcca_model.score(train_dataset), dmcca_model.score(test_dataset)))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DGCCA\n", - "total parameters: 201476\n", - "====> Epoch: 1 Average train loss: -0.3717\n", - "====> Epoch: 1 Average val loss: -0.2874\n", - "Min loss -0.29\n", - "====> Epoch: 2 Average train loss: -0.4010\n", - "====> Epoch: 2 Average val loss: -0.3402\n", - "Min loss -0.34\n", - "====> Epoch: 3 Average train loss: -0.5026\n", - "====> Epoch: 3 Average val loss: -0.4891\n", - "Min loss -0.49\n", - "====> Epoch: 4 Average train loss: -0.4654\n", - "====> Epoch: 4 Average val loss: -0.4745\n", - "====> Epoch: 5 Average train loss: -0.6690\n", - "====> Epoch: 5 Average val loss: -0.5786\n", - "Min loss -0.58\n", - "====> Epoch: 6 Average train loss: -0.5903\n", - "====> Epoch: 6 Average val loss: -0.5848\n", - "Min loss -0.58\n", - "====> Epoch: 7 Average train loss: -0.6430\n", - "====> Epoch: 7 Average val loss: -0.6894\n", - "Min loss -0.69\n", - "====> Epoch: 8 Average train loss: -0.7384\n", - "====> Epoch: 8 Average val loss: -0.6960\n", - "Min loss -0.70\n", - "====> Epoch: 9 Average train loss: -0.7716\n", - "====> Epoch: 9 Average val loss: -0.7567\n", - "Min loss -0.76\n", - "====> Epoch: 10 Average train loss: -0.8274\n", - "====> Epoch: 10 Average val loss: -0.7267\n", - "====> Epoch: 11 Average train loss: -0.8185\n", - "====> Epoch: 11 Average val loss: -0.7173\n", - "====> Epoch: 12 Average train loss: -0.8574\n", - "====> Epoch: 12 Average val loss: -0.7553\n", - "====> Epoch: 13 Average train loss: -0.8279\n", - "====> Epoch: 13 Average val loss: -0.7826\n", - "Min loss -0.78\n", - "====> Epoch: 14 Average train loss: -0.8461\n", - "====> Epoch: 14 Average val loss: -0.7677\n", - "====> Epoch: 15 Average train loss: -0.9792\n", - "====> Epoch: 15 Average val loss: -0.8070\n", - "Min loss -0.81\n", - "====> Epoch: 16 Average train loss: -0.8722\n", - "====> Epoch: 16 Average val loss: -0.8461\n", - "Min loss -0.85\n", - "====> Epoch: 17 Average train loss: -0.8155\n", - "====> Epoch: 17 Average val loss: -0.8062\n", - "====> Epoch: 18 Average train loss: -0.8851\n", - "====> Epoch: 18 Average val loss: -0.8348\n", - "====> Epoch: 19 Average train loss: -1.0884\n", - "====> Epoch: 19 Average val loss: -0.8903\n", - "Min loss -0.89\n", - "====> Epoch: 20 Average train loss: -0.9305\n", - "====> Epoch: 20 Average val loss: -0.9799\n", - "Min loss -0.98\n", - "====> Epoch: 21 Average train loss: -0.8848\n", - "====> Epoch: 21 Average val loss: -0.9489\n", - "====> Epoch: 22 Average train loss: -0.9700\n", - "====> Epoch: 22 Average val loss: -0.9913\n", - "Min loss -0.99\n", - "====> Epoch: 23 Average train loss: -1.0768\n", - "====> Epoch: 23 Average val loss: -0.9465\n", - "====> Epoch: 24 Average train loss: -0.9855\n", - "====> Epoch: 24 Average val loss: -0.9430\n", - "====> Epoch: 25 Average train loss: -0.9074\n", - "====> Epoch: 25 Average val loss: -0.9642\n", - "====> Epoch: 26 Average train loss: -1.0501\n", - "====> Epoch: 26 Average val loss: -1.0461\n", - "Min loss -1.05\n", - "====> Epoch: 27 Average train loss: -1.0541\n", - "====> Epoch: 27 Average val loss: -1.0062\n", - "====> Epoch: 28 Average train loss: -1.0942\n", - "====> Epoch: 28 Average val loss: -1.0910\n", - "Min loss -1.09\n", - "====> Epoch: 29 Average train loss: -1.0717\n", - "====> Epoch: 29 Average val loss: -1.0549\n", - "====> Epoch: 30 Average train loss: -1.0560\n", - "====> Epoch: 30 Average val loss: -1.0381\n", - "====> Epoch: 31 Average train loss: -1.0791\n", - "====> Epoch: 31 Average val loss: -1.0678\n", - "====> Epoch: 32 Average train loss: -0.7371\n", - "====> Epoch: 32 Average val loss: -0.6544\n", - "====> Epoch: 33 Average train loss: -0.7659\n", - "====> Epoch: 33 Average val loss: -0.6864\n", - "====> Epoch: 34 Average train loss: -0.7129\n", - "====> Epoch: 34 Average val loss: -0.6451\n", - "====> Epoch: 35 Average train loss: -0.7985\n", - "====> Epoch: 35 Average val loss: -0.8138\n", - "====> Epoch: 36 Average train loss: -0.7538\n", - "====> Epoch: 36 Average val loss: -0.6864\n", - "====> Epoch: 37 Average train loss: -0.7906\n", - "====> Epoch: 37 Average val loss: -0.7654\n", - "====> Epoch: 38 Average train loss: -0.8417\n", - "====> Epoch: 38 Average val loss: -0.8218\n", - "====> Epoch: 39 Average train loss: -0.8689\n", - "====> Epoch: 39 Average val loss: -0.8711\n", - "====> Epoch: 40 Average train loss: -0.8193\n", - "====> Epoch: 40 Average val loss: -0.7741\n", - "====> Epoch: 41 Average train loss: -0.9110\n", - "====> Epoch: 41 Average val loss: -0.8667\n", - "====> Epoch: 42 Average train loss: -0.9737\n", - "====> Epoch: 42 Average val loss: -0.9454\n", - "====> Epoch: 43 Average train loss: -0.9981\n", - "====> Epoch: 43 Average val loss: -0.9658\n", - "====> Epoch: 44 Average train loss: -1.1852\n", - "====> Epoch: 44 Average val loss: -1.1355\n", - "Min loss -1.14\n", - "====> Epoch: 45 Average train loss: -0.9776\n", - "====> Epoch: 45 Average val loss: -0.7820\n", - "====> Epoch: 46 Average train loss: -0.7977\n", - "====> Epoch: 46 Average val loss: -0.7567\n", - "====> Epoch: 47 Average train loss: -1.1455\n", - "====> Epoch: 47 Average val loss: -0.9938\n", - "====> Epoch: 48 Average train loss: -1.0858\n", - "====> Epoch: 48 Average val loss: -1.0856\n", - "====> Epoch: 49 Average train loss: -1.0088\n", - "====> Epoch: 49 Average val loss: -1.0482\n", - "====> Epoch: 50 Average train loss: -1.0798\n", - "====> Epoch: 50 Average val loss: -1.1690\n", - "Min loss -1.17\n", - "DMCCA\n", - "total parameters: 201476\n", - "====> Epoch: 1 Average train loss: -0.7272\n", - "====> Epoch: 1 Average val loss: -0.5648\n", - "Min loss -0.56\n", - "====> Epoch: 2 Average train loss: -0.5396\n", - "====> Epoch: 2 Average val loss: -0.5257\n", - "====> Epoch: 3 Average train loss: -0.5592\n", - "====> Epoch: 3 Average val loss: -0.5761\n", - "Min loss -0.58\n", - "====> Epoch: 4 Average train loss: -0.8326\n", - "====> Epoch: 4 Average val loss: -0.6956\n", - "Min loss -0.70\n", - "====> Epoch: 5 Average train loss: -0.7508\n", - "====> Epoch: 5 Average val loss: -0.7319\n", - "Min loss -0.73\n", - "====> Epoch: 6 Average train loss: -0.7918\n", - "====> Epoch: 6 Average val loss: -0.7501\n", - "Min loss -0.75\n", - "====> Epoch: 7 Average train loss: -0.7197\n", - "====> Epoch: 7 Average val loss: -0.7025\n", - "====> Epoch: 8 Average train loss: -0.8067\n", - "====> Epoch: 8 Average val loss: -0.7396\n", - "====> Epoch: 9 Average train loss: -0.7109\n", - "====> Epoch: 9 Average val loss: -0.7343\n", - "====> Epoch: 10 Average train loss: -0.7181\n", - "====> Epoch: 10 Average val loss: -0.6935\n", - "====> Epoch: 11 Average train loss: -0.7715\n", - "====> Epoch: 11 Average val loss: -0.7396\n", - "====> Epoch: 12 Average train loss: -0.8081\n", - "====> Epoch: 12 Average val loss: -0.7593\n", - "Min loss -0.76\n", - "====> Epoch: 13 Average train loss: -0.8056\n", - "====> Epoch: 13 Average val loss: -0.7780\n", - "Min loss -0.78\n", - "====> Epoch: 14 Average train loss: -0.7139\n", - "====> Epoch: 14 Average val loss: -0.7234\n", - "====> Epoch: 15 Average train loss: -0.8010\n", - "====> Epoch: 15 Average val loss: -0.7624\n", - "====> Epoch: 16 Average train loss: -0.8537\n", - "====> Epoch: 16 Average val loss: -0.8110\n", - "Min loss -0.81\n", - "====> Epoch: 17 Average train loss: -0.8014\n", - "====> Epoch: 17 Average val loss: -0.7729\n", - "====> Epoch: 18 Average train loss: -0.9397\n", - "====> Epoch: 18 Average val loss: -0.7617\n", - "====> Epoch: 19 Average train loss: -0.9068\n", - "====> Epoch: 19 Average val loss: -0.8688\n", - "Min loss -0.87\n", - "====> Epoch: 20 Average train loss: -0.9523\n", - "====> Epoch: 20 Average val loss: -0.9657\n", - "Min loss -0.97\n", - "====> Epoch: 21 Average train loss: -0.9830\n", - "====> Epoch: 21 Average val loss: -0.9828\n", - "Min loss -0.98\n", - "====> Epoch: 22 Average train loss: -0.8908\n", - "====> Epoch: 22 Average val loss: -1.0201\n", - "Min loss -1.02\n", - "====> Epoch: 23 Average train loss: -1.0304\n", - "====> Epoch: 23 Average val loss: -1.0090\n", - "====> Epoch: 24 Average train loss: -1.0930\n", - "====> Epoch: 24 Average val loss: -1.0294\n", - "Min loss -1.03\n", - "====> Epoch: 25 Average train loss: -1.1140\n", - "====> Epoch: 25 Average val loss: -1.0150\n", - "====> Epoch: 26 Average train loss: -0.8270\n", - "====> Epoch: 26 Average val loss: -0.8099\n", - "====> Epoch: 27 Average train loss: -0.8180\n", - "====> Epoch: 27 Average val loss: -0.7855\n", - "====> Epoch: 28 Average train loss: -0.8484\n", - "====> Epoch: 28 Average val loss: -0.8272\n", - "====> Epoch: 29 Average train loss: -0.9530\n", - "====> Epoch: 29 Average val loss: -0.9293\n", - "====> Epoch: 30 Average train loss: -1.0448\n", - "====> Epoch: 30 Average val loss: -1.0305\n", - "Min loss -1.03\n", - "====> Epoch: 31 Average train loss: -1.1297\n", - "====> Epoch: 31 Average val loss: -1.1626\n", - "Min loss -1.16\n", - "====> Epoch: 32 Average train loss: -1.2066\n", - "====> Epoch: 32 Average val loss: -1.2592\n", - "Min loss -1.26\n", - "====> Epoch: 33 Average train loss: -1.2493\n", - "====> Epoch: 33 Average val loss: -1.2980\n", - "Min loss -1.30\n", - "====> Epoch: 34 Average train loss: -1.1872\n", - "====> Epoch: 34 Average val loss: -1.2153\n", - "====> Epoch: 35 Average train loss: -1.2708\n", - "====> Epoch: 35 Average val loss: -1.2851\n", - "====> Epoch: 36 Average train loss: -1.2831\n", - "====> Epoch: 36 Average val loss: -1.3056\n", - "Min loss -1.31\n", - "====> Epoch: 37 Average train loss: -1.2999\n", - "====> Epoch: 37 Average val loss: -1.3194\n", - "Min loss -1.32\n", - "====> Epoch: 38 Average train loss: -1.3256\n", - "====> Epoch: 38 Average val loss: -1.3528\n", - "Min loss -1.35\n", - "====> Epoch: 39 Average train loss: -1.3218\n", - "====> Epoch: 39 Average val loss: -1.2652\n", - "====> Epoch: 40 Average train loss: -1.3106\n", - "====> Epoch: 40 Average val loss: -1.2634\n", - "====> Epoch: 41 Average train loss: -0.8961\n", - "====> Epoch: 41 Average val loss: -0.7964\n", - "====> Epoch: 42 Average train loss: -1.3604\n", - "====> Epoch: 42 Average val loss: -1.3359\n", - "====> Epoch: 43 Average train loss: -1.2987\n", - "====> Epoch: 43 Average val loss: -1.2966\n", - "====> Epoch: 44 Average train loss: -1.4320\n", - "====> Epoch: 44 Average val loss: -1.4049\n", - "Min loss -1.40\n", - "====> Epoch: 45 Average train loss: -1.4297\n", - "====> Epoch: 45 Average val loss: -1.3719\n", - "====> Epoch: 46 Average train loss: -1.3900\n", - "====> Epoch: 46 Average val loss: -1.4160\n", - "Min loss -1.42\n", - "====> Epoch: 47 Average train loss: -1.4094\n", - "====> Epoch: 47 Average val loss: -1.4063\n", - "====> Epoch: 48 Average train loss: -1.4272\n", - "====> Epoch: 48 Average val loss: -1.4198\n", - "Min loss -1.42\n", - "====> Epoch: 49 Average train loss: -1.3373\n", - "====> Epoch: 49 Average val loss: -1.3366\n", - "====> Epoch: 50 Average train loss: -1.3213\n", - "====> Epoch: 50 Average val loss: -1.3334\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "i6Czf36DvLTy" - }, - "source": [ - "# Deep Canonically Correlated Autoencoders\n", - "We need to add decoders in order to model deep canonically correlated autoencoders and we also use the DCCAE class which inherits from DCCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "jphk92IhvLTz", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "4982a7e6-e80a-4bef-cd4d-2557d39a8a45" - }, - "source": [ - "# DCCAE\n", - "print('DCCAE')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "decoder_1 = architectures.Decoder(latent_dims=latent_dims, feature_size=784)\n", - "decoder_2 = architectures.Decoder(latent_dims=latent_dims, feature_size=784)\n", - "dccae_model = DCCAE(latent_dims=latent_dims, encoders=[encoder_1, encoder_2], decoders=[decoder_1, decoder_2])\n", - "\n", - "dccae_model = DeepWrapper(dccae_model)\n", - "\n", - "#can also pass a tuple of numpy arrays\n", - "dccae_model.fit((train_view_1, train_view_2), epochs=epochs)\n", - "\n", - "dccae_results = np.stack(\n", - " (dccae_model.score(train_dataset), dccae_model.score(test_dataset)))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DCCAE\n", - "total parameters: 404516\n", - "====> Epoch: 1 Average train loss: -0.0019\n", - "====> Epoch: 2 Average train loss: -0.0677\n", - "====> Epoch: 3 Average train loss: -0.1219\n", - "====> Epoch: 4 Average train loss: -0.1684\n", - "====> Epoch: 5 Average train loss: -0.2075\n", - "====> Epoch: 6 Average train loss: -0.2417\n", - "====> Epoch: 7 Average train loss: -0.2722\n", - "====> Epoch: 8 Average train loss: -0.3001\n", - "====> Epoch: 9 Average train loss: -0.3262\n", - "====> Epoch: 10 Average train loss: -0.3502\n", - "====> Epoch: 11 Average train loss: -0.3730\n", - "====> Epoch: 12 Average train loss: -0.3943\n", - "====> Epoch: 13 Average train loss: -0.4143\n", - "====> Epoch: 14 Average train loss: -0.4332\n", - "====> Epoch: 15 Average train loss: -0.4514\n", - "====> Epoch: 16 Average train loss: -0.4691\n", - "====> Epoch: 17 Average train loss: -0.4865\n", - "====> Epoch: 18 Average train loss: -0.5030\n", - "====> Epoch: 19 Average train loss: -0.5188\n", - "====> Epoch: 20 Average train loss: -0.5341\n", - "====> Epoch: 21 Average train loss: -0.5490\n", - "====> Epoch: 22 Average train loss: -0.5638\n", - "====> Epoch: 23 Average train loss: -0.5781\n", - "====> Epoch: 24 Average train loss: -0.5919\n", - "====> Epoch: 25 Average train loss: -0.6053\n", - "====> Epoch: 26 Average train loss: -0.6185\n", - "====> Epoch: 27 Average train loss: -0.6317\n", - "====> Epoch: 28 Average train loss: -0.6447\n", - "====> Epoch: 29 Average train loss: -0.6575\n", - "====> Epoch: 30 Average train loss: -0.6700\n", - "====> Epoch: 31 Average train loss: -0.6821\n", - "====> Epoch: 32 Average train loss: -0.6939\n", - "====> Epoch: 33 Average train loss: -0.7054\n", - "====> Epoch: 34 Average train loss: -0.7167\n", - "====> Epoch: 35 Average train loss: -0.7276\n", - "====> Epoch: 36 Average train loss: -0.7384\n", - "====> Epoch: 37 Average train loss: -0.7490\n", - "====> Epoch: 38 Average train loss: -0.7595\n", - "====> Epoch: 39 Average train loss: -0.7698\n", - "====> Epoch: 40 Average train loss: -0.7800\n", - "====> Epoch: 41 Average train loss: -0.7901\n", - "====> Epoch: 42 Average train loss: -0.7999\n", - "====> Epoch: 43 Average train loss: -0.8096\n", - "====> Epoch: 44 Average train loss: -0.8191\n", - "====> Epoch: 45 Average train loss: -0.8285\n", - "====> Epoch: 46 Average train loss: -0.8378\n", - "====> Epoch: 47 Average train loss: -0.8471\n", - "====> Epoch: 48 Average train loss: -0.8563\n", - "====> Epoch: 49 Average train loss: -0.8654\n", - "====> Epoch: 50 Average train loss: -0.8743\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "WEK3sUSuvLTz" - }, - "source": [ - "# Deep Variational CCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "9lqcopiSvLTz", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "484a9a77-0b35-41b7-ddef-6aacbb8be6c0" - }, - "source": [ - "\"\"\"\n", - "### Deep Variational Learning\n", - "Finally we have Deep Variational CCA methods.\n", - "- Deep Variational CCA (DVCCA)\n", - "- Deep Variational CCA - private (DVVCA_p)\n", - "\n", - "These are both implemented by the DVCCA class with private=True/False and both_encoders=True/False. If both_encoders,\n", - "the encoder to the shared information Q(z_shared|x) is modelled for both x_1 and x_2 whereas if both_encoders is false\n", - "it is modelled for x_1 as in the paper\n", - "\"\"\"\n", - "\n", - "# %%\n", - "# DVCCA (technically bi-DVCCA)\n", - "print('DVCCA')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784, variational=True)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784, variational=True)\n", - "decoder_1 = architectures.Decoder(latent_dims=latent_dims, feature_size=784, norm_output=True)\n", - "decoder_2 = architectures.Decoder(latent_dims=latent_dims, feature_size=784, norm_output=True)\n", - "dvcca_model = DVCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2], decoders=[decoder_1, decoder_2])\n", - "\n", - "dvcca_model = DeepWrapper(dvcca_model)\n", - "\n", - "dvcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dvcca_model_results = np.stack(\n", - " (dvcca_model.score(train_dataset), dvcca_model.score(test_dataset)))\n", - "\n", - "# DVCCA_private (technically bi-DVCCA_private)\n", - "print('DVCCA_private')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784, variational=True)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784, variational=True)\n", - "private_encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784, variational=True)\n", - "private_encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784, variational=True)\n", - "decoder_1 = architectures.Decoder(latent_dims=latent_dims * 2, feature_size=784, norm_output=True)\n", - "decoder_2 = architectures.Decoder(latent_dims=latent_dims * 2, feature_size=784, norm_output=True)\n", - "dvccap_model = DVCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2], decoders=[decoder_1, decoder_2],\n", - " private_encoders=[private_encoder_1, private_encoder_2])\n", - "\n", - "dvccap_model = DeepWrapper(dvccap_model)\n", - "\n", - "dvccap_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dvccap_model_results = np.stack(\n", - " (dvccap_model.score(train_dataset), dvccap_model.score(test_dataset)))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DVCCA\n", - "total parameters: 405032\n", - "====> Epoch: 1 Average train loss: 1109.5615\n", - "====> Epoch: 1 Average val loss: 1106.4091\n", - "Min loss 1106.41\n", - "====> Epoch: 2 Average train loss: 1106.5356\n", - "====> Epoch: 2 Average val loss: 1103.5577\n", - "Min loss 1103.56\n", - "====> Epoch: 3 Average train loss: 1103.5905\n", - "====> Epoch: 3 Average val loss: 1100.6377\n", - "Min loss 1100.64\n", - "====> Epoch: 4 Average train loss: 1100.4827\n", - "====> Epoch: 4 Average val loss: 1097.5350\n", - "Min loss 1097.54\n", - "====> Epoch: 5 Average train loss: 1097.5166\n", - "====> Epoch: 5 Average val loss: 1094.5754\n", - "Min loss 1094.58\n", - "====> Epoch: 6 Average train loss: 1094.8094\n", - "====> Epoch: 6 Average val loss: 1091.9133\n", - "Min loss 1091.91\n", - "====> Epoch: 7 Average train loss: 1092.1809\n", - "====> Epoch: 7 Average val loss: 1089.2421\n", - "Min loss 1089.24\n", - "====> Epoch: 8 Average train loss: 1089.0867\n", - "====> Epoch: 8 Average val loss: 1086.4995\n", - "Min loss 1086.50\n", - "====> Epoch: 9 Average train loss: 1086.4094\n", - "====> Epoch: 9 Average val loss: 1083.7737\n", - "Min loss 1083.77\n", - "====> Epoch: 10 Average train loss: 1083.9321\n", - "====> Epoch: 10 Average val loss: 1081.5156\n", - "Min loss 1081.52\n", - "====> Epoch: 11 Average train loss: 1081.5007\n", - "====> Epoch: 11 Average val loss: 1078.8384\n", - "Min loss 1078.84\n", - "====> Epoch: 12 Average train loss: 1078.7391\n", - "====> Epoch: 12 Average val loss: 1076.0708\n", - "Min loss 1076.07\n", - "====> Epoch: 13 Average train loss: 1076.5303\n", - "====> Epoch: 13 Average val loss: 1073.5686\n", - "Min loss 1073.57\n", - "====> Epoch: 14 Average train loss: 1073.3767\n", - "====> Epoch: 14 Average val loss: 1071.1628\n", - "Min loss 1071.16\n", - "====> Epoch: 15 Average train loss: 1071.1952\n", - "====> Epoch: 15 Average val loss: 1068.8904\n", - "Min loss 1068.89\n", - "====> Epoch: 16 Average train loss: 1069.1177\n", - "====> Epoch: 16 Average val loss: 1066.5808\n", - "Min loss 1066.58\n", - "====> Epoch: 17 Average train loss: 1066.5830\n", - "====> Epoch: 17 Average val loss: 1064.4114\n", - "Min loss 1064.41\n", - "====> Epoch: 18 Average train loss: 1063.8755\n", - "====> Epoch: 18 Average val loss: 1061.8417\n", - "Min loss 1061.84\n", - "====> Epoch: 19 Average train loss: 1062.0535\n", - "====> Epoch: 19 Average val loss: 1059.6287\n", - "Min loss 1059.63\n", - "====> Epoch: 20 Average train loss: 1059.5776\n", - "====> Epoch: 20 Average val loss: 1057.6716\n", - "Min loss 1057.67\n", - "====> Epoch: 21 Average train loss: 1057.6267\n", - "====> Epoch: 21 Average val loss: 1055.5505\n", - "Min loss 1055.55\n", - "====> Epoch: 22 Average train loss: 1055.3225\n", - "====> Epoch: 22 Average val loss: 1053.0654\n", - "Min loss 1053.07\n", - "====> Epoch: 23 Average train loss: 1053.1990\n", - "====> Epoch: 23 Average val loss: 1051.1343\n", - "Min loss 1051.13\n", - "====> Epoch: 24 Average train loss: 1051.1526\n", - "====> Epoch: 24 Average val loss: 1049.0571\n", - "Min loss 1049.06\n", - "====> Epoch: 25 Average train loss: 1048.9647\n", - "====> Epoch: 25 Average val loss: 1046.9143\n", - "Min loss 1046.91\n", - "====> Epoch: 26 Average train loss: 1047.0123\n", - "====> Epoch: 26 Average val loss: 1044.8156\n", - "Min loss 1044.82\n", - "====> Epoch: 27 Average train loss: 1044.7405\n", - "====> Epoch: 27 Average val loss: 1043.0496\n", - "Min loss 1043.05\n", - "====> Epoch: 28 Average train loss: 1043.0582\n", - "====> Epoch: 28 Average val loss: 1041.3112\n", - "Min loss 1041.31\n", - "====> Epoch: 29 Average train loss: 1041.1470\n", - "====> Epoch: 29 Average val loss: 1039.2390\n", - "Min loss 1039.24\n", - "====> Epoch: 30 Average train loss: 1039.4268\n", - "====> Epoch: 30 Average val loss: 1037.7975\n", - "Min loss 1037.80\n", - "====> Epoch: 31 Average train loss: 1037.5386\n", - "====> Epoch: 31 Average val loss: 1035.5798\n", - "Min loss 1035.58\n", - "====> Epoch: 32 Average train loss: 1035.8229\n", - "====> Epoch: 32 Average val loss: 1033.9326\n", - "Min loss 1033.93\n", - "====> Epoch: 33 Average train loss: 1033.8650\n", - "====> Epoch: 33 Average val loss: 1031.8835\n", - "Min loss 1031.88\n", - "====> Epoch: 34 Average train loss: 1032.1606\n", - "====> Epoch: 34 Average val loss: 1030.4136\n", - "Min loss 1030.41\n", - "====> Epoch: 35 Average train loss: 1030.4910\n", - "====> Epoch: 35 Average val loss: 1028.5079\n", - "Min loss 1028.51\n", - "====> Epoch: 36 Average train loss: 1028.8259\n", - "====> Epoch: 36 Average val loss: 1026.8826\n", - "Min loss 1026.88\n", - "====> Epoch: 37 Average train loss: 1027.3547\n", - "====> Epoch: 37 Average val loss: 1025.3986\n", - "Min loss 1025.40\n", - "====> Epoch: 38 Average train loss: 1025.0544\n", - "====> Epoch: 38 Average val loss: 1023.9811\n", - "Min loss 1023.98\n", - "====> Epoch: 39 Average train loss: 1024.1343\n", - "====> Epoch: 39 Average val loss: 1021.9122\n", - "Min loss 1021.91\n", - "====> Epoch: 40 Average train loss: 1022.0138\n", - "====> Epoch: 40 Average val loss: 1019.9561\n", - "Min loss 1019.96\n", - "====> Epoch: 41 Average train loss: 1020.5950\n", - "====> Epoch: 41 Average val loss: 1019.5484\n", - "Min loss 1019.55\n", - "====> Epoch: 42 Average train loss: 1018.7665\n", - "====> Epoch: 42 Average val loss: 1017.3838\n", - "Min loss 1017.38\n", - "====> Epoch: 43 Average train loss: 1017.5677\n", - "====> Epoch: 43 Average val loss: 1016.0297\n", - "Min loss 1016.03\n", - "====> Epoch: 44 Average train loss: 1016.1881\n", - "====> Epoch: 44 Average val loss: 1014.4371\n", - "Min loss 1014.44\n", - "====> Epoch: 45 Average train loss: 1014.6393\n", - "====> Epoch: 45 Average val loss: 1013.1492\n", - "Min loss 1013.15\n", - "====> Epoch: 46 Average train loss: 1013.3342\n", - "====> Epoch: 46 Average val loss: 1012.0658\n", - "Min loss 1012.07\n", - "====> Epoch: 47 Average train loss: 1011.7484\n", - "====> Epoch: 47 Average val loss: 1010.7506\n", - "Min loss 1010.75\n", - "====> Epoch: 48 Average train loss: 1010.6251\n", - "====> Epoch: 48 Average val loss: 1008.8579\n", - "Min loss 1008.86\n", - "====> Epoch: 49 Average train loss: 1009.0502\n", - "====> Epoch: 49 Average val loss: 1007.8226\n", - "Min loss 1007.82\n", - "====> Epoch: 50 Average train loss: 1007.1901\n", - "====> Epoch: 50 Average val loss: 1005.9260\n", - "Min loss 1005.93\n", - "DVCCA_private\n", - "total parameters: 607536\n", - "====> Epoch: 1 Average train loss: 1111.1758\n", - "====> Epoch: 1 Average val loss: 1108.3516\n", - "Min loss 1108.35\n", - "====> Epoch: 2 Average train loss: 1108.4519\n", - "====> Epoch: 2 Average val loss: 1105.1642\n", - "Min loss 1105.16\n", - "====> Epoch: 3 Average train loss: 1105.5024\n", - "====> Epoch: 3 Average val loss: 1102.2720\n", - "Min loss 1102.27\n", - "====> Epoch: 4 Average train loss: 1102.6145\n", - "====> Epoch: 4 Average val loss: 1099.6088\n", - "Min loss 1099.61\n", - "====> Epoch: 5 Average train loss: 1099.9021\n", - "====> Epoch: 5 Average val loss: 1097.1158\n", - "Min loss 1097.12\n", - "====> Epoch: 6 Average train loss: 1097.0347\n", - "====> Epoch: 6 Average val loss: 1094.3152\n", - "Min loss 1094.32\n", - "====> Epoch: 7 Average train loss: 1094.4141\n", - "====> Epoch: 7 Average val loss: 1091.6289\n", - "Min loss 1091.63\n", - "====> Epoch: 8 Average train loss: 1091.6658\n", - "====> Epoch: 8 Average val loss: 1088.9984\n", - "Min loss 1089.00\n", - "====> Epoch: 9 Average train loss: 1089.0031\n", - "====> Epoch: 9 Average val loss: 1086.1648\n", - "Min loss 1086.16\n", - "====> Epoch: 10 Average train loss: 1086.3365\n", - "====> Epoch: 10 Average val loss: 1083.7288\n", - "Min loss 1083.73\n", - "====> Epoch: 11 Average train loss: 1083.8301\n", - "====> Epoch: 11 Average val loss: 1081.3557\n", - "Min loss 1081.36\n", - "====> Epoch: 12 Average train loss: 1081.5806\n", - "====> Epoch: 12 Average val loss: 1078.7136\n", - "Min loss 1078.71\n", - "====> Epoch: 13 Average train loss: 1078.9102\n", - "====> Epoch: 13 Average val loss: 1076.4788\n", - "Min loss 1076.48\n", - "====> Epoch: 14 Average train loss: 1076.6189\n", - "====> Epoch: 14 Average val loss: 1074.0825\n", - "Min loss 1074.08\n", - "====> Epoch: 15 Average train loss: 1074.3435\n", - "====> Epoch: 15 Average val loss: 1071.4995\n", - "Min loss 1071.50\n", - "====> Epoch: 16 Average train loss: 1072.0078\n", - "====> Epoch: 16 Average val loss: 1069.3938\n", - "Min loss 1069.39\n", - "====> Epoch: 17 Average train loss: 1069.2450\n", - "====> Epoch: 17 Average val loss: 1067.1929\n", - "Min loss 1067.19\n", - "====> Epoch: 18 Average train loss: 1067.4008\n", - "====> Epoch: 18 Average val loss: 1064.8755\n", - "Min loss 1064.88\n", - "====> Epoch: 19 Average train loss: 1065.0360\n", - "====> Epoch: 19 Average val loss: 1062.5830\n", - "Min loss 1062.58\n", - "====> Epoch: 20 Average train loss: 1062.9507\n", - "====> Epoch: 20 Average val loss: 1060.5204\n", - "Min loss 1060.52\n", - "====> Epoch: 21 Average train loss: 1060.8132\n", - "====> Epoch: 21 Average val loss: 1058.4624\n", - "Min loss 1058.46\n", - "====> Epoch: 22 Average train loss: 1058.6918\n", - "====> Epoch: 22 Average val loss: 1056.5654\n", - "Min loss 1056.57\n", - "====> Epoch: 23 Average train loss: 1056.6611\n", - "====> Epoch: 23 Average val loss: 1054.4075\n", - "Min loss 1054.41\n", - "====> Epoch: 24 Average train loss: 1054.8269\n", - "====> Epoch: 24 Average val loss: 1052.1930\n", - "Min loss 1052.19\n", - "====> Epoch: 25 Average train loss: 1052.5613\n", - "====> Epoch: 25 Average val loss: 1050.3650\n", - "Min loss 1050.36\n", - "====> Epoch: 26 Average train loss: 1050.4044\n", - "====> Epoch: 26 Average val loss: 1048.1295\n", - "Min loss 1048.13\n", - "====> Epoch: 27 Average train loss: 1048.4557\n", - "====> Epoch: 27 Average val loss: 1046.2268\n", - "Min loss 1046.23\n", - "====> Epoch: 28 Average train loss: 1046.7646\n", - "====> Epoch: 28 Average val loss: 1044.4366\n", - "Min loss 1044.44\n", - "====> Epoch: 29 Average train loss: 1044.8062\n", - "====> Epoch: 29 Average val loss: 1042.5750\n", - "Min loss 1042.57\n", - "====> Epoch: 30 Average train loss: 1042.6490\n", - "====> Epoch: 30 Average val loss: 1040.4587\n", - "Min loss 1040.46\n", - "====> Epoch: 31 Average train loss: 1040.9629\n", - "====> Epoch: 31 Average val loss: 1038.8336\n", - "Min loss 1038.83\n", - "====> Epoch: 32 Average train loss: 1038.8879\n", - "====> Epoch: 32 Average val loss: 1036.8384\n", - "Min loss 1036.84\n", - "====> Epoch: 33 Average train loss: 1037.3240\n", - "====> Epoch: 33 Average val loss: 1035.5549\n", - "Min loss 1035.55\n", - "====> Epoch: 34 Average train loss: 1035.4932\n", - "====> Epoch: 34 Average val loss: 1033.6082\n", - "Min loss 1033.61\n", - "====> Epoch: 35 Average train loss: 1033.9220\n", - "====> Epoch: 35 Average val loss: 1031.9670\n", - "Min loss 1031.97\n", - "====> Epoch: 36 Average train loss: 1032.3682\n", - "====> Epoch: 36 Average val loss: 1030.0439\n", - "Min loss 1030.04\n", - "====> Epoch: 37 Average train loss: 1030.5251\n", - "====> Epoch: 37 Average val loss: 1028.3313\n", - "Min loss 1028.33\n", - "====> Epoch: 38 Average train loss: 1028.4504\n", - "====> Epoch: 38 Average val loss: 1026.9329\n", - "Min loss 1026.93\n", - "====> Epoch: 39 Average train loss: 1027.0562\n", - "====> Epoch: 39 Average val loss: 1025.2065\n", - "Min loss 1025.21\n", - "====> Epoch: 40 Average train loss: 1025.8102\n", - "====> Epoch: 40 Average val loss: 1023.5693\n", - "Min loss 1023.57\n", - "====> Epoch: 41 Average train loss: 1023.5449\n", - "====> Epoch: 41 Average val loss: 1022.0502\n", - "Min loss 1022.05\n", - "====> Epoch: 42 Average train loss: 1022.1390\n", - "====> Epoch: 42 Average val loss: 1020.6116\n", - "Min loss 1020.61\n", - "====> Epoch: 43 Average train loss: 1020.8058\n", - "====> Epoch: 43 Average val loss: 1019.0730\n", - "Min loss 1019.07\n", - "====> Epoch: 44 Average train loss: 1019.7094\n", - "====> Epoch: 44 Average val loss: 1018.3165\n", - "Min loss 1018.32\n", - "====> Epoch: 45 Average train loss: 1017.3981\n", - "====> Epoch: 45 Average val loss: 1016.1403\n", - "Min loss 1016.14\n", - "====> Epoch: 46 Average train loss: 1016.0779\n", - "====> Epoch: 46 Average val loss: 1014.8041\n", - "Min loss 1014.80\n", - "====> Epoch: 47 Average train loss: 1014.8489\n", - "====> Epoch: 47 Average val loss: 1013.4117\n", - "Min loss 1013.41\n", - "====> Epoch: 48 Average train loss: 1013.5171\n", - "====> Epoch: 48 Average val loss: 1011.6441\n", - "Min loss 1011.64\n", - "====> Epoch: 49 Average train loss: 1012.0751\n", - "====> Epoch: 49 Average val loss: 1010.4889\n", - "Min loss 1010.49\n", - "====> Epoch: 50 Average train loss: 1011.1703\n", - "====> Epoch: 50 Average val loss: 1009.1624\n", - "Min loss 1009.16\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "-li2xlrmvLT0" - }, - "source": [ - "# Convolutional Deep CCA (and using other architectures)\n", - "We provide a standard CNN encoder and decoder but users can build their own encoders and decoders by inheriting BaseEncoder and BaseDecoder for seamless integration with the pipeline" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "AzfBwb3NvLT0", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "93bed1b4-1d78-40b8-c505-6caad6375276" - }, - "source": [ - "print('Convolutional DCCA')\n", - "encoder_1 = architectures.CNNEncoder(latent_dims=latent_dims, channels=[3, 3])\n", - "encoder_2 = architectures.CNNEncoder(latent_dims=latent_dims, channels=[3, 3])\n", - "dcca_conv_model = DCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])\n", - "\n", - "dcca_conv_model = DeepWrapper(dcca_conv_model)\n", - "\n", - "# to change the models used change the cfg.encoder_models. We implement a CNN_Encoder and CNN_decoder as well\n", - "# as some based on brainnet architecture in cca_zoo.architectures. Equally you could pass your own encoder/decoder models\n", - "\n", - "dcca_conv_model.fit((train_view_1.reshape((-1, 1, 28, 28)), train_view_2.reshape((-1, 1, 28, 28))), epochs=10)\n", - "\n", - "dcca_conv_results = np.stack((\n", - " dcca_conv_model.score((test_view_1.reshape((-1, 1, 28, 28)),test_view_2.reshape((-1, 1, 28, 28)))), \n", - " dcca_conv_model.score((test_view_1.reshape((-1, 1, 28, 28)),test_view_2.reshape((-1, 1, 28, 28))))))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Convolutional DCCA\n", - "total parameters: 9568\n", - "====> Epoch: 1 Average train loss: -0.6154\n", - "====> Epoch: 2 Average train loss: -0.7844\n", - "====> Epoch: 3 Average train loss: -0.9141\n", - "====> Epoch: 4 Average train loss: -1.0193\n", - "====> Epoch: 5 Average train loss: -1.1371\n", - "====> Epoch: 6 Average train loss: -1.2470\n", - "====> Epoch: 7 Average train loss: -1.3430\n", - "====> Epoch: 8 Average train loss: -1.4286\n", - "====> Epoch: 9 Average train loss: -1.4995\n", - "====> Epoch: 10 Average train loss: -1.5591\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "collapsed": false, - "id": "78IxzigYvLT0" - }, - "source": [ - "# DTCCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "pycharm": { - "name": "#%%\n" - }, - "id": "MHvAzaiGvLT1", - "colab": { - "base_uri": "https://localhost:8080/" - }, - "outputId": "2fd16f5f-0579-44c6-acc0-1f66a450ba40" - }, - "source": [ - "# %%\n", - "# DTCCA\n", - "print('DTCCA')\n", - "encoder_1 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "encoder_2 = architectures.Encoder(latent_dims=latent_dims, feature_size=784)\n", - "dtcca_model = DTCCA(latent_dims=latent_dims, encoders=[encoder_1, encoder_2])\n", - "\n", - "dtcca_model = DeepWrapper(dtcca_model)\n", - "\n", - "dtcca_model.fit(train_dataset, val_dataset=val_dataset, epochs=epochs)\n", - "\n", - "dtcca_results = np.stack((dtcca_model.score(train_dataset), dtcca_model.score(test_dataset)))" - ], - "execution_count": null, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "DTCCA\n", - "total parameters: 201476\n" - ] - }, - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.7/dist-packages/tensorly/backend/core.py:885: UserWarning: In partial_svd: converting to NumPy. Check SVD_FUNS for available alternatives if you want to avoid this.\n", - " warnings.warn('In partial_svd: converting to NumPy.'\n" - ] - }, - { - "output_type": "stream", - "name": "stdout", - "text": [ - "====> Epoch: 1 Average train loss: 0.0000\n", - "====> Epoch: 1 Average val loss: 0.0000\n", - "Min loss 0.00\n", - "====> Epoch: 2 Average train loss: 0.0000\n", - "====> Epoch: 2 Average val loss: 0.0000\n", - "Min loss 0.00\n", - "====> Epoch: 3 Average train loss: 0.0000\n", - "====> Epoch: 3 Average val loss: 0.0000\n", - "====> Epoch: 4 Average train loss: 0.0000\n", - "====> Epoch: 4 Average val loss: 0.0000\n", - "====> Epoch: 5 Average train loss: 0.0000\n", - "====> Epoch: 5 Average val loss: 0.0000\n", - "====> Epoch: 6 Average train loss: 0.0000\n", - "====> Epoch: 6 Average val loss: 0.0000\n", - "====> Epoch: 7 Average train loss: 0.0000\n", - "====> Epoch: 7 Average val loss: 0.0000\n", - "====> Epoch: 8 Average train loss: 0.0000\n", - "====> Epoch: 8 Average val loss: 0.0000\n", - "====> Epoch: 9 Average train loss: 0.0000\n", - "====> Epoch: 9 Average val loss: 0.0000\n", - "====> Epoch: 10 Average train loss: 0.0000\n", - "====> Epoch: 10 Average val loss: 0.0000\n", - "====> Epoch: 11 Average train loss: 0.0000\n", - "====> Epoch: 11 Average val loss: 0.0000\n", - "====> Epoch: 12 Average train loss: 0.0000\n", - "====> Epoch: 12 Average val loss: 0.0000\n", - "====> Epoch: 13 Average train loss: 0.0000\n", - "====> Epoch: 13 Average val loss: 0.0000\n", - "====> Epoch: 14 Average train loss: 0.0000\n", - "====> Epoch: 14 Average val loss: 0.0000\n", - "====> Epoch: 15 Average train loss: 0.0000\n", - "====> Epoch: 15 Average val loss: 0.0000\n", - "====> Epoch: 16 Average train loss: 0.0000\n", - "====> Epoch: 16 Average val loss: 0.0000\n", - "====> Epoch: 17 Average train loss: 0.0000\n", - "====> Epoch: 17 Average val loss: 0.0000\n", - "====> Epoch: 18 Average train loss: 0.0000\n", - "====> Epoch: 18 Average val loss: 0.0000\n", - "====> Epoch: 19 Average train loss: 0.0000\n", - "====> Epoch: 19 Average val loss: 0.0000\n", - "====> Epoch: 20 Average train loss: 0.0000\n", - "====> Epoch: 20 Average val loss: 0.0000\n", - "====> Epoch: 21 Average train loss: 0.0000\n", - "====> Epoch: 21 Average val loss: 0.0000\n", - "====> Epoch: 22 Average train loss: 0.0000\n", - "====> Epoch: 22 Average val loss: 0.0000\n", - "====> Epoch: 23 Average train loss: 0.0000\n", - "====> Epoch: 23 Average val loss: 0.0000\n", - "====> Epoch: 24 Average train loss: 0.0000\n", - "====> Epoch: 24 Average val loss: 0.0000\n", - "====> Epoch: 25 Average train loss: 0.0000\n", - "====> Epoch: 25 Average val loss: 0.0000\n", - "====> Epoch: 26 Average train loss: 0.0000\n", - "====> Epoch: 26 Average val loss: 0.0000\n", - "====> Epoch: 27 Average train loss: 0.0000\n", - "====> Epoch: 27 Average val loss: 0.0000\n", - "====> Epoch: 28 Average train loss: 0.0000\n", - "====> Epoch: 28 Average val loss: 0.0000\n", - "====> Epoch: 29 Average train loss: 0.0000\n", - "====> Epoch: 29 Average val loss: 0.0000\n", - "====> Epoch: 30 Average train loss: 0.0000\n", - "====> Epoch: 30 Average val loss: 0.0000\n", - "====> Epoch: 31 Average train loss: 0.0000\n", - "====> Epoch: 31 Average val loss: 0.0000\n", - "====> Epoch: 32 Average train loss: 0.0000\n", - "====> Epoch: 32 Average val loss: 0.0000\n", - "====> Epoch: 33 Average train loss: 0.0000\n", - "====> Epoch: 33 Average val loss: 0.0000\n", - "====> Epoch: 34 Average train loss: 0.0000\n", - "====> Epoch: 34 Average val loss: 0.0000\n", - "====> Epoch: 35 Average train loss: 0.0000\n", - "====> Epoch: 35 Average val loss: 0.0000\n", - "====> Epoch: 36 Average train loss: 0.0000\n", - "====> Epoch: 36 Average val loss: 0.0000\n", - "====> Epoch: 37 Average train loss: 0.0000\n", - "====> Epoch: 37 Average val loss: 0.0000\n", - "====> Epoch: 38 Average train loss: 0.0000\n", - "====> Epoch: 38 Average val loss: 0.0000\n", - "====> Epoch: 39 Average train loss: 0.0000\n", - "====> Epoch: 39 Average val loss: 0.0000\n", - "====> Epoch: 40 Average train loss: 0.0000\n", - "====> Epoch: 40 Average val loss: 0.0000\n", - "====> Epoch: 41 Average train loss: 0.0000\n", - "====> Epoch: 41 Average val loss: 0.0000\n", - "====> Epoch: 42 Average train loss: 0.0000\n", - "====> Epoch: 42 Average val loss: 0.0000\n", - "====> Epoch: 43 Average train loss: 0.0000\n", - "====> Epoch: 43 Average val loss: 0.0000\n", - "====> Epoch: 44 Average train loss: 0.0000\n", - "====> Epoch: 44 Average val loss: 0.0000\n", - "====> Epoch: 45 Average train loss: 0.0000\n", - "====> Epoch: 45 Average val loss: 0.0000\n", - "====> Epoch: 46 Average train loss: 0.0000\n", - "====> Epoch: 46 Average val loss: 0.0000\n", - "====> Epoch: 47 Average train loss: 0.0000\n", - "====> Epoch: 47 Average val loss: 0.0000\n", - "====> Epoch: 48 Average train loss: 0.0000\n", - "====> Epoch: 48 Average val loss: 0.0000\n", - "====> Epoch: 49 Average train loss: 0.0000\n", - "====> Epoch: 49 Average val loss: 0.0000\n", - "====> Epoch: 50 Average train loss: 0.0000\n", - "====> Epoch: 50 Average val loss: 0.0000\n", - "reconstruction error=1.7372686071707115e-08\n", - "iteration 1, reconstruction error: 1.7372686071707115e-08, decrease = 0.0, unnormalized = 2.634178031930877e-09\n", - "PARAFAC converged after 1 iterations\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "4PLCBruxlAAu" - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file diff --git a/tutorial_notebooks/cca_zoo_weights_and_sparsity.ipynb b/tutorial_notebooks/cca_zoo_weights_and_sparsity.ipynb deleted file mode 100644 index 5dd11ec7..00000000 --- a/tutorial_notebooks/cca_zoo_weights_and_sparsity.ipynb +++ /dev/null @@ -1,1096 +0,0 @@ -{ - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "colab": { - "name": "cca_zoo_sparsity.ipynb", - "provenance": [], - "toc_visible": true, - "include_colab_link": true - }, - "kernelspec": { - "name": "python3", - "language": "python", - "display_name": "Python 3" - }, - "language_info": { - "name": "python" - } - }, - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "view-in-github", - "colab_type": "text" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rpca4lWRvlwC" - }, - "source": [ - "# A tutorial on using cca-zoo to generate multiview models with sparsity on weights" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "riuTJcsuvRcS", - "outputId": "f7b0b940-9513-4b10-e842-1ba20178c223" - }, - "source": [ - "!pip install cca-zoo --upgrade" - ], - "execution_count": 1, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Collecting cca-zoo\n", - " Downloading cca_zoo-1.9.0-py3-none-any.whl (68 kB)\n", - "\u001b[?25l\r\u001b[K |████▊ | 10 kB 23.0 MB/s eta 0:00:01\r\u001b[K |█████████▌ | 20 kB 10.1 MB/s eta 0:00:01\r\u001b[K |██████████████▎ | 30 kB 8.3 MB/s eta 0:00:01\r\u001b[K |███████████████████ | 40 kB 7.7 MB/s eta 0:00:01\r\u001b[K |███████████████████████▊ | 51 kB 4.2 MB/s eta 0:00:01\r\u001b[K |████████████████████████████▌ | 61 kB 4.4 MB/s eta 0:00:01\r\u001b[K |████████████████████████████████| 68 kB 2.8 MB/s \n", - "\u001b[?25hRequirement already satisfied: seaborn in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (0.11.2)\n", - "Collecting scipy>=1.7\n", - " Downloading scipy-1.7.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.whl (28.5 MB)\n", - "\u001b[K |████████████████████████████████| 28.5 MB 49 kB/s \n", - "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (1.19.5)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (1.1.5)\n", - "Collecting mvlearn\n", - " Downloading mvlearn-0.4.1-py3-none-any.whl (2.1 MB)\n", - "\u001b[K |████████████████████████████████| 2.1 MB 45.2 MB/s \n", - "\u001b[?25hCollecting tensorly\n", - " Downloading tensorly-0.6.0-py3-none-any.whl (160 kB)\n", - "\u001b[K |████████████████████████████████| 160 kB 53.3 MB/s \n", - "\u001b[?25hRequirement already satisfied: joblib in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (1.0.1)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.7/dist-packages (from cca-zoo) (3.2.2)\n", - "Collecting scikit-learn>=0.23\n", - " Downloading scikit_learn-1.0-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (23.1 MB)\n", - "\u001b[K |████████████████████████████████| 23.1 MB 79.1 MB/s \n", - "\u001b[?25hCollecting threadpoolctl>=2.0.0\n", - " Downloading threadpoolctl-3.0.0-py3-none-any.whl (14 kB)\n", - "Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (2.4.7)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (0.10.0)\n", - "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (2.8.2)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.7/dist-packages (from matplotlib->cca-zoo) (1.3.2)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.7/dist-packages (from cycler>=0.10->matplotlib->cca-zoo) (1.15.0)\n", - "Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->cca-zoo) (2018.9)\n", - "Collecting nose\n", - " Downloading nose-1.3.7-py3-none-any.whl (154 kB)\n", - "\u001b[K |████████████████████████████████| 154 kB 47.1 MB/s \n", - "\u001b[?25hInstalling collected packages: threadpoolctl, scipy, scikit-learn, nose, tensorly, mvlearn, cca-zoo\n", - " Attempting uninstall: scipy\n", - " Found existing installation: scipy 1.4.1\n", - " Uninstalling scipy-1.4.1:\n", - " Successfully uninstalled scipy-1.4.1\n", - " Attempting uninstall: scikit-learn\n", - " Found existing installation: scikit-learn 0.22.2.post1\n", - " Uninstalling scikit-learn-0.22.2.post1:\n", - " Successfully uninstalled scikit-learn-0.22.2.post1\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.2.9 which is incompatible.\u001b[0m\n", - "Successfully installed cca-zoo-1.9.0 mvlearn-0.4.1 nose-1.3.7 scikit-learn-1.0 scipy-1.7.1 tensorly-0.6.0 threadpoolctl-3.0.0\n" - ] - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "LVmJ5X8RvV3_" - }, - "source": [ - "from cca_zoo.models import PMD, SCCA, ElasticCCA, CCA, PLS, SCCA_ADMM, SpanCCA\n", - "from cca_zoo.model_selection import GridSearchCV\n", - "from cca_zoo.data import generate_covariance_data\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import itertools\n", - "import pandas as pd" - ], - "execution_count": 20, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IkMwUGzkwbQY" - }, - "source": [ - "## Generate some data\n", - "set the true correlation and the sparsity of the true weights" - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "zdYep44wvtKo" - }, - "source": [ - "np.random.seed(42)\n", - "n=200\n", - "p=100\n", - "q=100\n", - "view_1_sparsity=0.1\n", - "view_2_sparsity=0.1\n", - "true_latent_dims=1\n", - "\n", - "(X,Y),(tx, ty)=generate_covariance_data(n,view_features=[p,q],latent_dims=true_latent_dims,\n", - " view_sparsity=[view_1_sparsity,view_2_sparsity],correlation=[0.9])\n", - "#normalize weights for comparability\n", - "tx/=np.sqrt(n)\n", - "ty/=np.sqrt(n)" - ], - "execution_count": 3, - "outputs": [] - }, - { - "cell_type": "code", - "metadata": { - "id": "ijitQkskw_jw" - }, - "source": [ - "def plot_true_weights_coloured(ax, weights, true_weights, title=''):\n", - " ind = np.arange(len(true_weights))\n", - " mask = np.squeeze(true_weights == 0)\n", - " ax.scatter(ind[~mask], weights[~mask], c='b')\n", - " ax.scatter(ind[mask], weights[mask], c='r')\n", - " ax.set_title(title)\n", - "\n", - "def plot_model_weights(wx,wy,tx,ty):\n", - " fig,axs=plt.subplots(2,2,sharex=True,sharey=True)\n", - " plot_true_weights_coloured(axs[0,0],tx,tx,title='true x weights')\n", - " plot_true_weights_coloured(axs[0,1],ty,ty,title='true y weights')\n", - " plot_true_weights_coloured(axs[1,0],wx,tx,title='model x weights')\n", - " plot_true_weights_coloured(axs[1,1],wy,ty,title='model y weights')\n", - " plt.tight_layout()\n", - " plt.show()" - ], - "execution_count": 4, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "x_-JR1lywpNO" - }, - "source": [ - "## First try with CCA" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 297 - }, - "id": "as1irviNwnCW", - "outputId": "9c3a3a77-77f0-41a1-f8ef-08a20c176413" - }, - "source": [ - "#fit a cca model\n", - "cca=CCA().fit([X,Y])\n", - "\n", - "plot_model_weights(cca.weights[0],cca.weights[1],tx,ty)" - ], - "execution_count": 5, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "AMLK2z5C1bFf" - }, - "source": [ - "## PLS" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 297 - }, - "id": "VqnBFLwFw1Fi", - "outputId": "55d3dae2-9e3d-4bc4-a3e9-6c2ef72281d3" - }, - "source": [ - "#fit a pls model\n", - "pls=PLS().fit([X,Y])\n", - "\n", - "plot_model_weights(pls.weights[0],pls.weights[1],tx,ty)" - ], - "execution_count": 6, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "maZg5LdP1l3H" - }, - "source": [ - "## Penalized Matrix Decomposition (Sparse CCA by Witten)\n", - "Initially set c=2 for both views arbitrarily" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 297 - }, - "id": "2petCaj61ffh", - "outputId": "445a4f81-9a6f-42e1-8840-b56750f75ff1" - }, - "source": [ - "#fit a pmd model\n", - "pmd=PMD(c=[2,2]).fit([X,Y])\n", - "\n", - "plot_model_weights(pmd.weights[0],pmd.weights[1],tx,ty)" - ], - "execution_count": 7, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gTZEL2SBTijd" - }, - "source": [ - "## Tracking the objective\n", - "For these iterative algorithms, you can access the convergence over iterations" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 312 - }, - "id": "eMutm5DjTh_V", - "outputId": "351effcb-3498-40f2-c09f-cac8167b3c9f" - }, - "source": [ - "#Convergence\n", - "plt.figure()\n", - "plt.title('Objective Convergence')\n", - "plt.plot(np.array(pmd.track[0]['objective']).T)\n", - "plt.ylabel('Objective')\n", - "plt.xlabel('#iterations')" - ], - "execution_count": 13, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Text(0.5, 0, '#iterations')" - ] - }, - "metadata": {}, - "execution_count": 13 - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OusaWIn82Wb7" - }, - "source": [ - "### We can also tune the hyperparameter using GridSearchCV" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "O7VXP9h21vyB", - "outputId": "8343f511-301d-40fd-910f-4071e761893e" - }, - "source": [ - "#Set up a grid. We can't use c<1 or c>sqrt(#features)\n", - "c1 = [1, 3, 7, 9]\n", - "c2 = [1, 3, 7, 9]\n", - "param_grid = {'c': [c1,c2]}\n", - "\n", - "#GridSearchCV can use multiple cores (jobs) and takes folds (number of cv folds) as a parameter. It can also produce a plot.\n", - "pmd = GridSearchCV(PMD(),param_grid=param_grid,\n", - " cv=3,\n", - " verbose=True).fit([X,Y])" - ], - "execution_count": 17, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Fitting 3 folds for each of 16 candidates, totalling 48 fits\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DscFV-7P3dU2" - }, - "source": [ - "Also the model object now has a pandas dataframe containing the results from each fold" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 545 - }, - "id": "2GNSiTjC21fB", - "outputId": "cb6ad3e8-6e82-4c84-f231-35babd967776" - }, - "source": [ - "pd.DataFrame(pmd.cv_results_)" - ], - "execution_count": 21, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
mean_fit_timestd_fit_timemean_score_timestd_score_timeparam_cparamssplit0_test_scoresplit1_test_scoresplit2_test_scoremean_test_scorestd_test_scorerank_test_score
00.0342420.0067900.0016470.000124[1, 1]{'c': [1, 1]}0.052058-0.033554-0.071958-0.0178180.05183813
10.0360550.0054290.0015950.000006[1, 3]{'c': [1, 3]}0.1537890.3604820.0363180.1835300.1340002
20.0331630.0014880.0016050.000035[1, 7]{'c': [1, 7]}0.1608030.331418-0.0130430.1597260.1406284
30.0273500.0030100.0015670.000019[1, 9]{'c': [1, 9]}0.1640700.330148-0.0012630.1643180.1352983
40.0446100.0157160.0015720.000018[3, 1]{'c': [3, 1]}0.073034-0.208602-0.024809-0.0534590.11674915
50.1711060.0709840.0015840.000017[3, 3]{'c': [3, 3]}0.0461860.6000120.4144160.3535380.2301601
60.2430980.0428680.0016190.000019[3, 7]{'c': [3, 7]}0.014920-0.0918100.1638790.0289960.10485810
70.2246370.0199150.0016430.000097[3, 9]{'c': [3, 9]}-0.0605650.2759570.1132120.1095350.1374095
80.0404130.0089710.0017050.000155[7, 1]{'c': [7, 1]}0.024604-0.163847-0.104591-0.0812780.07868116
90.1898190.0656530.0015820.000013[7, 3]{'c': [7, 3]}0.0497190.135290-0.1421860.0142750.11601911
100.1835630.0326230.0017250.000167[7, 7]{'c': [7, 7]}0.0738140.0444160.0444390.0542230.0138537
110.1026140.0168890.0016250.000106[7, 9]{'c': [7, 9]}0.0681250.0665070.0448140.0598150.0106286
120.0323350.0098010.0015770.000022[9, 1]{'c': [9, 1]}0.052391-0.073953-0.130107-0.0505560.07631914
130.1469380.0570940.0017080.000134[9, 3]{'c': [9, 3]}0.0433830.149690-0.219036-0.0086540.15496412
140.1512950.0740130.0017060.000203[9, 7]{'c': [9, 7]}0.0754320.0194890.0524290.0491170.0229588
150.0374960.0067300.0015590.000010[9, 9]{'c': [9, 9]}0.0557170.0376960.0414030.0449390.0077719
\n", - "
" - ], - "text/plain": [ - " mean_fit_time std_fit_time ... std_test_score rank_test_score\n", - "0 0.034242 0.006790 ... 0.051838 13\n", - "1 0.036055 0.005429 ... 0.134000 2\n", - "2 0.033163 0.001488 ... 0.140628 4\n", - "3 0.027350 0.003010 ... 0.135298 3\n", - "4 0.044610 0.015716 ... 0.116749 15\n", - "5 0.171106 0.070984 ... 0.230160 1\n", - "6 0.243098 0.042868 ... 0.104858 10\n", - "7 0.224637 0.019915 ... 0.137409 5\n", - "8 0.040413 0.008971 ... 0.078681 16\n", - "9 0.189819 0.065653 ... 0.116019 11\n", - "10 0.183563 0.032623 ... 0.013853 7\n", - "11 0.102614 0.016889 ... 0.010628 6\n", - "12 0.032335 0.009801 ... 0.076319 14\n", - "13 0.146938 0.057094 ... 0.154964 12\n", - "14 0.151295 0.074013 ... 0.022958 8\n", - "15 0.037496 0.006730 ... 0.007771 9\n", - "\n", - "[16 rows x 12 columns]" - ] - }, - "metadata": {}, - "execution_count": 21 - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "O7rWUBmb4apX" - }, - "source": [ - "## Sparse CCA by iterative lasso (Mai)" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 592 - }, - "id": "FimdDUDe3kML", - "outputId": "8af2f49f-2571-4f6f-e8ba-3b5becabadaf" - }, - "source": [ - "#fit a scca model\n", - "scca=SCCA(c=[1e-3,1e-3]).fit([X,Y])\n", - "\n", - "plot_model_weights(scca.weights[0],scca.weights[1],tx,ty)\n", - "\n", - "#Convergence\n", - "plt.figure()\n", - "plt.title('Objective Convergence')\n", - "plt.plot(np.array(scca.track[0]['objective']).T)\n", - "plt.ylabel('Objective')\n", - "plt.xlabel('#iterations')" - ], - "execution_count": 23, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Text(0.5, 0, '#iterations')" - ] - }, - "metadata": {}, - "execution_count": 23 - }, - { - "output_type": "display_data", - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYgAAAEWCAYAAAB8LwAVAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAgAElEQVR4nO3deZxcVZ338c+3qrq7ku6snQYhhASSCAMuLGEbN2RRcNSog4IrjMww6uAyjs6DvlyQR2fEURm3R4cRBBcEjaJRUVTiiKOIJGyySsKWhADZl056rd/zx73VuWmql5CuVFP1fb9e9ep7z71169y+SX37nHMXRQRmZmaD5WpdATMzG58cEGZmVpEDwszMKnJAmJlZRQ4IMzOryAFhZmYVOSCsJiRdKOnbwyy/W9KJVfjcqmzXrB45IKwqJJ0j6c+Stkt6XNJXJU0d7fsj4vCI+J89rMMVkj451tsd4rOa09B7QFKnpIclXS5pzlh/ltne4oCwMSfpX4CLgQ8CU4DjgdnAryQ117JuVbQIeDXwJpJ9fj6wDDi5lpXKklSodR3sGSYi/PJrzF7AZGAb8IZB5W3AWuDt6fyFJF+q1wBbgVuB52fWfxg4JZ3OARcAK4D1wPeA6Zl1Xwj8AdgErATOAc4DeoGetD4/yW4X2B/YMWg7RwLrgKZ0/u3AvcBG4Hpg9hD7fEq6rVnD/F72BxYDG4DlwD9kll2Y7tM309/F3cCCdNn/ARYN2tYXgC+m01OAy4A1wGrgk0A+XXYO8HvgkvT39kmgHfgJsAW4JS3738y2DwV+ldbz/uxxBK4AvgL8LK3nzcDczPLDM+99AvjwaI6fX+P35RaEjbW/BorAD7OFEbENuA44NVO8EPg+MB24CviRpKYK23w38BrgJSRftBtJvqiQNBv4OfAloAM4Arg9Ii4FvgN8JiLaIuJVg+rzGHAT8LeZ4jeRfBn3SloIfBh4Xbrd3wHfHWKfTwH+FBErh1gOcDWwKq3/GcC/STops/zV6TpTSYLky5n3vULSpHR/88AbSH5fkHxp9wHzSALuZcDfZ7Z7HPAgsC/wKZLfWyfwLODs9EW67VaSL/irgH2As4D/J+mwzPbOAj4BTCMJuk+l750E/Br4RbqP84Ab0vcMefxsnKt1QvlVXy/gLcDjQyz7NPCrdPpC4I+ZZTmSv4JflM4/zM4WxL3AyZl19yNpHRSADwHXDvF5VwCfHFSW3e7fA0vSaZG0Pl6czv8cOHdQ/bZToRUB/Ddw9TC/k1lAPzApU/bvwBWZ38WvM8sOA3Zk5v8XeFs6fSqwIp3eF+gGJmTWfSPwm3T6HODRzLJ8+ns7JFM20IIAzgR+N6ju/wV8PPP7/Hpm2SuA+zKfe9sQ+z/k8av1v1e/hn+5T9LG2jpghqRCRPQNWrZfurxs4C/uiChJKv+FPdhs4FpJpUxZP8kX5CySroun4wfAlyTtBzwbKJG0FMqf+QVJn8usL2Am8Mig7axP3z+U/YENEbE1U/YIsCAz/3hmejtQzPwOryL5Av4mSSun3HqYDTQBaySV35sj83sdNN1BEqpDLZ8NHCdpU6asAHxrmHq2pdPDHYfhjt/qId5j44C7mGys3UTyV+3rsoWS2oDT2dntAMmXSnl5DjgAeKzCNlcCp0fE1MyrGBGr02Vzh6jLsLcqjoiNwC9J/nJ+E0kroPyelcA/DvrMCRHxhwqb+jVwrKQDhviox4Dp5W6i1IGM/svx+8CJ6fZfy86AWEnyu56RqePkiDg8u5uZ6bUk3VHZes7KTK8Efjton9si4p2jqONK4OBhlg11/Gwcc0DYmIqIzSR91F+SdJqkpvRUz++R9MFn/xo9WtLr0rNr3kfyZffHCpv9GvCpdLwBSR3pGAEk4wynSHqDpIKkdklHpMueYOgvrbKrgLeRjAtclSn/GvAhSYennzlF0uuH2Odfk/TdXyvp6LQekyS9Q9LbIxmb+APw75KKkp4HnAsMeR3IoO2vBf4H+AbwUETcm5avIQm4z0maLCknaa6klwyxnX6SsaELJU2UdGi672U/BZ4t6a3pcWuSdIykvxpFNX8K7CfpfZJa0v0/Ll023PGzccwBYWMuIj5DMsD7WZKzZW4m+Svy5Ijozqz6Y5K/3jcCbwVeFxG9FTb5BZKB219K2koSIseln/UoSV/4v5CcPXM7ySmmkJzdc5ikTZJ+NER1FwPzScZN7sjsw7Ukp+peLWkLcBdJC2goZ5AMwl8DbE7XX0DSuoCki2gOSWviWpJ+/V8/dTNDuopkMPyqQeVvA5qBe0h+j4tIuvKGcj7JmU+Pk4T1d0mCmbQL7GUkA9GPpetcDLSMVLn0vacCr0rf9wDw0nTxkMfPxjftbFGbjR+SHgXeEhE31rou9UzSxcCzIuLsEVe2huMWhI07kjpIBlQfrnFV6o6kQyU9T4ljSbq6rq11vWx88llMNq5IOoakP/9LafeRja1JJN1K+5OM0XyOpKvP7CncxWRmZhW5i8nMzCqqmy6mGTNmxJw5c2pdDTOzZ5Rly5ati4iOSsvqJiDmzJnD0qVLa10NM7NnFEmD7wwwwF1MZmZWkQPCzMwqckCYmVlFDggzM6vIAWFmZhU5IMzMrCIHhJmZVdTwAfHYph18/pf389C6zlpXxcxsXGn4gFi3rZsvLlnOiie31boqZmbjSsMHRLEpD0BXX3+Na2JmNr44IAppQPSWRljTzKyxOCCakl9BV69bEGZmWQ0fEC3lLiYHhJnZLho+IMotiO4+dzGZmWU1fEA053NIbkGYmQ3W8AEhiWIh74AwMxuk4QMCoKUp57OYzMwGcUCAWxBmZhU4IEgGqrs8SG1mtgsHBMnV1G5BmJntygFBci2EA8LMbFcOCKBYyNHtQWozs104IEi7mHyzPjOzXTggSAep3cVkZrYLBwTlQWp3MZmZZTkg8HUQZmaVOCBwF5OZWSUOCMqD1O5iMjPLckCQXAfR01eiVIpaV8XMbNxwQLDzmRA9/W5FmJmVVTUgJJ0m6X5JyyVdUGF5i6Rr0uU3S5qTWfY8STdJulvSnyUVq1XPnc+l9jiEmVlZ1QJCUh74CnA6cBjwRkmHDVrtXGBjRMwDLgEuTt9bAL4NvCMiDgdOBHqrVdfiwGNH3YIwMyurZgviWGB5RDwYET3A1cDCQessBK5MpxcBJ0sS8DLgzoi4AyAi1kdE1f68L3cxuQVhZrZTNQNiJrAyM78qLau4TkT0AZuBduDZQEi6XtKtkv61ivXc2YLw7TbMzAYUal2BIRSAFwLHANuBGyQti4gbsitJOg84D+DAAw982h+2swXhLiYzs7JqtiBWA7My8wekZRXXSccdpgDrSVobN0bEuojYDlwHHDX4AyLi0ohYEBELOjo6nnZFPUhtZvZU1QyIW4D5kg6S1AycBSwetM5i4Ox0+gxgSUQEcD3wXEkT0+B4CXBPtSra0uSAMDMbrGpdTBHRJ+l8ki/7PHB5RNwt6SJgaUQsBi4DviVpObCBJESIiI2SPk8SMgFcFxE/q1Zd3cVkZvZUVR2DiIjrSLqHsmUfy0x3Aa8f4r3fJjnVterKg9TdHqQ2MxvgK6nJXgfhgDAzK3NAkDxyFNzFZGaW5YDALQgzs0ocEECLWxBmZk/hgAAK+RyFnHwltZlZhgMilTyX2gFhZlbmgEgljx11F5OZWZkDItVSyNPtFoSZ2QAHRKrYlPMYhJlZhgMilYxBuIvJzKzMAZHyILWZ2a4cEKlkkNoBYWZW5oBIFQvuYjIzy3JApIpNeQ9Sm5llOCBSLU05ut2CMDMb4IBIeZDazGxXDohUMgbhgDAzK3NApJIL5dzFZGZW5oBIFZvy9JeC3n6HhJkZOCAGFJuSX0W3WxFmZoADYoCfKmdmtisHRKpYcECYmWU5IFItTX7sqJlZlgMi5S4mM7NdOSBS5YDo9u02zMwAB8SAYsFdTGZmWQ6IlLuYzMx25YBI7QwItyDMzMABMaBloIvJLQgzM3BADBhoQXiQ2swMcEAMKPo6CDOzXTggUh6kNjPblQMiVR6D6HZAmJkBDogBkmgp+JkQZmZlDogMP3bUzGynqgaEpNMk3S9puaQLKixvkXRNuvxmSXPS8jmSdki6PX19rZr1LCs25RwQZmapQrU2LCkPfAU4FVgF3CJpcUTck1ntXGBjRMyTdBZwMXBmumxFRBxRrfpVkrQg3MVkZgbVbUEcCyyPiAcjoge4Glg4aJ2FwJXp9CLgZEmqYp2GVSy4i8nMrKyaATETWJmZX5WWVVwnIvqAzUB7uuwgSbdJ+q2kF1X6AEnnSVoqaenatWv3uMLFJg9Sm5mVjddB6jXAgRFxJPB+4CpJkwevFBGXRsSCiFjQ0dGxxx/a4kFqM7MB1QyI1cCszPwBaVnFdSQVgCnA+ojojoj1ABGxDFgBPLuKdQWSMQhfB2FmlqhmQNwCzJd0kKRm4Cxg8aB1FgNnp9NnAEsiIiR1pIPcSDoYmA88WMW6AskzITxIbWaWqNpZTBHRJ+l84HogD1weEXdLughYGhGLgcuAb0laDmwgCRGAFwMXSeoFSsA7ImJDtepaVmzK+2Z9ZmapqgUEQERcB1w3qOxjmeku4PUV3vcD4AfVrFslvg7CzGyn8TpIXRO+DsLMbCcHRIZvtWFmtpMDIqNYyNHdVyIial0VM7Oac0BktKTPhOj2xXJmZqMLCEkTJX1U0n+n8/MlvbK6Vdv7yg8N6vY4hJnZqFsQ3wC6gRPS+dXAJ6tSoxoaeOyoT3U1Mxt1QMyNiM8AvQARsR2o2U31qqVY8GNHzczKRhsQPZImAAEgaS5Ji6Ku7HwutbuYzMxGe6HchcAvgFmSvgO8ADinSnWqmYEuJrcgzMxGFxAR8UtJy4DjSbqW3hsR66pasxrY2YJwQJiZjSogJP0EuApYHBGd1a1S7bQUyoPU7mIyMxvtGMRngRcB90haJOkMScUq1qsm3IIwM9tptF1MvwV+m96C+yTgH4DLgac8xOeZzGMQZmY7jfpurulZTK8CzgSOYuezpOvGpGITAFu6+mpcEzOz2hvtGMT3gGNJzmT6MvDbiKi7jvrprc0ArN9Wd2fwmpntttG2IC4D3hgRdd330pTPMXViE+u39dS6KmZmNTdsQEg6KSKWAK3AQmnXi6cj4odVrFtNtLc2s84tCDOzEVsQLwGWkIw9DBZA3QXEjLYWtyDMzBghICLi4+nkRRHxUHaZpIOqVqsamtHWwr1rttS6GmZmNTfa6yAqPR960VhWZLyY0eYuJjMzGHkM4lDgcGCKpNdlFk0G6u5COYD2tha2dPXR3ddPS3p3VzOzRjTSGMQhwCuBqew6DrGV5GK5ujOjrQWADZ097DdlQo1rY2ZWOyONQfwY+LGkEyLipr1Up5pqb0uuhVi31QFhZo1ttGMQ75A0tTwjaZqky6tUp5qaUQ6ITo9DmFljG21APC8iNpVnImIjcGR1qlRb5S6mdVsdEGbW2EYbEDlJ08ozkqazG/dxeiZpTwNifaevhTCzxjbaL/nPATdJ+n46/3rgU9WpUm21NucpNuXcgjCzhjfa231/U9JSklt9A7wuIu6pXrVqRxLtrS1uQZhZwxttFxPAdKAzIr4MrK3XK6kBZkxq8cVyZtbwRhUQkj4O/B/gQ2lRE/DtalWq1ma0NrPO92MyswY32hbEa4FXA50AEfEYMKlalaq15IZ9bkGYWWMbbUD0RESQ3MEVSa3Vq1Lttbc1s76zh1Ipal0VM7OaGW1AfE/SfwFTJf0D8Gvgv6tXrdqa0dZCfynYvKO31lUxM6uZ0Z7F9FlJpwJbSO7P9LGI+FVVa1ZDA7fb2NbNtPQxpGZmjWbUZzFFxK8i4oMR8YHRhoOk0yTdL2m5pAsqLG+RdE26/GZJcwYtP1DSNkkfGG09x0JH+WpqD1SbWQMbNiAk/W/6c6ukLRVeD0l61xDvzQNfAU4HDgPeKOmwQaudC2yMiHnAJcDFg5Z/Hvj57u/WnmkfCAgPVJtZ4xo2ICLihenPSRExefALWAC8d4i3Hwssj4gHI6IHuBpYOGidhcCV6fQi4GSlD76W9BrgIeDup7Nje6J8wz6fyWRmjWzUXUySjpL0HknvlnQkQESsB04c4i0zgZWZ+VVpWcV1IqIP2Ay0S2ojue7iEyPU6TxJSyUtXbt27Wh3ZURTJzaTk7uYzKyxjfZCuY+R/KXfDswArpD0EYCIWFOFel0IXBIR24ZbKSIujYgFEbGgo6NjzD48nxPTW1tY71t+m1kDG+3N+t4MPD8iugAkfRq4HfjkMO9ZDczKzB+QllVaZ5WkAjAFWA8cB5wh6TMkT7MrSepKb/OxV8xoa2btVrcgzKxxjTYgHiN5BnVXOt/CU7/sB7sFmJ/es2k1cBbwpkHrLAbOBm4CzgCWpBfkvai8gqQLgW17MxwgvZraLQgza2DDBoSkL5FcPb0ZuFtS+fTWU4A/DffeiOiTdD5wPZAHLo+IuyVdBCyNiMXAZcC3JC0HNpCEyLjQ3tbMI4921roaZmY1M1ILYmn68x7gBpKw6AN+M5qNR8R1wHWDyj6Wme4iebbEcNu4cDSfNdaS+zG5i8nMGtdIAXEVyYOB3g48Agg4EPgG8OHqVq222tua2d7Tz/aePiY21+XD88zMhjXSWUyfAaYBB0XE0RFxFHAwyWDyf1S7crVUfja1WxFm1qhGCohXAudFxNZyQURsAd4J/E01K1Zr5Yvl1vpiOTNrUCMFRKRnFQ0u7Ce99Xe9cgvCzBrdSAFxj6S3DS6U9BbgvupUaXzw/ZjMrNGNNPr6T8APJb0dWJaWLQAmkDxlrm51tLWQz4nVG3fUuipmZjUxbEBExGrgOEknAYenxddFxA1Vr1mNNRdyzJ4+kRVrh73bh5lZ3RrtA4OWAEuqXJdx5+CONpY/6YAws8Y06ru5NqJ5+7Tx8PpO+vpLta6Kmdle54AYxtyOVnr7g5UehzCzBuSAGMbcfdoAWOFuJjNrQA6IYcztSAPCA9Vm1oAcEMOYMqGJjkktDggza0gOiBHM7Wj1mUxm1pAcECOY29HGirWdVLjjiJlZXXNAjGBuRxubd/SyvtP3ZDKzxuKAGME8n8lkZg3KATGCgVNd1/rxo2bWWBwQI9hvcpEJTXkPVJtZw3FAjCCXEwd3tPpUVzNrOA6IUZi3T5sDwswajgNiFOZ2tLF60w529PTXuipmZnuNA2IU5na0EQEPrfNAtZk1DgfEKJRPdV3ubiYzayAOiFGY3T6RnOAvj2+tdVXMzPYaB8QoFJvyHL7/FP708IZaV8XMbK9xQIzSCXPbuf3RTXT1eqDazBqDA2KUTji4nZ7+Erc+srHWVTEz2yscEKO0YM408jlx04Pra10VM7O9wgExSpOKTTxn5hRuWuGAMLPG4IDYDScc3M4dqzaxvaev1lUxM6s6B8RuOP7g6fT2B8s8DmFmDcABsRuOmTOdQk7uZjKzhuCA2A2tLQWed8AU/uiBajNrAFUNCEmnSbpf0nJJF1RY3iLpmnT5zZLmpOXHSro9fd0h6bXVrOfuOP7gdu5ctZnObo9DmFl9q1pASMoDXwFOBw4D3ijpsEGrnQtsjIh5wCXAxWn5XcCCiDgCOA34L0mFatV1d5wwt52+UnCLr6o2szpXzRbEscDyiHgwInqAq4GFg9ZZCFyZTi8CTpakiNgeEeU/0YtAVLGeu+Xo2dNoyos/eBzCzOpcNQNiJrAyM78qLau4ThoIm4F2AEnHSbob+DPwjkxgDJB0nqSlkpauXbu2CrvwVBObC7xg3gx+csdj9JfGTW6ZmY25cTtIHRE3R8ThwDHAhyQVK6xzaUQsiIgFHR0de61uZxx9AGs2d/GHFev22meame1t1QyI1cCszPwBaVnFddIxhinALn03EXEvsA14TtVquptO+at9mVwssGjZqlpXxcysaqoZELcA8yUdJKkZOAtYPGidxcDZ6fQZwJKIiPQ9BQBJs4FDgYerWNfdUmzK8+oj9ucXdz3Olq7eWlfHzKwqqhYQ6ZjB+cD1wL3A9yLibkkXSXp1utplQLuk5cD7gfKpsC8E7pB0O3At8K6IGFf9OWccPYvuvhI/u3NNratiZlYViqiPgdYFCxbE0qVL99rnRQSnXnIjUyY08YN3/vVe+1wzs7EkaVlELKi0bNwOUo93kjjj6ANY9shGHvSzqs2sDjkg9sBrj5xJTvB9D1abWR1yQOyBfScXOfWwffn2TY+wsbOn1tUxMxtTDog99P5TD2FbTx9fu3FFratiZjamHBB76JBnTeI1R8zkyj88zBNbumpdHTOzMeOAGAPvO2U+ff3Bl5Y8UOuqmJmNGQfEGJjd3spZx87i6j+t5NH122tdHTOzMeGAGCPvPmk+hby4+Pr7al0VM7Mx4YAYI/tOLvKuE+fxszvX8OPbB99yyszsmccBMYbedeJcFsyexkeuvYuVG9zVZGbPbA6IMVTI57jkzCMA+Odrbqevv1TjGpmZPX0OiDE2a/pEPvna57D0kY18ccnyWlfHzOxpGxfPea43C4+Yye8eWMcXb3iAjkktvPX42bWukpnZbnNAVMm/vfa5bNrew0d/dBcTmvKccfQBta6SmdlucRdTlTQXcnz5TUfxovkz+NdFd7D4jsdqXSUzs93igKiiYlOeS9+6gAVzpvOe797G5395P/2l+nj+hpnVPwdElU1ozvPNtx/L648+gC8uWc453/gTG3znVzN7BnBA7AXFpjyfOeN5fPp1z+XmhzZw+hdu5Kd3Pka9PM3PzOqTA2IvkcRZxx7ID9/517S3tnD+Vbfx5q/fzANPbK111czMKnJA7GXPmTmFn7z7hfzfhYdz1+rNvPw/b+T8q27lrtWba101M7Nd+DTXGsjnxFtPmMMrnrsfl/7uQb7zx0f56Z1reOG8GZx5zCxOPWxfik35WlfTzBqc6qUffMGCBbF06dJaV+Np2byjl+/c/AjfuukR1mzuYnKxwCufvz9/89z9OPag6TTl3dAzs+qQtCwiFlRc5oAYP/pLwU0r1vODW1fxi7seZ0dvP1MnNnHyofty0qH78ML5M5gyoanW1TSzOuKAeAba0dPPjQ+s5fq7HufX9z7Blq4+8jlx5KypnDC3nWMPms5RB06jtcW9hGb29DkgnuH6+kvcvnITv/3LWm78y1ruemwL/aUgnxOH7DuJ58+aypGzpnLY/pOZt0+bxy/MbNQcEHVmW3cftz6ykT89tIE7Vm3i9pWb2NrVB0BOMGdGK/P3aePgjjbmdrRx0IxW5rRPZHprM5JqXHszG0+GCwj3TzwDtbUUePGzO3jxszsAKJWCh9Z3cv/jW7nv8a3c//gWVqzt5IZ7n6Qvc2uPtpYCB06fyKzpE5g1bSKzpk/kgGkTOGDaRGZOm0Cbu6vMLMPfCHUglxNz09bCK56730B5b3+JlRu28/D6Th5Zv52H13WycuMOVqzt5H/uX0t3364PNJoyoYn9p05g5tQiz5pS5FmTizxrygT2ndzCsyYX2WdykcnFglshZg3CAVHHmvI5Du5IupoGiwjWbeth1cbtrNq4g1Ubd7Bm8w5Wp9PLHtnIxu29T3lfSyHHPpNb2GdSkY62FjomtTBj4GczMya10NGWlE1o9liI2TOZA6JBSaJjUvLFfuSB0yqu09Xbz+Obu3hiSxdPbO3myS1dPJn5uWLtNv740Ho2VQgSgNbmPO1tSXAM/Gxtob2tmemtzcxoa2F6a3MyP7GZgq/3MBtXHBA2pGJTnjkzWpkzo3XY9Xr6Sqzv7Gbd1h7Wbetm7bZu1m1L5td3drN+Ww8rN2zntkc3saGzm6HueD5lQhPtrc1Ma00CZPrEZqan4TGttZlpE5uY1trM1AlNTJ3YzORiwaFiVkUOCNtjzYUc+02ZwH5TJoy4bqkUbNzew4bOHtZ39rB+Ww8bOrtZ35mUlV8rN2znjpWb2NDZs8tA+2BtLQUmFwtMntDE5GITkycU0p/pq8KyScUCbS0F2ooFWgruBjMbigPC9qpcTrS3tdDe1sL8UawfEWzt7mNjZw8bt/eysbOHTTt62Ly9l807+ti0o4etXX1s2dHL5h29PLapi/u6trJ5Ry/buvsY6SzuQk5MaM4zsTnPxOYCxaY8E5pyTGjOUyzkKaY/JzTn0p95ik3lV7YsN1A+IX2Vp4vNOZrzOQ/u2zOOA8LGNUnJX//FJma37957S6VgW08fm7f3sqWrly07+tjS1cu2rj62diUBsr2nn+09/ezo6WdHb/LqSl+bd/Syvaef7t4SXb3Jel19/SOGTiU5kQRHcyY4BgIkT7GQGwidlkKelkKOlnS6uZBL5tPyYlO2LJdO71pWnm8u5MjnHEz29DggrG7lcjvDZaxEBD39Jbp6SnT1JcEy8LM3KesaKCuxvaeP7r5MwGRCqBxKm3f08uRAMJXo7uunu69Ed19pTB5Rm8+J5nxuIDCa8zuDpTxfnm7KZ8qy70mXtaTLmvKiuZBPfybLmvI5CnnRlEt/5kUhnd75UxTyueRnLinPl8tzIp+TW1rjSFUDQtJpwBeAPPD1iPj0oOUtwDeBo4H1wJkR8bCkU4FPA81AD/DBiFhSzbqajYak9C/5PFOo/o0T+/pLSSD1lujpS15dff309CVBUi7Phkp5vZ7+ncuyZYPX6U3X6+xOwqz8vnL5zvX2zl0XkhDZNVTy5TDJBEk+XS+fyw2UNY0wn912dn7k9yTzuXQ+p3IdIJ/LkZfI5SCvZHvl5bl0Plm+871PKcsuG0dBWbWAkJQHvgKcCqwCbpG0OCLuyax2LrAxIuZJOgu4GDgTWAe8KiIek/Qc4HpgZrXqajZeFfI5CvkcE5trXZOky64cKL39sTNE0iDp6w96S8nPvv4SvaXkZ18pkrJ0WX8pWa8/Wz4wnbynvxT0psvK6/WWSpRKyTr9mVdvKehPt93VW6K/1J++v5RZp0R/eftpWbmO5W2ONzkxEGT5TFiVwyRbftIh+/CRVx425nWoZgviWGB5RDwIIOlqYCGQDYiFwIXp9CLgy5IUEbdl1rkbmCCpJSK6q1hfMxtGLieKuXxd3gwyIg2NgfAI+uOpIdKfhll/KSiVoD92DatSJOuVymUxaFm63fL2ymWlQbpbfZwAAAeRSURBVJ9ffmXX6S+VBrZX3k55vf2mjnwG4dNRzYCYCazMzK8CjhtqnYjok7QZaCdpQZT9LXBrpXCQdB5wHsCBBx44djU3s4aitNvHZz3valxfZSTpcJJup3+stDwiLo2IBRGxoKOjY+9WzsyszlUzIFYDszLzB6RlFdeRVACmkAxWI+kA4FrgbRGxoor1NDOzCqoZELcA8yUdJKkZOAtYPGidxcDZ6fQZwJKICElTgZ8BF0TE76tYRzMzG0LVAiIi+oDzSc5Auhf4XkTcLekiSa9OV7sMaJe0HHg/cEFafj4wD/iYpNvT1z7VqquZmT2VnyhnZtbAhnui3LgepDYzs9pxQJiZWUUOCDMzq6huxiAkrQUe2YNNzGDXC/QaQSPuMzTmfnufG8fu7vfsiKh4IVndBMSekrR0qIGaetWI+wyNud/e58YxlvvtLiYzM6vIAWFmZhU5IHa6tNYVqIFG3GdozP32PjeOMdtvj0GYmVlFbkGYmVlFDggzM6uo4QNC0mmS7pe0XNIFI7/jmUfSLEm/kXSPpLslvTctny7pV5IeSH9Oq3Vdq0FSXtJtkn6azh8k6eb0mF+T3m24bkiaKmmRpPsk3SvphEY41pL+Of33fZek70oq1uOxlnS5pCcl3ZUpq3h8lfhiuv93Sjpqdz6roQMi89zs04HDgDdKGvsHu9ZeH/AvEXEYcDzwT+l+XgDcEBHzgRvYeTfdevNekjsKl10MXBIR84CNJM9GrydfAH4REYcCzyfZ97o+1pJmAu8BFkTEc4A8ySMG6vFYXwGcNqhsqON7OjA/fZ0HfHV3PqihA4LMc7MjogcoPze7rkTEmoi4NZ3eSvKFMZNkX69MV7sSeE1talg96YOn/gb4ejov4CSSZ6BDne23pCnAi0lupU9E9ETEJhrgWJM8QnlC+vCxicAa6vBYR8SNwIZBxUMd34XANyPxR2CqpP1G+1mNHhCVnps9s0Z12SskzQGOBG4G9o2INemix4F9a1StavpP4F+BUjrfDmxKn1cC9XfMDwLWAt9Iu9W+LqmVOj/WEbEa+CzwKEkwbAaWUd/HOmuo47tH33GNHhANRVIb8APgfRGxJbsskvOd6+qcZ0mvBJ6MiGW1rsteVACOAr4aEUcCnQzqTqrTYz2N5K/lg4D9gVae2g3TEMby+DZ6QIzmudl1QVITSTh8JyJ+mBY/UW5upj+frFX9quQFwKslPUzSfXgSSf/81LQbAurvmK8CVkXEzen8IpLAqPdjfQrwUESsjYhe4Ickx7+ej3XWUMd3j77jGj0gRvPc7Ge8tN/9MuDeiPh8ZlH2meBnAz/e23Wrpoj4UEQcEBFzSI7tkoh4M/AbkmegQ53td0Q8DqyUdEhadDJwD3V+rEm6lo6XNDH9917e77o91oMMdXwXA29Lz2Y6Htic6YoaUcNfSS3pFST91Hng8oj4VI2rNOYkvRD4HfBndvbFf5hkHOJ7wIEkt0p/Q0QMHvyqC5JOBD4QEa+UdDBJi2I6cBvwlojormX9xpKkI0gG5ZuBB4G/I/ljsK6PtaRPAGeSnLV3G/D3JP3tdXWsJX0XOJHktt5PAB8HfkSF45uG5ZdJutu2A38XEaN+NnPDB4SZmVXW6F1MZmY2BAeEmZlV5IAwM7OKHBBmZlaRA8LMzCpyQJgBkv5d0kslvUbSh9KyiySdkk6/T9LEMfy812RvDJn9LLPxwqe5mgGSlpDc1O/fgEUR8ftByx8muVPout3YZj4i+odYdgXw04hYVGm52XjggLCGJuk/gJeT3MNnBTAXeIjkFhUHAz8lubfPZ4H7gXUR8VJJLwM+AbSk7/u7iNiWBsk1wKnAZ4BJJLdZbgaWA28Fjki3uzl9/S3wUdLAkHRy+nkFkqv93xkR3em2rwReBTQBr4+I+yS9hOQWIpDcg+fF6V17zfaIu5isoUXEB0meEXAFcAxwZ0Q8LyIuyqzzReAx4KVpOMwAPgKcEhFHAUuB92c2uz4ijoqIq4EfRsQxEVF+LsO5EfEHklsgfDAijoiIFeU3SiqmdTkzIp5LEhLvzGx7XfqZXwU+kJZ9APiniDgCeBGwY0x+OdbwHBBmyc3s7gAOZdcHCw3leJIHTP1e0u0k976ZnVl+TWb6OZJ+J+nPwJuBw0fY9iEkN537Szp/JcnzHcrKN1pcBsxJp38PfF7Se4Cpmdtbm+2RwsirmNWn9J5FV5Dc4XIdyUNmlH7pnzDcW4FfRcQbh1jemZm+AnhNRNwh6RySe+jsifJ9hPpJ//9GxKcl/Qx4BUlovTwi7tvDzzFzC8IaV0TcnnbL/IWkRbAEeHna7TO4m2YryXgCwB+BF0iaByCpVdKzh/iYScCa9Hbrbx5ie1n3A3PK2yYZs/jtcPshaW5E/DkiLiYZszh0uPXNRssBYQ1NUgewMSJKwKERcc8Qq14K/ELSbyJiLXAO8F1JdwI3MfSX8kdJ7pr7eyD7V/3VwAfTp77NLRdGRBfJ3Ve/n3ZLlYCvjbAb75N0V1qXXuDnI6xvNio+i8nMzCpyC8LMzCpyQJiZWUUOCDMzq8gBYWZmFTkgzMysIgeEmZlV5IAwM7OK/j96COLo9CNtAwAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "OjrY4zGuQQix" - }, - "source": [ - "### Positivity Constraints\n", - "In this case it isn't helpful (the data were generated with positive and negative weights) but is a cool functionality!" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 592 - }, - "id": "gdBFdkvfQUb6", - "outputId": "92fff950-5567-4fbb-a83c-224ca4d6052f" - }, - "source": [ - "#fit a scca model with positivity constraints\n", - "scca_pos=SCCA(c=[1e-3,1e-3],positive=[True,True]).fit([X,Y])\n", - "\n", - "plot_model_weights(scca_pos.weights[0],scca_pos.weights[1],tx,ty)\n", - "\n", - "#Convergence\n", - "plt.figure()\n", - "plt.title('Objective Convergence')\n", - "plt.plot(np.array(scca_pos.track[0]['objective']).T)\n", - "plt.ylabel('Objective')\n", - "plt.xlabel('#iterations')" - ], - "execution_count": 24, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Text(0.5, 0, '#iterations')" - ] - }, - "metadata": {}, - "execution_count": 24 - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KPC7uhls4ycW" - }, - "source": [ - "## Sparse CCA by iterative elastic net (adapted from Waaijenborg)" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 592 - }, - "id": "wuZjjyN24j0P", - "outputId": "b8fba7ce-d303-402f-b78c-a316ec3f42f6" - }, - "source": [ - "#fit an elastic model\n", - "#for some reason this model needs REALLY big l2 regularisation. This is actually\n", - "#the same level of l1 regularisation as SCCA\n", - "elasticcca=ElasticCCA(c=[10000,10000],l1_ratio=[0.000001,0.000001]).fit([X,Y])\n", - "\n", - "plot_model_weights(elasticcca.weights[0],elasticcca.weights[1],tx,ty)\n", - "\n", - "#Convergence\n", - "plt.figure()\n", - "plt.title('Objective Convergence')\n", - "plt.plot(np.array(elasticcca.track[0]['objective']).T)\n", - "plt.ylabel('Objective')\n", - "plt.xlabel('#iterations')" - ], - "execution_count": 25, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Text(0.5, 0, '#iterations')" - ] - }, - "metadata": {}, - "execution_count": 25 - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E8TEaBhe7CYw" - }, - "source": [ - "## Sparse CCA by ADMM" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 592 - }, - "id": "7RkSPcWR7FY8", - "outputId": "d9590105-6016-4fa3-bae0-655b3e3336a3" - }, - "source": [ - "#fit a scca_admm model\n", - "scca_admm=SCCA_ADMM(c=[1e-3,1e-3]).fit([X,Y])\n", - "\n", - "plot_model_weights(scca_admm.weights[0],scca_admm.weights[1],tx,ty)\n", - "\n", - "#Convergence\n", - "plt.figure()\n", - "plt.title('Objective Convergence')\n", - "plt.plot(np.array(scca_admm.track[0]['objective']).T)\n", - "plt.ylabel('Objective')\n", - "plt.xlabel('#iterations')" - ], - "execution_count": 27, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Text(0.5, 0, '#iterations')" - ] - }, - "metadata": {}, - "execution_count": 27 - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LcbbCAa_5q5C" - }, - "source": [ - "## Sparse CCA by random projection (Span CCA)\n", - "This time the regularisation parameter c is the l0 norm of the weights i.e. the maximum number of non-zero weights. Let's cheat and give it the correct numbers. We can also change the rank of the estimation as described in the paper" - ] - }, - { - "cell_type": "code", - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 592 - }, - "id": "F_bT21qk5jA5", - "outputId": "bd683c19-1b7d-4c08-ee44-e398d3323cbb" - }, - "source": [ - "#fit a spancca model\n", - "spancca=SpanCCA(c=[10,10],max_iter=2000,rank=20).fit([X,Y])\n", - "\n", - "plot_model_weights(spancca.weights[0],spancca.weights[1],tx,ty)\n", - "\n", - "#Convergence\n", - "plt.figure()\n", - "plt.title('Objective Convergence')\n", - "plt.plot(np.array(spancca.track[0]['objective']).T)\n", - "plt.ylabel('Objective')\n", - "plt.xlabel('#iterations')" - ], - "execution_count": 28, - "outputs": [ - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - }, - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "Text(0.5, 0, '#iterations')" - ] - }, - "metadata": {}, - "execution_count": 28 - }, - { - "output_type": "display_data", - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - } - } - ] - }, - { - "cell_type": "code", - "metadata": { - "id": "eTEVZnXpQzFm" - }, - "source": [ - "" - ], - "execution_count": null, - "outputs": [] - } - ] -} \ No newline at end of file