Open
Description
Bug description
I need to get access to the batch_size
in my predict_step
to calculate the correct output index.
But the self.trainer.predict_dataloaders[0].batch_size
always return None
instead of the correct batch size.
I tried doing the same thing to self.trainer.test_dataloaders[0].batch_size
and everything works fine. So I highly suspect that it is a bug for self.trainer.predict_dataloaders[0].batch_size
.
The example code below should return 32
when trainer.test(model, trainloader)
or trainer.predict(model, predictloader)
is executed.
But when trainer.predict(model, predictloader)
is executed, it returns None
How to reproduce the bug
import torch
import torch.nn as nn
import pytorch_lightning as pl
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
import torch.optim as optim
import matplotlib.pyplot as plt
X, Y = make_blobs(64,1000,centers=10, cluster_std=10)
X_train, X_test, y_train, y_test = train_test_split(X,Y, test_size=0.2, random_state=0)
trainset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),torch.from_numpy(y_train))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)
predictloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)
class Model(pl.LightningModule):
def __init__(self):
super(Model, self).__init__()
self.lstm = nn.LSTM(100, 256, bidirectional=True)
self.classifier = nn.Linear(256*2*10,10)
def forward(self, x):
x, _ = self.lstm(x.view(-1,10,100))
x = self.classifier(x.flatten(1))
return x
def test_step(self, batch, batch_idx):
print(f"{trainer.test_dataloaders[0].batch_size=}")
pred = self(batch[0])
loss = torch.nn.functional.cross_entropy(pred, batch[1])
return loss
def predict_step(self, batch, batch_idx):
print(f"{trainer.predict_dataloaders[0].batch_size=}")
pred = self(batch[0])
loss = torch.nn.functional.cross_entropy(pred, batch[1])
return loss
def configure_optimizers(self):
r"""Configure optimizer."""
return optim.Adam(self.parameters())
model = Model()
trainer = pl.Trainer(max_epochs=1, gpus=1)
trainer.test(model, trainloader)
trainer.predict(model, predictloader)
Error messages and logs
# Error messages and logs here please
Environment
PyTorch Lightning Version (1.6.0 - 1.7.7):
PyTorch Version (1.10):
Python version 3.8.10
### More info
_No response_
cc @justusschock @awaelchli