|
28 | 28 | import numpy.typing as npt
|
29 | 29 | import torch
|
30 | 30 |
|
| 31 | +import cebra |
31 | 32 | import cebra.data as cebra_data
|
| 33 | +import cebra.data.masking as cebra_data_masking |
32 | 34 | import cebra.helper as cebra_helper
|
33 | 35 | import cebra.io as cebra_io
|
34 | 36 | from cebra.data.datatypes import Batch
|
@@ -304,7 +306,7 @@ def _iter_property(self, attr):
|
304 | 306 |
|
305 | 307 |
|
306 | 308 | # TODO(stes): This should be a single session dataset?
|
307 |
| -class DatasetxCEBRA(cebra_io.HasDevice): |
| 309 | +class DatasetxCEBRA(cebra_io.HasDevice, cebra_data_masking.MaskedMixin): |
308 | 310 | """Dataset class for xCEBRA models.
|
309 | 311 |
|
310 | 312 | This class handles neural data and associated labels for xCEBRA models, providing
|
@@ -435,3 +437,95 @@ def load_batch_contrastive(self, index: BatchIndex) -> Batch:
|
435 | 437 | positive=[self[idx] for idx in index.positive],
|
436 | 438 | negative=self[index.negative],
|
437 | 439 | )
|
| 440 | + |
| 441 | + |
| 442 | +class UnifiedDataset(DatasetCollection): |
| 443 | + """Multi session dataset made up of a list of datasets, considered as a unique session. |
| 444 | +
|
| 445 | + Considering the sessions as a unique session, or pseudo-session, is used to later train a single |
| 446 | + model for all the sessions, even if they originally contain a variable number of neurons. |
| 447 | + To do that, we sample ref/pos/neg for each session and concatenate them along the neurons axis. |
| 448 | +
|
| 449 | + For instance, for a batch size ``batch_size``, we sample ``(batch_size, num_neurons(session), offset)`` for |
| 450 | + each type of samples (ref/pos/neg) and then concatenate so that the final :py:class:`cebra.data.datatypes.Batch` |
| 451 | + is of shape ``(batch_size, total_num_neurons, offset)``, with ``total_num_neurons`` is the sum of all the |
| 452 | + ``num_neurons(session)``. |
| 453 | + """ |
| 454 | + |
| 455 | + def __init__(self, *datasets: cebra_data.SingleSessionDataset): |
| 456 | + super().__init__(*datasets) |
| 457 | + |
| 458 | + @property |
| 459 | + def input_dimension(self) -> int: |
| 460 | + """Returns the sum of the input dimension for each session.""" |
| 461 | + return np.sum([ |
| 462 | + self.get_input_dimension(session_id) |
| 463 | + for session_id in range(self.num_sessions) |
| 464 | + ]) |
| 465 | + |
| 466 | + def _get_batches(self, index): |
| 467 | + """Return the data at the specified index location.""" |
| 468 | + return [ |
| 469 | + cebra_data.Batch( |
| 470 | + reference=self.get_session(session_id)[ |
| 471 | + index.reference[session_id]], |
| 472 | + positive=self.get_session(session_id)[ |
| 473 | + index.positive[session_id]], |
| 474 | + negative=self.get_session(session_id)[ |
| 475 | + index.negative[session_id]], |
| 476 | + ) for session_id in range(self.num_sessions) |
| 477 | + ] |
| 478 | + |
| 479 | + def configure_for(self, model: "cebra.models.Model"): |
| 480 | + """Configure the dataset offset for the provided model. |
| 481 | +
|
| 482 | + Call this function before indexing the dataset. This sets the |
| 483 | + :py:attr:`~.Dataset.offset` attribute of the dataset. |
| 484 | +
|
| 485 | + Args: |
| 486 | + model: The model to configure the dataset for. |
| 487 | + """ |
| 488 | + for i, session in enumerate(self.iter_sessions()): |
| 489 | + session.configure_for(model) |
| 490 | + |
| 491 | + def load_batch(self, index: BatchIndex) -> Batch: |
| 492 | + """Return the data at the specified index location. |
| 493 | +
|
| 494 | + Concatenate batches for each sessions on the number of neurons axis. |
| 495 | +
|
| 496 | + Args: |
| 497 | + batches: List of :py:class:`cebra.data.datatypes.Batch` sampled for each session. An instance |
| 498 | + :py:class:`cebra.data.datatypes.Batch` of the list is of shape ``(batch_size, num_neurons(session), offset)``. |
| 499 | +
|
| 500 | + Returns: |
| 501 | + A :py:class:`cebra.data.datatypes.Batch`, of shape ``(batch_size, total_num_neurons, offset)``, where |
| 502 | + ``total_num_neurons`` is the sum of all the ``num_neurons(session)`` |
| 503 | + """ |
| 504 | + batches = self._get_batches(index) |
| 505 | + |
| 506 | + if hasattr(self, "apply_mask"): |
| 507 | + # If the dataset has a mask, apply it to the data. |
| 508 | + batch = cebra_data.Batch( |
| 509 | + reference=self.apply_mask( |
| 510 | + torch.cat([batch.reference for batch in batches], dim=1)), |
| 511 | + positive=self.apply_mask( |
| 512 | + torch.cat([batch.positive for batch in batches], dim=1)), |
| 513 | + negative=self.apply_mask( |
| 514 | + torch.cat([batch.negative for batch in batches], dim=1)), |
| 515 | + ) |
| 516 | + else: |
| 517 | + batch = cebra_data.Batch( |
| 518 | + reference=torch.cat([batch.reference for batch in batches], |
| 519 | + dim=1), |
| 520 | + positive=torch.cat([batch.positive for batch in batches], |
| 521 | + dim=1), |
| 522 | + negative=torch.cat([batch.negative for batch in batches], |
| 523 | + dim=1), |
| 524 | + ) |
| 525 | + return batch |
| 526 | + |
| 527 | + def __getitem__(self, args) -> List[Batch]: |
| 528 | + """Return a set of samples from all sessions.""" |
| 529 | + |
| 530 | + session_id, index = args |
| 531 | + return self.get_session(session_id).__getitem__(index) |
0 commit comments