Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add parameters for MixedDataLoader #101

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

timonmerk
Copy link
Contributor

Addresses #100
Maybe I misunderstood the function but I added keywords that the sampling can now allow for empirical and uniform priors of the discrete label, and in addition to select between conditional and discrete only positive sampling.

@cla-bot cla-bot bot added the CLA signed label Oct 29, 2023
@MMathisLab
Copy link
Member

minor comment but this also includes the doc changes from #99 - I'd remove those first to keep this PR clean, thanks @timonmerk - we will review it.

@timonmerk
Copy link
Contributor Author

Yes, sorry for that! Took me some time to figure out how to remove an existing commit from github, but git rebase and force push did the job :)

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the super late review. I think we should mainly add tests to this PR before merging, and I would propose to do that in an upcoming release, not as part of 0.3.1 --- @MMathisLab ?

Comment on lines 277 to 309
@property
def dindex(self):
# TODO(stes) rename to discrete_index
def discrete_index(self):
return self.dataset.discrete_index

@property
def cindex(self):
# TODO(stes) rename to continuous_index
def continuous_index(self):
return self.dataset.continuous_index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should not delete these and instead keep the old names (with a deprecation notice) plus the new ones side by side before making the switch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the warnings in 8af5b0e

cebra/data/single_session.py Show resolved Hide resolved
Comment on lines 261 to 283
1. Positive pairs always share their discrete variable.
2. Positive pairs are drawn only based on their conditional,
not discrete variable.

Args:
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional`
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets`
positive_sampling (str): either "discrete_variable" (default) or "conditional"
discrete_sampling_prior (str): either "empirical" (default) or "uniform"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional: We could extend the docs here a bit and visualize the different options with examples (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extended the docstring a bit. We also had an example in our analysis where the embedding looked very different if a discrete_variable was being used, s.t. there were different number of clusters for each discrete variable instance (comparison of different cohorts with different movement types for classification of movement). But probably it would be tricky to upload the data and present the pipeline?

@timonmerk
Copy link
Contributor Author

Previously I obtain an error in the test_criterions.py but lokally I did not obtain this error. After merging the upstream main changes, the error now also doesn't occur any more.
@stes Are there any comments about the docstring and test I added?

@stes
Copy link
Member

stes commented Oct 27, 2024

Hi @timonmerk , apologies for the slow replies here. Let's get this merged before the PRs 1st year anniversary :D Checking compatibility with current main branch now and aiming to get this ready later today!

The test looks good.

@stes stes removed the request for review from nastya236 October 27, 2024 13:46
@stes stes added the enhancement New feature or request label Oct 27, 2024
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon reading this PR more closely today, I noticed a few issues I missed last time. Specifically, I think the current way of picking the options can also be implemented by choosing the DiscreteDataLoader variants in CEBRA, as the new sampling mode ignores the continuous variables.

I think before proceeding, we should discuss the goal of this PR (ideally in form of an example or test that demonstrates the functionality we want to accomplish).

One option, as remarked by @timonmerk in the description/issue #100 is support for the empirical vs. uniform sampling mode. However, I think the best place to implement this is in the MixedDataLoader directly.

The second option is to implement the mode "Positive pairs always share their discrete variable." (this is what we have right now) vs. "Positive pairs are drawn only based on their conditional, not discrete variable.". This is the description from the original docstring. This second mode would make sense if we want to sample the positive variable from the positive conditional distribution, and want to sample the negative distribution from the uniform conditional (thereby making the embedding invariant to this variable).

Both features make a lot of sense, and should be added to CEBRA.

@timonmerk , would you be interested in implementing this, or should I go ahead and make a suggestion?

Comment on lines +313 to +317
if self.positive_sampling == "conditional":
self.distribution = cebra.distributions.MixedTimeDeltaDistribution(
discrete=self.discrete_index,
continuous=self.continuous_index,
time_delta=self.time_offset)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be the default behavior, that was how the class used to behave.

Comment on lines +363 to +371
# taken from the DiscreteDataLoader get_indices function
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference = self.discrete_index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
negative=negative_idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# taken from the DiscreteDataLoader get_indices function
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference = self.discrete_index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
negative=negative_idx)
return self.distribution.get_indices(num_samples)

should be equivalent. But I think this is actually not the desired functionality, as this then completely ignores the continuous index...

Comment on lines +318 to +321
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical":
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index)
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform":
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these modes different from going for the empirical discrete / uniform discrete distribution in the first place? I think what we rather want is specify an option to the MixedTimeDeltaDistribution to support empirical vs. uniform

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, but I understood that the current docstring of MixedDataLoader suggests that this is indeed the intended functionality:

1. Positive pairs always share their discrete variable.

Even though I agree that it wouldn't make sense in this case to call the MixedDataLoader

Comment on lines +355 to +371
if self.positive_sampling == "conditional":
reference_idx = self.distribution.sample_prior(num_samples)
return BatchIndex(
reference=reference_idx,
negative=self.distribution.sample_prior(num_samples),
positive=self.distribution.sample_conditional(reference_idx),
)
else:
# taken from the DiscreteDataLoader get_indices function
reference_idx = self.distribution.sample_prior(num_samples * 2)
negative_idx = reference_idx[num_samples:]
reference_idx = reference_idx[:num_samples]
reference = self.discrete_index[reference_idx]
positive_idx = self.distribution.sample_conditional(reference)
return BatchIndex(reference=reference_idx,
positive=positive_idx,
negative=negative_idx)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way this is setup currently means we either take both variables into account, or we ignore the continuous variables. I think that behavior is not necessarily intended (?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, agree. As you suggested above, it would then makes sense to pass the MixedDataLoader.discrete_sampling_prior argument to MixedTimeDeltaDistribution directly, and adapt MixedTimeDeltaDistribution to not only sample DiscreteUniform distribution

Comment on lines +206 to +217
dataset = RandomDataset(N=100, d=5, device=device)
loader = cebra.data.MixedDataLoader(
dataset=dataset,
num_steps=10,
batch_size=8,
conditional=conditional,
positive_sampling=positive_sampling,
discrete_sampling_prior=discrete_sampling_prior,
)
_assert_dataset_on_correct_device(loader, device)
load_speed = LoadSpeed(loader)
benchmark(load_speed)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should extend the test to check the properties of the positive and negative samples (e.g., check if the discrete labels match and so forth, as expected for each setting of parameters)

@timonmerk
Copy link
Contributor Author

@stes Thanks for addressing this PR. And yes, took me some time to get into it again :D
I think the first intention of this PR was simply to implement the empirical and uniform priors that were already mentioned in the docstring of the MixedDataLoader. I realize now however that there is more functionality in discussion behind the MixedDataLoader.

As mentioned in the paper you wrote

if we aim to build a robust brain machine interface that should be invariant to such short-term changes, we would include trial information as a task-irrelevant variable and obtain an embedding space that no longer carries this information

So I thought that it would be possible to pass two "labels", one task-relevant and one task invariant. I understood that the positive sampling then samples with respect to the task-relevant variable, but also with respect to the task invariant variable to make the embedding invariant to it, e.g. by specifying the distribution to be uniform for that variable.

I think this would be an optimal application for a couple of use cases, such as the one mentioned in the paper, or also to build an embedding that is invariant across patients but nevertheless "variant" to a behavioural variable. Optimally both variables could be discrete, continuous, or a mixture of both. But I guess this wouldn't be directly supported by the currently Loader setup, even though I see that the sampling methods for that are implemented, e.g.

def sample_conditional_continuous(self,
but not used within the MixedDataLoader.

So I guess this is more up to the user to declare their own DataLoader, and then define manually how prior and index should be specified. As far as I understand it, the scikit-learn API also doesn't provide access to modify those.
With regard to the current code, you're right that the continuous sampling wouldn't be used right now, only if positive_sampling equals conditional, which doesn't make that much sense.

And to be honest I agree that the mixed.MixedTimedeltaDistribution makes also more sense to be uniform, if you think it would be useful to add the empirical option I would also be happy to add it. Otherwise maybe just the docstring could be removed to avoid confusion that there is currently no option to specify the discrete distribution?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA signed enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants