Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit check for number of channels #3013

Merged
merged 17 commits into from
Nov 20, 2020
Merged

Conversation

ademyanchuk
Copy link
Contributor

Example why you need to check it:
M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)
When you put this input through to_pil_image without mode argument, it converts to uint8 here:

if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()

and change the mode to RGB here:

if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'

Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3
@facebook-github-bot
Copy link

Hi @ademyanchuk!

Thank you for your pull request and welcome to our community. We require contributors to sign our Contributor License Agreement, and we don't seem to have you on file.

In order for us to review and merge your code, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

If you have received this in error or have any questions, please contact us at cla@fb.com. Thanks!

@ademyanchuk
Copy link
Contributor Author

Also in this Pytorch forum discussion.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 17, 2020

@ademyanchuk thanks for the PR ! Instead I would add the check before to numpy conversion, before line 194, as if pic.shape[-3] > 4.
Also, it would be good to add a test on that too, for example here:

def test_tensor_bad_types_to_pil_image(self):

and
def test_ndarray_bad_types_to_pil_image(self):

Thanks

@ademyanchuk
Copy link
Contributor Author

@vfdev-5 thanks for comment! I thought where to put the check for number of channels and decided to put it after conversion to numpy as at this point we know the channel dimension position is .shape[2]. Given the function takes as input both torch.tensor and numpy.array and assumes tensor is channel first, we have to add check two times for numpy with ndim==3 and for torch with ndim==3. Do you think I should add the check two times? Or perhaps is it better to check only once after conversion?

Sure, I would like to add tests. I will see how can I add those. It is my first PR :)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 17, 2020

Given the function takes as input both torch.tensor and numpy.array and assumes tensor is channel first, we have to add check two times for numpy with ndim==3 and for torch with ndim==3.

Yes, true. For tensor the assumption is that number of channels is -3 and for numpy is -1. Idea is that IMO it would be better to raise exception before making data processing ...

Sure, I would like to add tests. I will see how can I add those. It is my first PR :)

Sounds good. Feel free to ask questions if needed. Please, take a look at drafted CONTRIBUTING, if needed.

@ademyanchuk
Copy link
Contributor Author

Idea is that IMO it would be better to raise exception before making data processing ...

This is totally make sense. So I will proceed.

@ademyanchuk
Copy link
Contributor Author

Ok. I refactored the checking to be done before processing and added a test for invalid number of channels :)

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @ademyanchuk ! I left some comments to improve the PR.

test/test_transforms.py Show resolved Hide resolved
test/test_transforms.py Show resolved Hide resolved
torchvision/transforms/functional.py Outdated Show resolved Hide resolved
torchvision/transforms/functional.py Outdated Show resolved Hide resolved
ademyanchuk and others added 6 commits November 18, 2020 14:54
Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3
@ademyanchuk
Copy link
Contributor Author

Ok. It seems like I manage to make all discussed changes :)

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @ademyanchuk
I left few other comments


def test_ndarray_bad_types_to_pil_image(self):
trans = transforms.ToPILImage()
with self.assertRaises(TypeError):
with self.assertRaisesRegex(TypeError, r'Input type \w+ is not supported'):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please also fix this test as it was also incorrect: add with self.assertRaisesRegex(TypeError, r'Input type \w+ is not supported') for each trans(np.ones([4, 4, 1], type)).

torchvision/transforms/functional.py Outdated Show resolved Hide resolved
torchvision/transforms/functional.py Outdated Show resolved Hide resolved
@ademyanchuk
Copy link
Contributor Author

@vfdev-5 , sorry to bather. How could I do python setup.py develop so it compiles with cuda version 10.2?

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2020

How could I do python setup.py develop so it compiles with cuda version 10.2?

It will build cuda extensions with the same cuda version as torch has. If you'd like to disable it, just set

CUDA_VISIBLE_DEVICES="" python setup.py develop

@ademyanchuk
Copy link
Contributor Author

ademyanchuk commented Nov 18, 2020

It will build cuda extensions with the same cuda version as torch has

Apparently in my case it is not :(
I did conda install pytorch -c pytorch-nightly
it install torch with cuda 10.2
but python setup.py develop compile cuda11 version

On my local machine cuda 11 is installed - could it be issue?

So, as it is WSL2, i just delete my local cuda11. And torchvision compiled with matched version of cuda as in torch nightly!

@ademyanchuk
Copy link
Contributor Author

Anyway, I made the changes from your last comments :)

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2020

It will build cuda extensions with the same cuda version as torch has

Apparently in my case it is not :(
I did conda install pytorch -c pytorch-nightly
it install torch with cuda 10.2
but python setup.py develop compile cuda11 version

On my local machine cuda 11 is installed - could it be issue?

So, as it is WSL2, i just delete my local cuda11. And torchvision compiled with matched version of cuda as in torch nightly!

It would be nice to have a detailed instructions to be able to reproduce your situation. If you could reproduce the issue, please file an issue about that.

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Nov 18, 2020

Could you also please merge master to your branch. There is unexpected modification here : https://github.com/pytorch/vision/pull/3013/files#diff-b5826ce59a98ee20294c65d06e09363609879221494b5ff7fbb7087b75f3e159L7

@ademyanchuk
Copy link
Contributor Author

Hi, I did merge master into my branch patch-1. I also manually checked this file. But I have no diff with pytorch vision master. This diff that you show the link for is from this commit (as I understand) 365f159

@ademyanchuk
Copy link
Contributor Author

It would be nice to have a detailed instructions to be able to reproduce your situation. If you could reproduce the issue, please file an issue about that.

I will try to reproduce the issue. If I succeed, I would file the issue.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ademyanchuk !

@ademyanchuk
Copy link
Contributor Author

Thanks @vfdev-5 for a valuable comments and review. Appreciate! So, is it done? How the merge in master is proceeded?

@ademyanchuk
Copy link
Contributor Author

It appears to me like most of recent pull requests fail on the same travis ci job on the same place.

@fmassa fmassa merged commit a51c49e into pytorch:master Nov 20, 2020
bryant1410 pushed a commit to bryant1410/vision-1 that referenced this pull request Nov 22, 2020
* Add explicit check for number of channels

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

* Check number of channels before processing

* Add test for invalid number of channels

* Add explicit check for number of channels

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

* Check number of channels before processing

* Add test for invalid number of channels

* Put check after channel dim unsqueeze

* Add test if error message is matching

* Delete redundant code

* Bug fix in checking for bad types

Co-authored-by: Demyanchuk <demyanca@mh-hannover.local>
Co-authored-by: vfdev <vfdev.5@gmail.com>
vfdev-5 added a commit to Quansight/vision that referenced this pull request Dec 4, 2020
* Add explicit check for number of channels

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

* Check number of channels before processing

* Add test for invalid number of channels

* Add explicit check for number of channels

Example why you need to check it:
`M = torch.randint(low=0, high=2, size=(6, 64, 64), dtype = torch.float)`
When you put this input through to_pil_image without mode argument, it converts to uint8 here:
```
if pic.is_floating_point() and mode != 'F':
            pic = pic.mul(255).byte()
```
and change the mode to RGB here:
```
if mode is None and npimg.dtype == np.uint8:
            mode = 'RGB'
```
Image.fromarray doesn't raise if provided with mode RGB and just cut number of channels from what you have to 3

* Check number of channels before processing

* Add test for invalid number of channels

* Put check after channel dim unsqueeze

* Add test if error message is matching

* Delete redundant code

* Bug fix in checking for bad types

Co-authored-by: Demyanchuk <demyanca@mh-hannover.local>
Co-authored-by: vfdev <vfdev.5@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants