19
19
FAKEDATA_DIR = os .path .join (IMAGE_ROOT , "fakedata" )
20
20
IMAGE_DIR = os .path .join (FAKEDATA_DIR , "imagefolder" )
21
21
DAMAGED_JPEG = os .path .join (IMAGE_ROOT , 'damaged_jpeg' )
22
+ ENCODE_JPEG = os .path .join (IMAGE_ROOT , "encode_jpeg" )
22
23
23
24
24
25
def get_images (directory , img_ext ):
@@ -33,14 +34,44 @@ def get_images(directory, img_ext):
33
34
yield os .path .join (root , fl )
34
35
35
36
37
+ def pil_read_image (img_path ):
38
+ with Image .open (img_path ) as img :
39
+ return torch .from_numpy (np .array (img ))
40
+
41
+
42
+ def normalize_dimensions (img_pil ):
43
+ if len (img_pil .shape ) == 3 :
44
+ img_pil = img_pil .permute (2 , 0 , 1 )
45
+ else :
46
+ img_pil = img_pil .unsqueeze (0 )
47
+ return img_pil
48
+
49
+
36
50
class ImageTester (unittest .TestCase ):
37
51
def test_decode_jpeg (self ):
52
+ conversion = [(None , 0 ), ("L" , 1 ), ("RGB" , 3 )]
38
53
for img_path in get_images (IMAGE_ROOT , ".jpg" ):
39
- img_pil = torch .load (img_path .replace ('jpg' , 'pth' ))
40
- img_pil = img_pil .permute (2 , 0 , 1 )
41
- data = read_file (img_path )
42
- img_ljpeg = decode_jpeg (data )
43
- self .assertTrue (img_ljpeg .equal (img_pil ))
54
+ for pil_mode , channels in conversion :
55
+ with Image .open (img_path ) as img :
56
+ is_cmyk = img .mode == "CMYK"
57
+ if pil_mode is not None :
58
+ if is_cmyk :
59
+ # libjpeg does not support the conversion
60
+ continue
61
+ img = img .convert (pil_mode )
62
+ img_pil = torch .from_numpy (np .array (img ))
63
+ if is_cmyk :
64
+ # flip the colors to match libjpeg
65
+ img_pil = 255 - img_pil
66
+
67
+ img_pil = normalize_dimensions (img_pil )
68
+ data = read_file (img_path )
69
+ img_ljpeg = decode_image (data , channels = channels )
70
+
71
+ # Permit a small variation on pixel values to account for implementation
72
+ # differences between Pillow and LibJPEG.
73
+ abs_mean_diff = (img_ljpeg .type (torch .float32 ) - img_pil ).abs ().mean ().item ()
74
+ self .assertTrue (abs_mean_diff < 2 )
44
75
45
76
with self .assertRaisesRegex (RuntimeError , "Expected a non empty 1-dimensional tensor" ):
46
77
decode_jpeg (torch .empty ((100 , 1 ), dtype = torch .uint8 ))
@@ -68,7 +99,7 @@ def test_damaged_images(self):
68
99
decode_jpeg (data )
69
100
70
101
def test_encode_jpeg (self ):
71
- for img_path in get_images (IMAGE_ROOT , ".jpg" ):
102
+ for img_path in get_images (ENCODE_JPEG , ".jpg" ):
72
103
dirname = os .path .dirname (img_path )
73
104
filename , _ = os .path .splitext (os .path .basename (img_path ))
74
105
write_folder = os .path .join (dirname , 'jpeg_write' )
@@ -111,7 +142,7 @@ def test_encode_jpeg(self):
111
142
encode_jpeg (torch .empty ((100 , 100 ), dtype = torch .uint8 ))
112
143
113
144
def test_write_jpeg (self ):
114
- for img_path in get_images (IMAGE_ROOT , ".jpg" ):
145
+ for img_path in get_images (ENCODE_JPEG , ".jpg" ):
115
146
data = read_file (img_path )
116
147
img = decode_jpeg (data )
117
148
@@ -134,20 +165,25 @@ def test_write_jpeg(self):
134
165
self .assertEqual (torch_bytes , pil_bytes )
135
166
136
167
def test_decode_png (self ):
168
+ conversion = [(None , 0 ), ("L" , 1 ), ("LA" , 2 ), ("RGB" , 3 ), ("RGBA" , 4 )]
137
169
for img_path in get_images (FAKEDATA_DIR , ".png" ):
138
- img_pil = torch .from_numpy (np .array (Image .open (img_path )))
139
- if len (img_pil .shape ) == 3 :
140
- img_pil = img_pil .permute (2 , 0 , 1 )
141
- else :
142
- img_pil = img_pil .unsqueeze (0 )
143
- data = read_file (img_path )
144
- img_lpng = decode_png (data )
145
- self .assertTrue (img_lpng .equal (img_pil ))
170
+ for pil_mode , channels in conversion :
171
+ with Image .open (img_path ) as img :
172
+ if pil_mode is not None :
173
+ img = img .convert (pil_mode )
174
+ img_pil = torch .from_numpy (np .array (img ))
146
175
147
- with self .assertRaises (RuntimeError ):
148
- decode_png (torch .empty ((), dtype = torch .uint8 ))
149
- with self .assertRaises (RuntimeError ):
150
- decode_png (torch .randint (3 , 5 , (300 ,), dtype = torch .uint8 ))
176
+ img_pil = normalize_dimensions (img_pil )
177
+ data = read_file (img_path )
178
+ img_lpng = decode_image (data , channels = channels )
179
+
180
+ tol = 0 if conversion is None else 1
181
+ self .assertTrue (img_lpng .allclose (img_pil , atol = tol ))
182
+
183
+ with self .assertRaises (RuntimeError ):
184
+ decode_png (torch .empty ((), dtype = torch .uint8 ))
185
+ with self .assertRaises (RuntimeError ):
186
+ decode_png (torch .randint (3 , 5 , (300 ,), dtype = torch .uint8 ))
151
187
152
188
def test_encode_png (self ):
153
189
for img_path in get_images (IMAGE_DIR , '.png' ):
@@ -196,19 +232,6 @@ def test_write_png(self):
196
232
197
233
self .assertTrue (img_pil .equal (saved_image ))
198
234
199
- def test_decode_image (self ):
200
- for img_path in get_images (IMAGE_ROOT , ".jpg" ):
201
- img_pil = torch .load (img_path .replace ('jpg' , 'pth' ))
202
- img_pil = img_pil .permute (2 , 0 , 1 )
203
- img_ljpeg = decode_image (read_file (img_path ))
204
- self .assertTrue (img_ljpeg .equal (img_pil ))
205
-
206
- for img_path in get_images (IMAGE_DIR , ".png" ):
207
- img_pil = torch .from_numpy (np .array (Image .open (img_path )))
208
- img_pil = img_pil .permute (2 , 0 , 1 )
209
- img_lpng = decode_image (read_file (img_path ))
210
- self .assertTrue (img_lpng .equal (img_pil ))
211
-
212
235
def test_read_file (self ):
213
236
with get_tmp_dir () as d :
214
237
fname , content = 'test1.bin' , b'TorchVision\211 \n '
0 commit comments