Skip to content

Commit b67b3d5

Browse files
datumboxvfdev-5
authored andcommitted
Support specifying output channels in io.image.read_image (pytorch#2988)
* Adding output channels implementation for pngs. * Adding tests for png. * Adding channels in the API and documentation. * Fixing formatting. * Refactoring test_image.py to remove huge grace_hopper_517x606.pth file from assets and reduce duplicate code. Moving jpeg assets used by encode and write unit-tests on their separate folders. * Adding output channels implementation for jpegs. Fix asset locations. * Add tests for JPEG, adding the channels in the API and documentation and adding checks for inputs. * Changing folder for unit-test. * Fixing windows flakiness, removing duplicate test. * Replacing components to channels. * Adding reference for supporting CMYK. * Minor changes: num_components to output_components, adding comments, fixing variable name etc. * Reverting output_components to num_components. * Replacing decoding with generic method on tests. * Palette converted to Gray.
1 parent cfd15fe commit b67b3d5

17 files changed

+223
-88
lines changed
3.45 KB
Loading
1.4 KB
Loading
2.08 KB
Loading

test/assets/grace_hopper_517x606.pth

-919 KB
Binary file not shown.

test/test_cpp_models.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,17 @@ def process_model(model, tensor, func, name):
2525

2626

2727
def read_image1():
28-
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
28+
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
29+
'grace_hopper_517x606.jpg')
2930
image = Image.open(image_path)
3031
image = image.resize((224, 224))
3132
x = F.to_tensor(image)
3233
return x.view(1, 3, 224, 224)
3334

3435

3536
def read_image2():
36-
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
37+
image_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg',
38+
'grace_hopper_517x606.jpg')
3739
image = Image.open(image_path)
3840
image = image.resize((299, 299))
3941
x = F.to_tensor(image)

test/test_datasets_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515

1616
TEST_FILE = get_file_path_2(
17-
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
17+
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
1818

1919

2020
class Tester(unittest.TestCase):

test/test_image.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
FAKEDATA_DIR = os.path.join(IMAGE_ROOT, "fakedata")
2020
IMAGE_DIR = os.path.join(FAKEDATA_DIR, "imagefolder")
2121
DAMAGED_JPEG = os.path.join(IMAGE_ROOT, 'damaged_jpeg')
22+
ENCODE_JPEG = os.path.join(IMAGE_ROOT, "encode_jpeg")
2223

2324

2425
def get_images(directory, img_ext):
@@ -33,14 +34,44 @@ def get_images(directory, img_ext):
3334
yield os.path.join(root, fl)
3435

3536

37+
def pil_read_image(img_path):
38+
with Image.open(img_path) as img:
39+
return torch.from_numpy(np.array(img))
40+
41+
42+
def normalize_dimensions(img_pil):
43+
if len(img_pil.shape) == 3:
44+
img_pil = img_pil.permute(2, 0, 1)
45+
else:
46+
img_pil = img_pil.unsqueeze(0)
47+
return img_pil
48+
49+
3650
class ImageTester(unittest.TestCase):
3751
def test_decode_jpeg(self):
52+
conversion = [(None, 0), ("L", 1), ("RGB", 3)]
3853
for img_path in get_images(IMAGE_ROOT, ".jpg"):
39-
img_pil = torch.load(img_path.replace('jpg', 'pth'))
40-
img_pil = img_pil.permute(2, 0, 1)
41-
data = read_file(img_path)
42-
img_ljpeg = decode_jpeg(data)
43-
self.assertTrue(img_ljpeg.equal(img_pil))
54+
for pil_mode, channels in conversion:
55+
with Image.open(img_path) as img:
56+
is_cmyk = img.mode == "CMYK"
57+
if pil_mode is not None:
58+
if is_cmyk:
59+
# libjpeg does not support the conversion
60+
continue
61+
img = img.convert(pil_mode)
62+
img_pil = torch.from_numpy(np.array(img))
63+
if is_cmyk:
64+
# flip the colors to match libjpeg
65+
img_pil = 255 - img_pil
66+
67+
img_pil = normalize_dimensions(img_pil)
68+
data = read_file(img_path)
69+
img_ljpeg = decode_image(data, channels=channels)
70+
71+
# Permit a small variation on pixel values to account for implementation
72+
# differences between Pillow and LibJPEG.
73+
abs_mean_diff = (img_ljpeg.type(torch.float32) - img_pil).abs().mean().item()
74+
self.assertTrue(abs_mean_diff < 2)
4475

4576
with self.assertRaisesRegex(RuntimeError, "Expected a non empty 1-dimensional tensor"):
4677
decode_jpeg(torch.empty((100, 1), dtype=torch.uint8))
@@ -68,7 +99,7 @@ def test_damaged_images(self):
6899
decode_jpeg(data)
69100

70101
def test_encode_jpeg(self):
71-
for img_path in get_images(IMAGE_ROOT, ".jpg"):
102+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
72103
dirname = os.path.dirname(img_path)
73104
filename, _ = os.path.splitext(os.path.basename(img_path))
74105
write_folder = os.path.join(dirname, 'jpeg_write')
@@ -111,7 +142,7 @@ def test_encode_jpeg(self):
111142
encode_jpeg(torch.empty((100, 100), dtype=torch.uint8))
112143

113144
def test_write_jpeg(self):
114-
for img_path in get_images(IMAGE_ROOT, ".jpg"):
145+
for img_path in get_images(ENCODE_JPEG, ".jpg"):
115146
data = read_file(img_path)
116147
img = decode_jpeg(data)
117148

@@ -134,20 +165,25 @@ def test_write_jpeg(self):
134165
self.assertEqual(torch_bytes, pil_bytes)
135166

136167
def test_decode_png(self):
168+
conversion = [(None, 0), ("L", 1), ("LA", 2), ("RGB", 3), ("RGBA", 4)]
137169
for img_path in get_images(FAKEDATA_DIR, ".png"):
138-
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
139-
if len(img_pil.shape) == 3:
140-
img_pil = img_pil.permute(2, 0, 1)
141-
else:
142-
img_pil = img_pil.unsqueeze(0)
143-
data = read_file(img_path)
144-
img_lpng = decode_png(data)
145-
self.assertTrue(img_lpng.equal(img_pil))
170+
for pil_mode, channels in conversion:
171+
with Image.open(img_path) as img:
172+
if pil_mode is not None:
173+
img = img.convert(pil_mode)
174+
img_pil = torch.from_numpy(np.array(img))
146175

147-
with self.assertRaises(RuntimeError):
148-
decode_png(torch.empty((), dtype=torch.uint8))
149-
with self.assertRaises(RuntimeError):
150-
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
176+
img_pil = normalize_dimensions(img_pil)
177+
data = read_file(img_path)
178+
img_lpng = decode_image(data, channels=channels)
179+
180+
tol = 0 if conversion is None else 1
181+
self.assertTrue(img_lpng.allclose(img_pil, atol=tol))
182+
183+
with self.assertRaises(RuntimeError):
184+
decode_png(torch.empty((), dtype=torch.uint8))
185+
with self.assertRaises(RuntimeError):
186+
decode_png(torch.randint(3, 5, (300,), dtype=torch.uint8))
151187

152188
def test_encode_png(self):
153189
for img_path in get_images(IMAGE_DIR, '.png'):
@@ -196,19 +232,6 @@ def test_write_png(self):
196232

197233
self.assertTrue(img_pil.equal(saved_image))
198234

199-
def test_decode_image(self):
200-
for img_path in get_images(IMAGE_ROOT, ".jpg"):
201-
img_pil = torch.load(img_path.replace('jpg', 'pth'))
202-
img_pil = img_pil.permute(2, 0, 1)
203-
img_ljpeg = decode_image(read_file(img_path))
204-
self.assertTrue(img_ljpeg.equal(img_pil))
205-
206-
for img_path in get_images(IMAGE_DIR, ".png"):
207-
img_pil = torch.from_numpy(np.array(Image.open(img_path)))
208-
img_pil = img_pil.permute(2, 0, 1)
209-
img_lpng = decode_image(read_file(img_path))
210-
self.assertTrue(img_lpng.equal(img_pil))
211-
212235
def test_read_file(self):
213236
with get_tmp_dir() as d:
214237
fname, content = 'test1.bin', b'TorchVision\211\n'

test/test_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525

2626
GRACE_HOPPER = get_file_path_2(
27-
os.path.dirname(os.path.abspath(__file__)), 'assets', 'grace_hopper_517x606.jpg')
27+
os.path.dirname(os.path.abspath(__file__)), 'assets', 'encode_jpeg', 'grace_hopper_517x606.jpg')
2828

2929

3030
class Tester(unittest.TestCase):

0 commit comments

Comments
 (0)