Skip to content

Commit 2a841d4

Browse files
committed
fix: add load_image
1 parent 1c632f2 commit 2a841d4

File tree

1 file changed

+125
-0
lines changed

1 file changed

+125
-0
lines changed

rapid_table_det/utils/load_image.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# -*- encoding: utf-8 -*-
2+
# @Author: SWHL
3+
# @Contact: liekkaskono@163.com
4+
from io import BytesIO
5+
from pathlib import Path
6+
from typing import Any, Union
7+
8+
import cv2
9+
import numpy as np
10+
from PIL import Image, UnidentifiedImageError
11+
12+
root_dir = Path(__file__).resolve().parent
13+
InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
14+
15+
16+
class LoadImage:
17+
def __init__(self):
18+
pass
19+
20+
def __call__(self, img: InputType) -> np.ndarray:
21+
if not isinstance(img, InputType.__args__):
22+
raise LoadImageError(
23+
f"The img type {type(img)} does not in {InputType.__args__}"
24+
)
25+
26+
origin_img_type = type(img)
27+
img = self.load_img(img)
28+
if img.ndim == 3:
29+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
30+
img = self.convert_img(img, origin_img_type)
31+
return img
32+
33+
def load_img(self, img: InputType) -> np.ndarray:
34+
if isinstance(img, (str, Path)):
35+
self.verify_exist(img)
36+
try:
37+
img = self.img_to_ndarray(Image.open(img))
38+
except UnidentifiedImageError as e:
39+
raise LoadImageError(f"cannot identify image file {img}") from e
40+
return img
41+
42+
if isinstance(img, bytes):
43+
img = self.img_to_ndarray(Image.open(BytesIO(img)))
44+
return img
45+
46+
if isinstance(img, np.ndarray):
47+
return img
48+
49+
if isinstance(img, Image.Image):
50+
return self.img_to_ndarray(img)
51+
52+
raise LoadImageError(f"{type(img)} is not supported!")
53+
54+
def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
55+
if img.mode == "1":
56+
img = img.convert("L")
57+
return np.array(img)
58+
return np.array(img)
59+
60+
def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
61+
if img.ndim == 2:
62+
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
63+
64+
if img.ndim == 3:
65+
channel = img.shape[2]
66+
if channel == 1:
67+
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
68+
69+
if channel == 2:
70+
return self.cvt_two_to_three(img)
71+
72+
if channel == 3:
73+
if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
74+
return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
75+
return img
76+
77+
if channel == 4:
78+
return self.cvt_four_to_three(img)
79+
80+
raise LoadImageError(
81+
f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
82+
)
83+
84+
raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
85+
86+
@staticmethod
87+
def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
88+
"""gray + alpha → BGR"""
89+
img_gray = img[..., 0]
90+
img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
91+
92+
img_alpha = img[..., 1]
93+
not_a = cv2.bitwise_not(img_alpha)
94+
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
95+
96+
new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
97+
new_img = cv2.add(new_img, not_a)
98+
return new_img
99+
100+
@staticmethod
101+
def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
102+
"""RGBA → BGR"""
103+
r, g, b, a = cv2.split(img)
104+
new_img = cv2.merge((b, g, r))
105+
106+
not_a = cv2.bitwise_not(a)
107+
not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
108+
109+
new_img = cv2.bitwise_and(new_img, new_img, mask=a)
110+
111+
mean_color = np.mean(new_img)
112+
if mean_color <= 0.0:
113+
new_img = cv2.add(new_img, not_a)
114+
else:
115+
new_img = cv2.bitwise_not(new_img)
116+
return new_img
117+
118+
@staticmethod
119+
def verify_exist(file_path: Union[str, Path]):
120+
if not Path(file_path).exists():
121+
raise LoadImageError(f"{file_path} does not exist.")
122+
123+
124+
class LoadImageError(Exception):
125+
pass

0 commit comments

Comments
 (0)