Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LightningModule.configure_callbacks overrides Trainer callbacks #18784

Open
adamjstewart opened this issue Oct 11, 2023 · 12 comments
Open

LightningModule.configure_callbacks overrides Trainer callbacks #18784

adamjstewart opened this issue Oct 11, 2023 · 12 comments
Labels
discussion In a discussion stage feature Is an improvement or enhancement lightningmodule pl.LightningModule

Comments

@adamjstewart
Copy link
Contributor

adamjstewart commented Oct 11, 2023

Outline & Motivation

In TorchGeo, we provide LightningModules for common tasks, like classification, regression, etc. We also provide default callbacks using LightningModule.configure_callbacks as suggested in #18480. However, it seems that these default callbacks override any user-specific callbacks provided to the Trainer. According to the documentation, this actually seems to be the intended behavior:

the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’s callbacks argument. If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them.

I propose reversing this behavior such that Trainer callbacks override LightningModule.configure_callbacks when both share the same type.

Pitch

Users should be able to override callbacks hard-coded in the LightningModule with whatever they pass to the Trainer. Otherwise the only way to change something like the model checkpoint frequency is to subclass the LightningModule, which is unideal.

Alternatively, if there's another way to provide model-specific default callbacks, please let me know.

Additional context

@robmarkcole and @roybenhayun reported several issues to us that stemmed from this behavior. If we can't solve this, we may have to remove all default callbacks.

@calebrob6

cc @Borda @carmocca @justusschock @awaelchli

@adamjstewart
Copy link
Contributor Author

Any thoughts on this proposal?

@justusschock
Copy link
Member

Hey, personally I think it should stay the way it is right now. The reason this hook was introduced to the LM was that some models need specific callbacks for proper training. If user-passed callbacks would override them, this would mean that possibly you wouldn't have all of these essential callbacks.

Also if they won't be merged again, it is a lot of boilerplate to just extend the list of callbacks as you can't pass in the additional ones to the trainer anymore without also instantiating and listing the ones that have previously been handled by the LM.

@adamjstewart
Copy link
Contributor Author

Just to clarify, I'm proposing to keep this behavior:

the list or a callback returned here will be merged with the list of callbacks passed to the Trainer’s callbacks argument.

but change this behavior:

If a callback returned here has the same type as one or several callbacks already present in the Trainer’s callbacks list, it will take priority and replace them.

to reverse priority and let the user-passed callback to override the configure_callbacks version. Are there any LM where a callback must be configured in a certain way and a different configuration would break it?

@awaelchli
Copy link
Contributor

awaelchli commented Oct 31, 2023

The use case that you are describing here is that you provide default callbacks in configure_callbacks() and expect that the Trainer can override them. The configure_callbacks() is not meant to provide defaults, but rather the opposite. That is, callbacks and their configuration that the LightningModule requires to operate correctly.

But we could discuss making it configurable. For example, the LightningModule could have a special attribute:

class MyModel(LightningModule):

def __init__(self):
    super().__init__()
    self.replace_trainer_callbacks = True  # default
    self.replace_trainer_callbacks = False  # Trainer callbacks take precedence over configure_callbacks

(need to think about a good name)

In general though I think that "providing default callbacks and letting a user override them" should be handled completely outside the LightningModule. For example, the client code (torchgeo entry point) parsing the callbacks provided by the user could check them against the default callbacks chosen by the application (torchgeo).

@awaelchli awaelchli added feature Is an improvement or enhancement lightningmodule pl.LightningModule discussion In a discussion stage and removed refactor needs triage Waiting to be triaged by maintainers labels Oct 31, 2023
@adamjstewart
Copy link
Contributor Author

adamjstewart commented Oct 31, 2023

The configure_callbacks() is not meant to provide defaults

In that case, is there a different approach recommended for default callbacks? Our users will almost always want model checkpoints and early stopping, but might want to customize how these work. I'm hoping not to have to manually parse callbacks in our entrypoint if possible, but if that's the only way I may need help figuring out how to do that.

@awaelchli
Copy link
Contributor

What would prevent us from doing this (like any other default-argument handling):

def main(callbacks=None):
    if callbacks is None:
        callbacks = get_default_callbacks()
        
    # OR
    # append callbacks that the user didn't configure
    if not_has_this_other_callback(callbacks)
        callbacks.append(...)
    
    trainer = Trainer(callbacks=callbacks)
    
main()  # user calls with or without callbacks

Is there something in your use case that is strongly incompatible or inconvenient with this? Callbacks is just an example here of course, there would be plenty of other Trainer arguments that fall into the same usage pattern.

The above pattern is just a conceptual example, it may need some adjustment if a CLI is used etc.

@adamjstewart
Copy link
Contributor Author

Part of the problem is that the callbacks need to be configured based on which LightningModule is chosen. That was the whole motivation for #18480. How would I get access to the chosen LightningModule inside get_default_callbacks?

@awaelchli
Copy link
Contributor

Yeah this code snippet was just a sketch of course. If there are only a couple LightningModules, I guess a switch statement with if elif elif else should do (the get_default_callbacks would take the selected LightningModule as input). But if there are many lightning modules then I understand that it's a bit cumbersome to maintain a mapping.

If this is truly a new use case for the configure_callbacks hook, I think the configurable approach would be a way moving forward.

@adamjstewart
Copy link
Contributor Author

We don't have too many modules to maintain a mapping for. How would I get access to the selected LightningModule? I've never written custom LightningCLI arg parsing before...

@awaelchli
Copy link
Contributor

For LightningCLI specifically, I'm actually not sure what the best practice is. This might be a bit outside of what LightningCLI can offer without overriding it. @carmocca do you have any suggestions?

@carmocca
Copy link
Contributor

carmocca commented Nov 2, 2023

The LightningCLI could be subclassed to override instantiate_trainer: https://github.com/Lightning-AI/lightning/blob/master/src/lightning/pytorch/cli.py#L537

At that point, self.model will already be defined and can be used to pass callbacks as kwargs to the super call. Roughly:

class MyLightningCLI(LightningLCI):
    def instantiate_trainer(self, **kwargs):
       if isinstance(self.model, ThisModel):
           kwargs["callbacks"] = ThisCallback()
       return super().instantiate_trainer(**kwargs) 

@adamjstewart
Copy link
Contributor Author

Thanks, now I just need to decide whether this added complexity is worth it vs. just letting the users define their own callbacks.

I would still be interested in some setting somewhere that allows Trainer args to take precedence but don't want to add any more complexity to LightningModules if I'm the only one using them this way. Maybe we can leave this issue open for a while and see if anyone else needs this feature.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
discussion In a discussion stage feature Is an improvement or enhancement lightningmodule pl.LightningModule
Projects
None yet
Development

No branches or pull requests

4 participants