Closed
Description
🐛 Bug
I have a LightningModule
that takes a nn.Module
as argument. I use Lightning CLI to instantiate my model, however when I instantiate the model manually with instantiate_class
, the dictionary containing the module param of my model is not instantiated.
To Reproduce
Implementation of the (dummy) networks, file models.py
:
import pytorch_lightning as pl
import torch
import torch.nn as nn
class Encoder(nn.Module):
def __init__(self, out_channels):
super().__init__()
self.conv = nn.Conv2d(3, out_channels, kernel_size=3)
def forward(self, x):
return self.conv(x)
class Decoder(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.conv = nn.ConvTranspose2d(in_channels, 3, kernel_size=3)
def forward(self, x):
return self.conv(x)
class AE(pl.LightningModule):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.loss_fn = nn.MSELoss()
def shared_step(self, x):
z = self.encoder(x)
x_hat = self.decoder(z)
loss = self.loss_fn(x, x_hat)
return loss
def training_step(self, batch, batch_idx):
return self.shared_step(batch)
def validation_step(self, batch, batch_idx):
return self.shared_step(batch)
def configure_optimizers(self):
return torch.optim.Adam(self.parameters())
Config file (file config.yaml
):
model:
class_path: models.AE
init_args:
encoder:
class_path: models.Encoder
init_args:
out_channels: 16
decoder:
class_path: models.Decoder
init_args:
in_channels: 16
Main script (file main.py
):
import yaml
from pytorch_lightning.utilities.cli import instantiate_class
config_path = "config.yaml"
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
model_config = config["model"]
model = instantiate_class((), model_config)
print("model:")
print(model)
print("encoder:")
print(model.encoder)
print("decoder:")
print(model.decoder)
It outputs:
model:
AE(
(loss_fn): MSELoss()
)
encoder:
{'class_path': 'models.Encoder', 'init_args': {'out_channels': 16}}
decoder:
{'class_path': 'models.Decoder', 'init_args': {'in_channels': 16}}
whereas encoder
and decoder
should be instantiated.
Environment
- CUDA:
- GPU:
- available: False
- version: None - Packages:
- numpy: 1.21.5
- pyTorch_debug: False
- pyTorch_version: 1.11.0
- pytorch-lightning: 1.5.8
- tqdm: 4.64.0 - System:
- OS: Linux
- architecture:
- 64bit
- ELF
- processor: x86_64
- python: 3.10.4
- version: Add Gradient Checkpointing #49-Ubuntu SMP Wed May 18 13:28:06 UTC 2022