Skip to content

Commit 5614d80

Browse files
authored
Add unified CEBRA encoder: pytorch implementation (#251)
* start tests * remove print statements * first passing test * move functionality to base file in solver and separate in functions * add test_select_model for multisession * remove float16 * Improve modularity remove duplicate code and todos * Add tests to solver * Fix save/load * Fix extra docs errors * Add review updates * apply ruff auto-fixes * fix linting errors * Run isort, ruff, yapf * Fix gaussian mixture dataset import * Fix all tests but xcebra tests * Fix pytorch API usage example * Make xCEBRA compatible with the batched inference & padding in solver * Add some tests on transform() with xCEBRA * Add some docstrings and typings and clean unnecessary changes * Implement review comments * Fix sklearn test * Initial pass at integrating unifiedCEBRA * Add name in NOTE * Implement reviews on tests and typing * Fix import errors * Add select_model to aux solvers * Fix tests * Add mask tests * Fix docs error * Remove masking init() * Remove shuffled neurons in unified dataset * Remove extra datasets * Add tests on the private functions in base solver * Update tests and duplicate code based on review * Fix quantized_embedding_norm undefined when `normalize=False` (#249) * Fix tests * Adapt unified code to get_model method * Update mask.py add headers to new files * Update masking.py - header * Update test_data_masking.py - header * Implement review comments and fix typos * Fix docs errors * Remove np.int typing error * Fix docstring warning * Fix indentation docstrings * Implement review comments * Fix circular import and abstract method * Add maskedmixin to __all__ * Implement extra review comments * Change masking kwargs as tuple and not dict in sklearn impl * Add integrations/decoders.py * Fix typo * minor simplification in solver --------- Note, some comments in this PR overlap with #168 and #225 which were developed in parallel.
1 parent 7ae5e1e commit 5614d80

22 files changed

+1951
-206
lines changed

cebra/data/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,4 @@
5151
from cebra.data.multiobjective import *
5252
from cebra.data.datasets import *
5353
from cebra.data.helper import *
54+
from cebra.data.masking import *

cebra/data/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import torch
2828

2929
import cebra.data.assets as cebra_data_assets
30+
import cebra.data.masking as cebra_data_masking
3031
import cebra.distributions
3132
import cebra.io
3233
from cebra.data.datatypes import Batch
@@ -36,7 +37,7 @@
3637
__all__ = ["Dataset", "Loader"]
3738

3839

39-
class Dataset(abc.ABC, cebra.io.HasDevice):
40+
class Dataset(abc.ABC, cebra.io.HasDevice, cebra_data_masking.MaskedMixin):
4041
"""Abstract base class for implementing a dataset.
4142
4243
The class attributes provide information about the shape of the data when

cebra/data/datasets.py

Lines changed: 95 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828
import numpy.typing as npt
2929
import torch
3030

31+
import cebra
3132
import cebra.data as cebra_data
33+
import cebra.data.masking as cebra_data_masking
3234
import cebra.helper as cebra_helper
3335
import cebra.io as cebra_io
3436
from cebra.data.datatypes import Batch
@@ -304,7 +306,7 @@ def _iter_property(self, attr):
304306

305307

306308
# TODO(stes): This should be a single session dataset?
307-
class DatasetxCEBRA(cebra_io.HasDevice):
309+
class DatasetxCEBRA(cebra_io.HasDevice, cebra_data_masking.MaskedMixin):
308310
"""Dataset class for xCEBRA models.
309311
310312
This class handles neural data and associated labels for xCEBRA models, providing
@@ -435,3 +437,95 @@ def load_batch_contrastive(self, index: BatchIndex) -> Batch:
435437
positive=[self[idx] for idx in index.positive],
436438
negative=self[index.negative],
437439
)
440+
441+
442+
class UnifiedDataset(DatasetCollection):
443+
"""Multi session dataset made up of a list of datasets, considered as a unique session.
444+
445+
Considering the sessions as a unique session, or pseudo-session, is used to later train a single
446+
model for all the sessions, even if they originally contain a variable number of neurons.
447+
To do that, we sample ref/pos/neg for each session and concatenate them along the neurons axis.
448+
449+
For instance, for a batch size ``batch_size``, we sample ``(batch_size, num_neurons(session), offset)`` for
450+
each type of samples (ref/pos/neg) and then concatenate so that the final :py:class:`cebra.data.datatypes.Batch`
451+
is of shape ``(batch_size, total_num_neurons, offset)``, with ``total_num_neurons`` is the sum of all the
452+
``num_neurons(session)``.
453+
"""
454+
455+
def __init__(self, *datasets: cebra_data.SingleSessionDataset):
456+
super().__init__(*datasets)
457+
458+
@property
459+
def input_dimension(self) -> int:
460+
"""Returns the sum of the input dimension for each session."""
461+
return np.sum([
462+
self.get_input_dimension(session_id)
463+
for session_id in range(self.num_sessions)
464+
])
465+
466+
def _get_batches(self, index):
467+
"""Return the data at the specified index location."""
468+
return [
469+
cebra_data.Batch(
470+
reference=self.get_session(session_id)[
471+
index.reference[session_id]],
472+
positive=self.get_session(session_id)[
473+
index.positive[session_id]],
474+
negative=self.get_session(session_id)[
475+
index.negative[session_id]],
476+
) for session_id in range(self.num_sessions)
477+
]
478+
479+
def configure_for(self, model: "cebra.models.Model"):
480+
"""Configure the dataset offset for the provided model.
481+
482+
Call this function before indexing the dataset. This sets the
483+
:py:attr:`~.Dataset.offset` attribute of the dataset.
484+
485+
Args:
486+
model: The model to configure the dataset for.
487+
"""
488+
for i, session in enumerate(self.iter_sessions()):
489+
session.configure_for(model)
490+
491+
def load_batch(self, index: BatchIndex) -> Batch:
492+
"""Return the data at the specified index location.
493+
494+
Concatenate batches for each sessions on the number of neurons axis.
495+
496+
Args:
497+
batches: List of :py:class:`cebra.data.datatypes.Batch` sampled for each session. An instance
498+
:py:class:`cebra.data.datatypes.Batch` of the list is of shape ``(batch_size, num_neurons(session), offset)``.
499+
500+
Returns:
501+
A :py:class:`cebra.data.datatypes.Batch`, of shape ``(batch_size, total_num_neurons, offset)``, where
502+
``total_num_neurons`` is the sum of all the ``num_neurons(session)``
503+
"""
504+
batches = self._get_batches(index)
505+
506+
if hasattr(self, "apply_mask"):
507+
# If the dataset has a mask, apply it to the data.
508+
batch = cebra_data.Batch(
509+
reference=self.apply_mask(
510+
torch.cat([batch.reference for batch in batches], dim=1)),
511+
positive=self.apply_mask(
512+
torch.cat([batch.positive for batch in batches], dim=1)),
513+
negative=self.apply_mask(
514+
torch.cat([batch.negative for batch in batches], dim=1)),
515+
)
516+
else:
517+
batch = cebra_data.Batch(
518+
reference=torch.cat([batch.reference for batch in batches],
519+
dim=1),
520+
positive=torch.cat([batch.positive for batch in batches],
521+
dim=1),
522+
negative=torch.cat([batch.negative for batch in batches],
523+
dim=1),
524+
)
525+
return batch
526+
527+
def __getitem__(self, args) -> List[Batch]:
528+
"""Return a set of samples from all sessions."""
529+
530+
session_id, index = args
531+
return self.get_session(session_id).__getitem__(index)

0 commit comments

Comments
 (0)