Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

predict_dataloaders[0].batch_size always returns None #15186

Open
KinWaiCheuk opened this issue Oct 19, 2022 · 2 comments
Open

predict_dataloaders[0].batch_size always returns None #15186

KinWaiCheuk opened this issue Oct 19, 2022 · 2 comments
Labels
data handling Generic data-related topic trainer: predict

Comments

@KinWaiCheuk
Copy link

KinWaiCheuk commented Oct 19, 2022

Bug description

I need to get access to the batch_size in my predict_step to calculate the correct output index.
But the self.trainer.predict_dataloaders[0].batch_size always return None instead of the correct batch size.

I tried doing the same thing to self.trainer.test_dataloaders[0].batch_size and everything works fine. So I highly suspect that it is a bug for self.trainer.predict_dataloaders[0].batch_size.

The example code below should return 32 when trainer.test(model, trainloader) or trainer.predict(model, predictloader) is executed.

But when trainer.predict(model, predictloader) is executed, it returns None

How to reproduce the bug

import torch
import torch.nn as nn
import pytorch_lightning as pl
from sklearn.datasets import make_blobs
from sklearn.model_selection import train_test_split
import torch.optim as optim

import matplotlib.pyplot as plt


X, Y = make_blobs(64,1000,centers=10, cluster_std=10)
X_train, X_test, y_train, y_test = train_test_split(X,Y, test_size=0.2, random_state=0)


trainset = torch.utils.data.TensorDataset(torch.from_numpy(X_train).float(),torch.from_numpy(y_train))
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)
predictloader = torch.utils.data.DataLoader(trainset, batch_size=32,shuffle=True, num_workers=2)


class Model(pl.LightningModule):
    def __init__(self):
        super(Model, self).__init__()
        self.lstm = nn.LSTM(100, 256, bidirectional=True)
        self.classifier = nn.Linear(256*2*10,10)

    def forward(self, x):
        x, _ = self.lstm(x.view(-1,10,100))
        x = self.classifier(x.flatten(1))
        return x

    def test_step(self, batch, batch_idx):
        print(f"{trainer.test_dataloaders[0].batch_size=}")
        pred = self(batch[0])
        loss = torch.nn.functional.cross_entropy(pred, batch[1])
        return loss        
    
    def predict_step(self, batch, batch_idx):
        print(f"{trainer.predict_dataloaders[0].batch_size=}")        
        pred = self(batch[0])
        loss = torch.nn.functional.cross_entropy(pred, batch[1])
        return loss    

        

    def configure_optimizers(self):
        r"""Configure optimizer."""
        return optim.Adam(self.parameters())

model = Model()

trainer = pl.Trainer(max_epochs=1, gpus=1)

trainer.test(model, trainloader)
trainer.predict(model, predictloader)

Error messages and logs


# Error messages and logs here please

Environment

PyTorch Lightning Version (1.6.0 - 1.7.7):
PyTorch Version (1.10):
Python version 3.8.10



### More info

_No response_

cc @justusschock @awaelchli
@KinWaiCheuk KinWaiCheuk added the needs triage Waiting to be triaged by maintainers label Oct 19, 2022
@stale
Copy link

stale bot commented Apr 15, 2023

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions - the Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Apr 15, 2023
@awaelchli
Copy link
Contributor

awaelchli commented Sep 12, 2023

We can unfortunately not support this because we need to inject a custom sampler for prediction into the dataloader. The PyTorch dataloader sets batch_size=None when a batch-sampler is used: https://github.com/pytorch/pytorch/blob/8025b193a966a6d8e3afc9c03a54e577bc04eb3d/torch/utils/data/dataloader.py#L329-L336

My recommendation is that users access the batch size by batch.size(0) (first dimension of your tensor) or by other means (e.g. storing the configuration).

@stale stale bot removed the won't fix This will not be worked on label Sep 12, 2023
@awaelchli awaelchli added data handling Generic data-related topic trainer: predict and removed needs triage Waiting to be triaged by maintainers labels Sep 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data handling Generic data-related topic trainer: predict
Projects
None yet
Development

No branches or pull requests

2 participants