@@ -34,6 +34,12 @@ class Image:
34
34
35
35
_pil_to_tensor = PILToTensor ()
36
36
_default_device = _get_device ()
37
+ _FILTER_EDGES_KERNEL = (
38
+ torch .tensor ([[- 1.0 , - 1.0 , - 1.0 ], [- 1.0 , 8.0 , - 1.0 ], [- 1.0 , - 1.0 , - 1.0 ]])
39
+ .unsqueeze (dim = 0 )
40
+ .unsqueeze (dim = 0 )
41
+ .to (_default_device )
42
+ )
37
43
38
44
@staticmethod
39
45
def from_file (path : str | Path , device : Device = _default_device ) -> Image :
@@ -116,7 +122,10 @@ def _repr_jpeg_(self) -> bytes | None:
116
122
if self .channel == 4 :
117
123
return None
118
124
buffer = io .BytesIO ()
119
- save_image (self ._image_tensor .to (torch .float32 ) / 255 , buffer , format = "jpeg" )
125
+ if self .channel == 1 :
126
+ func2 .to_pil_image (self ._image_tensor , mode = "L" ).save (buffer , format = "jpeg" )
127
+ else :
128
+ save_image (self ._image_tensor .to (torch .float32 ) / 255 , buffer , format = "jpeg" )
120
129
buffer .seek (0 )
121
130
return buffer .read ()
122
131
@@ -130,7 +139,10 @@ def _repr_png_(self) -> bytes:
130
139
The image as PNG.
131
140
"""
132
141
buffer = io .BytesIO ()
133
- save_image (self ._image_tensor .to (torch .float32 ) / 255 , buffer , format = "png" )
142
+ if self .channel == 1 :
143
+ func2 .to_pil_image (self ._image_tensor , mode = "L" ).save (buffer , format = "png" )
144
+ else :
145
+ save_image (self ._image_tensor .to (torch .float32 ) / 255 , buffer , format = "png" )
134
146
buffer .seek (0 )
135
147
return buffer .read ()
136
148
@@ -213,7 +225,10 @@ def to_jpeg_file(self, path: str | Path) -> None:
213
225
if self .channel == 4 :
214
226
raise IllegalFormatError ("png" )
215
227
Path (path ).parent .mkdir (parents = True , exist_ok = True )
216
- save_image (self ._image_tensor .to (torch .float32 ) / 255 , path , format = "jpeg" )
228
+ if self .channel == 1 :
229
+ func2 .to_pil_image (self ._image_tensor , mode = "L" ).save (path , format = "jpeg" )
230
+ else :
231
+ save_image (self ._image_tensor .to (torch .float32 ) / 255 , path , format = "jpeg" )
217
232
218
233
def to_png_file (self , path : str | Path ) -> None :
219
234
"""
@@ -225,7 +240,10 @@ def to_png_file(self, path: str | Path) -> None:
225
240
The path to the PNG file.
226
241
"""
227
242
Path (path ).parent .mkdir (parents = True , exist_ok = True )
228
- save_image (self ._image_tensor .to (torch .float32 ) / 255 , path , format = "png" )
243
+ if self .channel == 1 :
244
+ func2 .to_pil_image (self ._image_tensor , mode = "L" ).save (path , format = "png" )
245
+ else :
246
+ save_image (self ._image_tensor .to (torch .float32 ) / 255 , path , format = "png" )
229
247
230
248
# ------------------------------------------------------------------------------------------------------------------
231
249
# Transformations
@@ -457,6 +475,12 @@ def adjust_color_balance(self, factor: float) -> Image:
457
475
UserWarning ,
458
476
stacklevel = 2 ,
459
477
)
478
+ elif self .channel == 1 :
479
+ warnings .warn (
480
+ "Color adjustment will not have an affect on grayscale images with only one channel." ,
481
+ UserWarning ,
482
+ stacklevel = 2 ,
483
+ )
460
484
return Image (
461
485
self .convert_to_grayscale ()._image_tensor * (1.0 - factor * 1.0 ) + self ._image_tensor * (factor * 1.0 ),
462
486
device = self .device ,
@@ -568,3 +592,38 @@ def rotate_left(self) -> Image:
568
592
The image rotated 90 degrees counter-clockwise.
569
593
"""
570
594
return Image (func2 .rotate (self ._image_tensor , 90 , expand = True ), device = self .device )
595
+
596
+ def find_edges (self ) -> Image :
597
+ """
598
+ Return a grayscale version of the image with the edges highlighted.
599
+
600
+ The original image is not modified.
601
+
602
+ Returns
603
+ -------
604
+ result : Image
605
+ The image with edges found.
606
+ """
607
+ kernel = (
608
+ Image ._FILTER_EDGES_KERNEL
609
+ if self .device .type == Image ._default_device
610
+ else Image ._FILTER_EDGES_KERNEL .to (self .device )
611
+ )
612
+ edges_tensor = torch .clamp (
613
+ torch .nn .functional .conv2d (
614
+ self .convert_to_grayscale ()._image_tensor .float ()[0 ].unsqueeze (dim = 0 ),
615
+ kernel ,
616
+ padding = "same" ,
617
+ ).squeeze (dim = 1 ),
618
+ 0 ,
619
+ 255 ,
620
+ ).to (torch .uint8 )
621
+ if self .channel == 3 :
622
+ return Image (edges_tensor .repeat (3 , 1 , 1 ), device = self .device )
623
+ elif self .channel == 4 :
624
+ return Image (
625
+ torch .cat ([edges_tensor .repeat (3 , 1 , 1 ), self ._image_tensor [3 ].unsqueeze (dim = 0 )]),
626
+ device = self .device ,
627
+ )
628
+ else :
629
+ return Image (edges_tensor , device = self .device )
0 commit comments