Description
🐛 Bug
Non CPU generator objects cause torch.utils.data.random_split
to fail without any error message
To Reproduce
Steps to reproduce the behavior:
- Create a Generator object with a device type CUDA.
- Add that CUDA Generator to
torch.utils.data.random_split
function. - Run code, and watch how it fails without any error message.
import torch
rnd_generator = torch.Generator(device='cuda:0')
print(sorted(torch.utils.data.random_split([1,2,3,4,5,6,7,8,9,0], [8,2], generator=rnd_generator)[0]))
Expected behavior
The device type of the Generator object either shouldn't affect torch.utils.data.random_split
or an error message should be thrown.
Environment
-
PyTorch version: 1.6.0+cu101
-
Is debug build: False
-
CUDA used to build PyTorch: 10.1
-
ROCM used to build PyTorch: N/A
-
OS: Ubuntu 18.04.5 LTS (x86_64)
-
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
-
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
-
CMake version: version 3.12.0
-
Python version: 3.6 (64-bit runtime)
-
Is CUDA available: True
-
CUDA runtime version: 10.1.243
-
GPU models and configuration: GPU 0: Tesla K80
-
Nvidia driver version: 418.67
-
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
-
HIP runtime version: N/A
-
MIOpen runtime version: N/A
Versions of relevant libraries:
- [pip3] numpy==1.18.5
- [pip3] torch==1.6.0+cu101
- [pip3] torchsummary==1.5.1
- [pip3] torchtext==0.3.1
- [pip3] torchvision==0.7.0+cu101
- [conda] Could not collect
Additional context
The above is from Google Colab (the instance crashed when I ran the test code), and I can also confirm the issue is present on Windows as well.