Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

instantiate_class should be recursive #13279

Closed
aRI0U opened this issue Jun 13, 2022 · 9 comments · Fixed by #18105
Closed

instantiate_class should be recursive #13279

aRI0U opened this issue Jun 13, 2022 · 9 comments · Fixed by #18105
Labels
feature Is an improvement or enhancement lightningcli pl.cli.LightningCLI
Milestone

Comments

@aRI0U
Copy link

aRI0U commented Jun 13, 2022

🐛 Bug

I have a LightningModule that takes a nn.Module as argument. I use Lightning CLI to instantiate my model, however when I instantiate the model manually with instantiate_class, the dictionary containing the module param of my model is not instantiated.

To Reproduce

Implementation of the (dummy) networks, file models.py:

import pytorch_lightning as pl
import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(3, out_channels, kernel_size=3)

    def forward(self, x):
        return self.conv(x)


class Decoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, 3, kernel_size=3)

    def forward(self, x):
        return self.conv(x)


class AE(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.loss_fn = nn.MSELoss()

    def shared_step(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = self.loss_fn(x, x_hat)
        return loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

Config file (file config.yaml):

model:
  class_path: models.AE
  init_args:
    encoder:
      class_path: models.Encoder
      init_args:
        out_channels: 16
    decoder:
      class_path: models.Decoder
      init_args:
        in_channels: 16

Main script (file main.py):

import yaml
from pytorch_lightning.utilities.cli import instantiate_class

config_path = "config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

model_config = config["model"]
model = instantiate_class((), model_config)

print("model:")
print(model)
print("encoder:")
print(model.encoder)
print("decoder:")
print(model.decoder)

It outputs:

model:
AE(
  (loss_fn): MSELoss()
)
encoder:
{'class_path': 'models.Encoder', 'init_args': {'out_channels': 16}}
decoder:
{'class_path': 'models.Decoder', 'init_args': {'in_channels': 16}}

whereas encoder and decoder should be instantiated.

Environment

  • CUDA:
    - GPU:
    - available: False
    - version: None
  • Packages:
    - numpy: 1.21.5
    - pyTorch_debug: False
    - pyTorch_version: 1.11.0
    - pytorch-lightning: 1.5.8
    - tqdm: 4.64.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.4
    - version: Add Gradient Checkpointing #49-Ubuntu SMP Wed May 18 13:28:06 UTC 2022

cc @Borda @carmocca @mauvilsa

@aRI0U aRI0U added the needs triage Waiting to be triaged by maintainers label Jun 13, 2022
@akihironitta akihironitta added lightningcli pl.cli.LightningCLI and removed needs triage Waiting to be triaged by maintainers labels Jun 14, 2022
@carmocca carmocca added the feature Is an improvement or enhancement label Jun 14, 2022
@carmocca carmocca added this to the future milestone Jun 14, 2022
@carmocca
Copy link
Contributor

I think this is fair. Thoughts @mauvilsa? Would you like to work on this?

@aRI0U
Copy link
Author

aRI0U commented Jun 14, 2022

Thanks a lot for your reply!

I realized a really basic recursive implementation of instantiate_class, but it only works when there are no args. I put it here for reference:

def rec_instantiate_class(config):
    if isinstance(config, list):
        return [rec_instantiate_class(cfg) for cfg in config]
    if not isinstance(config, dict):
        return config

    for k, v in config.items():
        config[k] = rec_instantiate_class(v)

    if config.get("class_path") is None:
        return config

    return instantiate_class((), config)

@mauvilsa
Copy link
Contributor

The instantiate_class function was introduced to instantiate optimizers which require a positional argument and it is not expected to have nested classes. Certainly it can be made recursive so that it can be used in other situations. It is fine that args is supported only on the first level. Though the change should be modifying instantiate_class, so that it calls itself, not having a new function like rec_instantiate_class in the snippet above. This is simple enough for anyone to contribute. No point in me working on this. @aRI0U will you work on it?

Do note that it would be up to the developer to know when it is okay to use this function. For example, someone might want to link_arguments from out_channels to in_channels. The saved config will not include a value for in_channels so just loading a saved fit config and then using instantiate_class would fail. Later on someone will request that instantiate_class should support argument links. With a first thought I am not sure this is a good direction to take. The core issue is that the purpose of the saved configs is reproducibility of the command that was run and must be only reproducibility. The same config cannot simultaneously serve another purpose, e.g. to instantiate the model for prediction. In many cases the config for training must be different from what is needed for other purposes. And this highly depends on the specific case.

Having said this, making instantiate_class recursive does not hurt. Still I have been thinking about a more proper solution to the core issue.

@pisarik
Copy link

pisarik commented Feb 1, 2024

Oh, I was expecting this feature. It would be highly valuable for my data generator.

@mauvilsa do you have any suggestions to establish training config? I know there are pretty handy hydra and omegaconf, but I am afraid of mixing those in addition to Lightning CLI.

@mauvilsa
Copy link
Contributor

mauvilsa commented Feb 1, 2024

This issue is from long ago. Right now I don't even think it is worth extending instantiate_class. Now the recommended way to add optimizers is to use dependency injection as explained in multiple-optimizers-and-schedulers. And the same can be done for nn.Module submodules or other classes in general. There is no need for instantiate_class anymore.

Regarding how to instantiate a previously trained model, I did come up with what I think is a proper solution. My proposal is in #18105. Unfortunately it hasn't received much attention.

@pisarik I am not sure what you meant by "establish training config". LightningCLI already supports training configs quite extensively.

@pisarik
Copy link

pisarik commented Feb 1, 2024

@mauvilsa Thank you for the info. You are right! I just tried to write a config with nested classes and it worked out as a charm. Should this issue be closed then?

@mauvilsa
Copy link
Contributor

mauvilsa commented Feb 1, 2024

however when I instantiate the model manually with instantiate_class, the dictionary containing the module param of my model is not instantiated.

This ticket is originally about instantiating a model manually, not for training. So I would say that this issue shouldn't be closed yet.

@aRI0U
Copy link
Author

aRI0U commented Feb 5, 2024

Hi! Actually yeah the main issue is that reloading checkpoints when using dependency injection in general is not straightforward, because the hparams dict is only saved in the main LightningModule, so one has to manually reload the configuration that corresponds to the checkpoint, then instantiate the model, then load the checkpoint.

The ideal situation would be to store the whole config in the checkpoint (not sure how easy and general it would be to implement though).

If this is not possible, it would be great to have a convenient way to reinstantiate the model and datamodule from a yaml config generated by the CLI, and I guess instantiate_class would be the appropriate method. Moreover, I don't see any case where one would want to instantiate a model non-recursively. If there are maybe a flag could handle this issue

@mauvilsa
Copy link
Contributor

mauvilsa commented Feb 6, 2024

@aRI0U you are welcome to have a look at the proposal #18105, comment, test it out in its current state, etc.

The ideal situation would be to store the whole config in the checkpoint

This is what that pull request does.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement lightningcli pl.cli.LightningCLI
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants