Skip to content

Commit eeda45d

Browse files
committed
instead cv2 LANCZOS4 resize to pil resize
1 parent 51d57f0 commit eeda45d

File tree

5 files changed

+36
-15
lines changed

5 files changed

+36
-15
lines changed

finetune/tag_images_by_wd14_tagger.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tqdm import tqdm
1212

1313
import library.train_util as train_util
14-
from library.utils import setup_logging
14+
from library.utils import setup_logging, pil_resize
1515

1616
setup_logging()
1717
import logging
@@ -42,8 +42,10 @@ def preprocess_image(image):
4242
pad_t = pad_y // 2
4343
image = np.pad(image, ((pad_t, pad_y - pad_t), (pad_l, pad_x - pad_l), (0, 0)), mode="constant", constant_values=255)
4444

45-
interp = cv2.INTER_AREA if size > IMAGE_SIZE else cv2.INTER_LANCZOS4
46-
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), interpolation=interp)
45+
if size > IMAGE_SIZE:
46+
image = cv2.resize(image, (IMAGE_SIZE, IMAGE_SIZE), cv2.INTER_AREA)
47+
else:
48+
image = pil_resize(image, (IMAGE_SIZE, IMAGE_SIZE))
4749

4850
image = image.astype(np.float32)
4951
return image

library/train_util.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
import library.huggingface_util as huggingface_util
7272
import library.sai_model_spec as sai_model_spec
7373
import library.deepspeed_utils as deepspeed_utils
74-
from library.utils import setup_logging
74+
from library.utils import setup_logging, pil_resize
7575

7676
setup_logging()
7777
import logging
@@ -2028,9 +2028,7 @@ def __getitem__(self, index):
20282028
# ), f"image size is small / 画像サイズが小さいようです: {image_info.absolute_path}"
20292029
# resize to target
20302030
if cond_img.shape[0] != target_size_hw[0] or cond_img.shape[1] != target_size_hw[1]:
2031-
cond_img = cv2.resize(
2032-
cond_img, (int(target_size_hw[1]), int(target_size_hw[0])), interpolation=cv2.INTER_LANCZOS4
2033-
)
2031+
cond_img=pil_resize(cond_img,(int(target_size_hw[1]), int(target_size_hw[0])))
20342032

20352033
if flipped:
20362034
cond_img = cond_img[:, ::-1, :].copy() # copy to avoid negative stride
@@ -2362,7 +2360,10 @@ def trim_and_resize_if_required(
23622360

23632361
if image_width != resized_size[0] or image_height != resized_size[1]:
23642362
# リサイズする
2365-
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
2363+
if image_width > resized_size[0] and image_height > resized_size[1]:
2364+
image = cv2.resize(image, resized_size, interpolation=cv2.INTER_AREA) # INTER_AREAでやりたいのでcv2でリサイズ
2365+
else:
2366+
image = pil_resize(image, resized_size)
23662367

23672368
image_height, image_width = image.shape[0:2]
23682369

library/utils.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from diffusers import EulerAncestralDiscreteScheduler
88
import diffusers.schedulers.scheduling_euler_ancestral_discrete
99
from diffusers.schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteSchedulerOutput
10-
10+
import cv2
11+
from PIL import Image
12+
import numpy as np
1113

1214
def fire_in_thread(f, *args, **kwargs):
1315
threading.Thread(target=f, args=args, kwargs=kwargs).start()
@@ -78,7 +80,17 @@ def setup_logging(args=None, log_level=None, reset=False):
7880
logger = logging.getLogger(__name__)
7981
logger.info(msg_init)
8082

83+
def pil_resize(image, size, interpolation=Image.LANCZOS):
84+
85+
pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
86+
87+
# use Pillow resize
88+
resized_pil = pil_image.resize(size, interpolation)
89+
90+
# return cv2 image
91+
resized_cv2 = cv2.cvtColor(np.array(resized_pil), cv2.COLOR_RGB2BGR)
8192

93+
return resized_cv2
8294

8395
# TODO make inf_utils.py
8496

tools/detect_face_rotate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from anime_face_detector import create_detector
1616
from tqdm import tqdm
1717
import numpy as np
18-
from library.utils import setup_logging
18+
from library.utils import setup_logging, pil_resize
1919
setup_logging()
2020
import logging
2121
logger = logging.getLogger(__name__)
@@ -172,7 +172,10 @@ def process(args):
172172
if scale != 1.0:
173173
w = int(w * scale + .5)
174174
h = int(h * scale + .5)
175-
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA if scale < 1.0 else cv2.INTER_LANCZOS4)
175+
if scale < 1.0:
176+
face_img = cv2.resize(face_img, (w, h), interpolation=cv2.INTER_AREA)
177+
else:
178+
face_img = pil_resize(face_img, (w, h))
176179
cx = int(cx * scale + .5)
177180
cy = int(cy * scale + .5)
178181
fw = int(fw * scale + .5)

tools/resize_images_to_resolution.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import math
77
from PIL import Image
88
import numpy as np
9-
from library.utils import setup_logging
9+
from library.utils import setup_logging, pil_resize
1010
setup_logging()
1111
import logging
1212
logger = logging.getLogger(__name__)
@@ -24,9 +24,9 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
2424

2525
# Select interpolation method
2626
if interpolation == 'lanczos4':
27-
cv2_interpolation = cv2.INTER_LANCZOS4
27+
pil_interpolation = Image.LANCZOS
2828
elif interpolation == 'cubic':
29-
cv2_interpolation = cv2.INTER_CUBIC
29+
pil_interpolation = Image.BICUBIC
3030
else:
3131
cv2_interpolation = cv2.INTER_AREA
3232

@@ -64,7 +64,10 @@ def resize_images(src_img_folder, dst_img_folder, max_resolution="512x512", divi
6464
new_width = int(img.shape[1] * math.sqrt(scale_factor))
6565

6666
# Resize image
67-
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
67+
if cv2_interpolation:
68+
img = cv2.resize(img, (new_width, new_height), interpolation=cv2_interpolation)
69+
else:
70+
img = pil_resize(img, (new_width, new_height), interpolation=pil_interpolation)
6871
else:
6972
new_height, new_width = img.shape[0:2]
7073

0 commit comments

Comments
 (0)