-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batched inference CEBRA & padding at the Solver
level
#168
base: main
Are you sure you want to change the base?
Batched inference CEBRA & padding at the Solver
level
#168
Conversation
…ional models in _transform
@stes @MMathisLab, if you have time to review this that would be great :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks fine to me, and is already used internally in production right @stes ?
@CeliaBenquet can you solve the conflicts, then I think fine to merge! |
@MMathisLab there's been big code changes / refactoring since @stes's last review, so I would be more confident about merging after an "in-depth" reviewing, but your call :) |
reviewing now |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few comments.
The biggest issue I see, is that a lot of features are changed that do not seem to be directly related to the batched implementation (but I might be wrong). So one iteration addressing some of these Qs in my review would help me understand the logic a bit better.
cebra/data/base.py
Outdated
raise NotImplementedError | ||
self.offset = model.get_offset() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo? / missing cleanup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i.e. should the line below be removed here? why is that relevant for batched inference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
configure_for was done in the cebra.CEBRA class (in configure_for_all) and now it is moved to the solvers directly, and the configure_for in the multisession solver was wrongly implemented and not used.
So now not implemented in the base class and defined in multi and single solvers.
@@ -192,7 +203,6 @@ class ContinuousDataLoader(cebra_data.Loader): | |||
and become equivalent to time contrastive learning. | |||
""", | |||
) | |||
time_offset: int = dataclasses.field(default=10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
moved to the base class instead: it was defined in each child class (see in multisession dataset I also removed it, and added to the base one).
cebra/solver/base.py
Outdated
if not hasattr(self, "n_features"): | ||
raise ValueError( | ||
f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " | ||
"appropriate arguments before using this estimator.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is not ideal to use the n_features
for this. can you implement a @property
that gives you that info directly (is_fitted()
) for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's what was done initially with the sklearn function.
I'm not sure to understand how the is_fitted changes that, it's just an implementation thing right? I keep the n_features?
cebra/solver/base.py
Outdated
@@ -336,7 +647,7 @@ def load(self, logdir, filename="checkpoint.pth"): | |||
checkpoint = torch.load(savepath, map_location=self.device) | |||
self.load_state_dict(checkpoint, strict=True) | |||
|
|||
def save(self, logdir, filename="checkpoint_last.pth"): | |||
def save(self, logdir, filename="checkpoint.pth"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep the old naming here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason it's different from the default in the load function?
cebra/solver/multi_session.py
Outdated
def parameters(self, session_id: Optional[int] = None): | ||
"""Iterate over all parameters.""" | ||
self._check_is_session_id_valid(session_id=session_id) | ||
for parameter in self.model[session_id].parameters(): | ||
yield parameter | ||
|
||
for parameter in self.criterion.parameters(): | ||
yield parameter |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if None
is given, we should return all parameters from the super() class
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure I get it, you mean we should return the parameters from the criterion at least? or the parameters for all models?
the super class has an abstract method for that method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this related to batched inference only?
@@ -104,7 +104,7 @@ def get_years(start_year=2021): | |||
|
|||
intersphinx_mapping = { | |||
"python": ("https://docs.python.org/3", None), | |||
"torch": ("https://pytorch.org/docs/master/", None), | |||
"torch": ("https://pytorch.org/docs/stable/", None), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can go in extra PR for quick merge, but ofc fine here as well
re your comment on changes not related to the batched inference, it is because the PR was started with 2 (related) goals at once if I'm correct (not me who started it):
--> see other linked issues for better understanding. |
Ok, makes sense! |
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/746
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/624
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/issues/637
fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/594