-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add parameters for MixedDataLoader #101
base: main
Are you sure you want to change the base?
Add parameters for MixedDataLoader #101
Conversation
minor comment but this also includes the doc changes from #99 - I'd remove those first to keep this PR clean, thanks @timonmerk - we will review it. |
e8f73fe
to
f1894a1
Compare
Yes, sorry for that! Took me some time to figure out how to remove an existing commit from github, but git rebase and force push did the job :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the super late review. I think we should mainly add tests to this PR before merging, and I would propose to do that in an upcoming release, not as part of 0.3.1 --- @MMathisLab ?
cebra/data/single_session.py
Outdated
@property | ||
def dindex(self): | ||
# TODO(stes) rename to discrete_index | ||
def discrete_index(self): | ||
return self.dataset.discrete_index | ||
|
||
@property | ||
def cindex(self): | ||
# TODO(stes) rename to continuous_index | ||
def continuous_index(self): | ||
return self.dataset.continuous_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not delete these and instead keep the old names (with a deprecation notice) plus the new ones side by side before making the switch.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added the warnings in 8af5b0e
cebra/data/single_session.py
Outdated
1. Positive pairs always share their discrete variable. | ||
2. Positive pairs are drawn only based on their conditional, | ||
not discrete variable. | ||
|
||
Args: | ||
conditional (str): The conditional variable for sampling positive pairs. :py:attr:`cebra.CEBRA.conditional` | ||
time_offset (int): :py:attr:`cebra.CEBRA.time_offsets` | ||
positive_sampling (str): either "discrete_variable" (default) or "conditional" | ||
discrete_sampling_prior (str): either "empirical" (default) or "uniform" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optional: We could extend the docs here a bit and visualize the different options with examples (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I extended the docstring a bit. We also had an example in our analysis where the embedding looked very different if a discrete_variable was being used, s.t. there were different number of clusters for each discrete variable instance (comparison of different cohorts with different movement types for classification of movement). But probably it would be tricky to upload the data and present the pipeline?
Previously I obtain an error in the |
e1dd659
to
8d1f0a8
Compare
Hi @timonmerk , apologies for the slow replies here. Let's get this merged before the PRs 1st year anniversary :D Checking compatibility with current main branch now and aiming to get this ready later today! The test looks good. |
8d1f0a8
to
f91b64a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Upon reading this PR more closely today, I noticed a few issues I missed last time. Specifically, I think the current way of picking the options can also be implemented by choosing the DiscreteDataLoader variants in CEBRA, as the new sampling mode ignores the continuous variables.
I think before proceeding, we should discuss the goal of this PR (ideally in form of an example or test that demonstrates the functionality we want to accomplish).
One option, as remarked by @timonmerk in the description/issue #100 is support for the empirical
vs. uniform
sampling mode. However, I think the best place to implement this is in the MixedDataLoader
directly.
The second option is to implement the mode "Positive pairs always share their discrete variable." (this is what we have right now) vs. "Positive pairs are drawn only based on their conditional, not discrete variable.". This is the description from the original docstring. This second mode would make sense if we want to sample the positive variable from the positive conditional distribution, and want to sample the negative distribution from the uniform conditional (thereby making the embedding invariant to this variable).
Both features make a lot of sense, and should be added to CEBRA.
@timonmerk , would you be interested in implementing this, or should I go ahead and make a suggestion?
if self.positive_sampling == "conditional": | ||
self.distribution = cebra.distributions.MixedTimeDeltaDistribution( | ||
discrete=self.discrete_index, | ||
continuous=self.continuous_index, | ||
time_delta=self.time_offset) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs to be the default behavior, that was how the class used to behave.
# taken from the DiscreteDataLoader get_indices function | ||
reference_idx = self.distribution.sample_prior(num_samples * 2) | ||
negative_idx = reference_idx[num_samples:] | ||
reference_idx = reference_idx[:num_samples] | ||
reference = self.discrete_index[reference_idx] | ||
positive_idx = self.distribution.sample_conditional(reference) | ||
return BatchIndex(reference=reference_idx, | ||
positive=positive_idx, | ||
negative=negative_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# taken from the DiscreteDataLoader get_indices function | |
reference_idx = self.distribution.sample_prior(num_samples * 2) | |
negative_idx = reference_idx[num_samples:] | |
reference_idx = reference_idx[:num_samples] | |
reference = self.discrete_index[reference_idx] | |
positive_idx = self.distribution.sample_conditional(reference) | |
return BatchIndex(reference=reference_idx, | |
positive=positive_idx, | |
negative=negative_idx) | |
return self.distribution.get_indices(num_samples) |
should be equivalent. But I think this is actually not the desired functionality, as this then completely ignores the continuous index...
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "empirical": | ||
self.distribution = cebra.distributions.DiscreteEmpirical(self.discrete_index) | ||
elif self.positive_sampling == "discrete_variable" and self.discrete_sampling_prior == "uniform": | ||
self.distribution = cebra.distributions.DiscreteUniform(self.discrete_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are these modes different from going for the empirical discrete / uniform discrete distribution in the first place? I think what we rather want is specify an option to the MixedTimeDeltaDistribution
to support empirical
vs. uniform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but I understood that the current docstring of MixedDataLoader
suggests that this is indeed the intended functionality:
CEBRA/cebra/data/single_session.py
Line 268 in 9898850
1. Positive pairs always share their discrete variable. |
Even though I agree that it wouldn't make sense in this case to call the MixedDataLoader
if self.positive_sampling == "conditional": | ||
reference_idx = self.distribution.sample_prior(num_samples) | ||
return BatchIndex( | ||
reference=reference_idx, | ||
negative=self.distribution.sample_prior(num_samples), | ||
positive=self.distribution.sample_conditional(reference_idx), | ||
) | ||
else: | ||
# taken from the DiscreteDataLoader get_indices function | ||
reference_idx = self.distribution.sample_prior(num_samples * 2) | ||
negative_idx = reference_idx[num_samples:] | ||
reference_idx = reference_idx[:num_samples] | ||
reference = self.discrete_index[reference_idx] | ||
positive_idx = self.distribution.sample_conditional(reference) | ||
return BatchIndex(reference=reference_idx, | ||
positive=positive_idx, | ||
negative=negative_idx) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way this is setup currently means we either take both variables into account, or we ignore the continuous variables. I think that behavior is not necessarily intended (?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, agree. As you suggested above, it would then makes sense to pass the MixedDataLoader.discrete_sampling_prior
argument to MixedTimeDeltaDistribution
directly, and adapt MixedTimeDeltaDistribution
to not only sample DiscreteUniform
distribution
dataset = RandomDataset(N=100, d=5, device=device) | ||
loader = cebra.data.MixedDataLoader( | ||
dataset=dataset, | ||
num_steps=10, | ||
batch_size=8, | ||
conditional=conditional, | ||
positive_sampling=positive_sampling, | ||
discrete_sampling_prior=discrete_sampling_prior, | ||
) | ||
_assert_dataset_on_correct_device(loader, device) | ||
load_speed = LoadSpeed(loader) | ||
benchmark(load_speed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should extend the test to check the properties of the positive and negative samples (e.g., check if the discrete labels match and so forth, as expected for each setting of parameters)
@stes Thanks for addressing this PR. And yes, took me some time to get into it again :D As mentioned in the paper you wrote
So I thought that it would be possible to pass two "labels", one task-relevant and one task invariant. I understood that the positive sampling then samples with respect to the task-relevant variable, but also with respect to the task invariant variable to make the embedding invariant to it, e.g. by specifying the distribution to be uniform for that variable. I think this would be an optimal application for a couple of use cases, such as the one mentioned in the paper, or also to build an embedding that is invariant across patients but nevertheless "variant" to a behavioural variable. Optimally both variables could be discrete, continuous, or a mixture of both. But I guess this wouldn't be directly supported by the currently Loader setup, even though I see that the sampling methods for that are implemented, e.g. CEBRA/cebra/distributions/mixed.py Line 88 in 9898850
So I guess this is more up to the user to declare their own DataLoader, and then define manually how And to be honest I agree that the |
Addresses #100
Maybe I misunderstood the function but I added keywords that the sampling can now allow for
empirical
anduniform
priors of the discrete label, and in addition to select between conditional and discrete only positive sampling.