Skip to content

instantiate_class should be recursive #13279

Closed
@aRI0U

Description

@aRI0U

🐛 Bug

I have a LightningModule that takes a nn.Module as argument. I use Lightning CLI to instantiate my model, however when I instantiate the model manually with instantiate_class, the dictionary containing the module param of my model is not instantiated.

To Reproduce

Implementation of the (dummy) networks, file models.py:

import pytorch_lightning as pl
import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(3, out_channels, kernel_size=3)

    def forward(self, x):
        return self.conv(x)


class Decoder(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.ConvTranspose2d(in_channels, 3, kernel_size=3)

    def forward(self, x):
        return self.conv(x)


class AE(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.loss_fn = nn.MSELoss()

    def shared_step(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = self.loss_fn(x, x_hat)
        return loss

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch)

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())

Config file (file config.yaml):

model:
  class_path: models.AE
  init_args:
    encoder:
      class_path: models.Encoder
      init_args:
        out_channels: 16
    decoder:
      class_path: models.Decoder
      init_args:
        in_channels: 16

Main script (file main.py):

import yaml
from pytorch_lightning.utilities.cli import instantiate_class

config_path = "config.yaml"
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

model_config = config["model"]
model = instantiate_class((), model_config)

print("model:")
print(model)
print("encoder:")
print(model.encoder)
print("decoder:")
print(model.decoder)

It outputs:

model:
AE(
  (loss_fn): MSELoss()
)
encoder:
{'class_path': 'models.Encoder', 'init_args': {'out_channels': 16}}
decoder:
{'class_path': 'models.Decoder', 'init_args': {'in_channels': 16}}

whereas encoder and decoder should be instantiated.

Environment

  • CUDA:
    - GPU:
    - available: False
    - version: None
  • Packages:
    - numpy: 1.21.5
    - pyTorch_debug: False
    - pyTorch_version: 1.11.0
    - pytorch-lightning: 1.5.8
    - tqdm: 4.64.0
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.4
    - version: Add Gradient Checkpointing #49-Ubuntu SMP Wed May 18 13:28:06 UTC 2022

cc @Borda @carmocca @mauvilsa

Metadata

Metadata

Assignees

No one assigned

    Labels

    featureIs an improvement or enhancementlightningclipl.cli.LightningCLI

    Type

    No type

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions