-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LightningModule.configure_callbacks overrides Trainer callbacks #18784
Comments
Any thoughts on this proposal? |
Hey, personally I think it should stay the way it is right now. The reason this hook was introduced to the LM was that some models need specific callbacks for proper training. If user-passed callbacks would override them, this would mean that possibly you wouldn't have all of these essential callbacks. Also if they won't be merged again, it is a lot of boilerplate to just extend the list of callbacks as you can't pass in the additional ones to the trainer anymore without also instantiating and listing the ones that have previously been handled by the LM. |
Just to clarify, I'm proposing to keep this behavior:
but change this behavior:
to reverse priority and let the user-passed callback to override the |
The use case that you are describing here is that you provide default callbacks in But we could discuss making it configurable. For example, the LightningModule could have a special attribute: class MyModel(LightningModule):
def __init__(self):
super().__init__()
self.replace_trainer_callbacks = True # default
self.replace_trainer_callbacks = False # Trainer callbacks take precedence over configure_callbacks (need to think about a good name) In general though I think that "providing default callbacks and letting a user override them" should be handled completely outside the LightningModule. For example, the client code (torchgeo entry point) parsing the callbacks provided by the user could check them against the default callbacks chosen by the application (torchgeo). |
In that case, is there a different approach recommended for default callbacks? Our users will almost always want model checkpoints and early stopping, but might want to customize how these work. I'm hoping not to have to manually parse callbacks in our entrypoint if possible, but if that's the only way I may need help figuring out how to do that. |
What would prevent us from doing this (like any other default-argument handling): def main(callbacks=None):
if callbacks is None:
callbacks = get_default_callbacks()
# OR
# append callbacks that the user didn't configure
if not_has_this_other_callback(callbacks)
callbacks.append(...)
trainer = Trainer(callbacks=callbacks)
main() # user calls with or without callbacks Is there something in your use case that is strongly incompatible or inconvenient with this? Callbacks is just an example here of course, there would be plenty of other Trainer arguments that fall into the same usage pattern. The above pattern is just a conceptual example, it may need some adjustment if a CLI is used etc. |
Part of the problem is that the callbacks need to be configured based on which LightningModule is chosen. That was the whole motivation for #18480. How would I get access to the chosen LightningModule inside |
Yeah this code snippet was just a sketch of course. If there are only a couple LightningModules, I guess a switch statement with If this is truly a new use case for the |
We don't have too many modules to maintain a mapping for. How would I get access to the selected LightningModule? I've never written custom LightningCLI arg parsing before... |
For LightningCLI specifically, I'm actually not sure what the best practice is. This might be a bit outside of what LightningCLI can offer without overriding it. @carmocca do you have any suggestions? |
The At that point, class MyLightningCLI(LightningLCI):
def instantiate_trainer(self, **kwargs):
if isinstance(self.model, ThisModel):
kwargs["callbacks"] = ThisCallback()
return super().instantiate_trainer(**kwargs) |
Thanks, now I just need to decide whether this added complexity is worth it vs. just letting the users define their own callbacks. I would still be interested in some setting somewhere that allows Trainer args to take precedence but don't want to add any more complexity to LightningModules if I'm the only one using them this way. Maybe we can leave this issue open for a while and see if anyone else needs this feature. |
Outline & Motivation
In TorchGeo, we provide
LightningModules
for common tasks, like classification, regression, etc. We also provide default callbacks usingLightningModule.configure_callbacks
as suggested in #18480. However, it seems that these default callbacks override any user-specific callbacks provided to theTrainer
. According to the documentation, this actually seems to be the intended behavior:I propose reversing this behavior such that Trainer
callbacks
overrideLightningModule.configure_callbacks
when both share the same type.Pitch
Users should be able to override callbacks hard-coded in the
LightningModule
with whatever they pass to theTrainer
. Otherwise the only way to change something like the model checkpoint frequency is to subclass theLightningModule
, which is unideal.Alternatively, if there's another way to provide model-specific default callbacks, please let me know.
Additional context
@robmarkcole and @roybenhayun reported several issues to us that stemmed from this behavior. If we can't solve this, we may have to remove all default callbacks.
@calebrob6
cc @Borda @carmocca @justusschock @awaelchli
The text was updated successfully, but these errors were encountered: