Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Float PILImage not converted as writeable #2194

Open
lukasHoel opened this issue May 7, 2020 · 8 comments
Open

Float PILImage not converted as writeable #2194

lukasHoel opened this issue May 7, 2020 · 8 comments

Comments

@lukasHoel
Copy link

lukasHoel commented May 7, 2020

🐛 Bug

When we have a float PIL-Image (e.g. mode='F'), e.g. for the purpose of applying transforms to it, and finally convert it with ToTensor then it will print the warning

/opt/conda/conda-bld/pytorch_1587428094786/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.

The writeability of the numpy array got lost somewhere when converting a numpy array to PILImage with Image.fromarray(numpy_array, mode='F') and then after some transforms to a tensor with ToTensor.
This does not happen with PIL Images other than float (e.g. mode='RGB').

This warning is especially annoying since it gets printed every epoch.

To Reproduce

Steps to reproduce the behavior:

    from torchvision.transforms import ToTensor
    import numpy as np
    from PIL import Image
    a = np.array([[1.0,0.5], [1.0,0.5]])
    print(a.flags.writeable)
    pilimg = Image.fromarray(a, mode='F')
    tensor = ToTensor()(pilimg)
    print(tensor.numpy().flags)

    b = np.asarray(pilimg)
    c = np.array(pilimg)
    print(b.flags)
    print(c.flags)

This code will print above warning.

Also note the following:

  • the numpy array b is NOT writeable, the numpy array c is writeable.
  • This suggests that the error is located in the conversion from numpy to PIL.
  • But since I do not have the easy possibility to convert PIL to numpy and then to tensor within a transforms Compose and since numpy array c is writeable, I open this issue here.

As workaround I do the following:

class ToNumpy(object):
    def __call__(self, sample):
        return np.array(sample)

def fix_compose_transform(transform):
        if isinstance(transform.transforms[-1], torchvision.transforms.ToTensor):
            transform = torchvision.transforms.Compose([
                *transform.transforms[:-1],
                ToNumpy(),
                torchvision.transforms.ToTensor()
            ])
        return transform

Expected behavior

Warning is not printed and ToTensor method can deal with the misbehaviour of PIL image.

Environment

Collecting environment information...
PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.1

OS: Ubuntu 18.04.4 LTS
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
CMake version: version 3.10.2

Python version: 3.8
Is CUDA available: Yes
CUDA runtime version: 10.1.243
GPU models and configuration: GPU 0: GeForce GTX 970
Nvidia driver version: 418.87.01
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy==1.18.3
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.1.243             h6bb024c_0  
[conda] mkl                       2020.0                      166  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.0.15           py38ha843d7b_0  
[conda] mkl_random                1.1.0            py38h962f231_0  
[conda] numpy                     1.18.1           py38h4f9e942_0  
[conda] numpy-base                1.18.1           py38hde5b4d6_1  
[conda] numpydoc                  0.9.2                      py_0  
[conda] pytorch                   1.5.0           py3.8_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] pytorch3d                 0.1.1                    pypi_0    pypi
[conda] torchvision               0.6.0                py38_cu101    pytorch

@zhangguanheng66
Copy link
Contributor

In your snippet code, Image is not defined.

@lukasHoel
Copy link
Author

Sorry, forgot that. Now added, it's the PIL image...

@zhangguanheng66
Copy link
Contributor

Yeap. I got the same warning as you. @fmassa Do you thing if we need a fix for this?

@pmeier
Copy link
Collaborator

pmeier commented May 8, 2020

We started to split to_tensor in pil_to_tensor (#2092) and convert_image_dtype (#2078). Maybe this already covered by what we already have. If not, I think it would be good fit there.

@lukasHoel
Copy link
Author

@pmeier had a quick look at both suggestions:

  • pil_to_tensor has the same problem, as it uses np.asarray instead of np.array to convert from PIL to numpy. I mentioned this in the corresponding line of code, see Add pil_to_tensor to functionals #2092. If it is ok to swap the numpy method, this would solve the problem in a similar way as in my workaround above, but also integrate it nicely into transforms.Compose.

  • convert_image_dtype is not practical in this use case as users would like to keep the data as-is and only get them as tensor without converting types.

@pmeier
Copy link
Collaborator

pmeier commented May 8, 2020

It was meant as suggestion where we should fix this. If I understood @fmassa correctly convert_image_dtype(pil_to_tensor(pil_image)) will replace to_tensor(pil_image) in the future.

@fmassa
Copy link
Member

fmassa commented May 15, 2020

@pmeier is correct, and convert_image_dtype(pil_to_tensor(pil_image)) will be a substitute for to_tensor(pil_image) for most of the cases.

About the writeable issue, I think this lies within PyTorch itself.
Making a copy in torchvision would add more overhead, which might not be necessary in most of the cases.

@bwesen
Copy link

bwesen commented Oct 2, 2020

FYI this happens with "normal" loading as well like opening a png with PIL Image, as luminance image for example, and then doing tvt.functional.pil_to_tensor warns this warning. Doesn't have to be a float image.

pic = Image.open(fn).convert(mode='L')
tensor = tvt.functional.pil_to_tensor(img).type(torch.ByteTensor)

=> the warning, ending with:
(Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
img = torch.as_tensor(np.asarray(pic))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants