Skip to content
This repository was archived by the owner on Aug 17, 2024. It is now read-only.

Commit 2c038ca

Browse files
committed
perf(data): 测试集进行随机旋转;使用随机像素进行边界填充
1 parent 87f831a commit 2c038ca

File tree

5 files changed

+33
-15
lines changed

5 files changed

+33
-15
lines changed

rotnet/data/datasets/base_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __getitem__(self, index):
4040
if self.target_transform is not None:
4141
img, target = self.target_transform(img)
4242
else:
43-
# 假定所有训练/测试图像的旋转角度为0
43+
# 假定所有训练/测试图像的初始旋转角度为0
4444
target = 0
4545

4646
# doing this so that it is consistent with all other datasets

rotnet/data/transforms/build.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
def build_transform(cfg, train=True):
1717
size = cfg.MODEL.INPUT_SIZE
1818

19-
target_transform = None
2019
if train:
2120
transform = transforms.Compose([
2221
transforms.Resize(size),
@@ -26,17 +25,17 @@ def build_transform(cfg, train=True):
2625
transforms.Normalize((0.5,), (0.5,)),
2726
transforms.RandomErasing()
2827
])
29-
30-
target_transform = Compose([
31-
Rotate(),
32-
ToGray(),
33-
])
3428
else:
3529
transform = transforms.Compose([
3630
transforms.Resize(size),
3731
transforms.Grayscale(),
3832
transforms.ToTensor(),
3933
transforms.Normalize((0.5,), (0.5,)),
4034
])
35+
# 对测试集和训练集都进行随机旋转
36+
target_transform = Compose([
37+
Rotate(random=True),
38+
ToGray(),
39+
])
4140

4241
return transform, target_transform

rotnet/data/transforms/rotate.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,39 @@
77
@description:
88
"""
99

10-
import math
11-
import random
12-
import cv2
1310
import numpy as np
1411

1512
from rotnet.util.utils import rotate
1613

1714

1815
class Rotate:
1916

17+
def __init__(self, random=False, borderValue=(255, 255, 255)):
18+
"""
19+
:param random: 默认为False,表示使用borderValue指定的边界填充值;如果为True,则忽略borderValue,随机选择填充值
20+
:param borderValue: 边界填充值
21+
"""
22+
self.random = random
23+
self.borderValue = borderValue
24+
2025
def __call__(self, img: np.ndarray, angle=None):
26+
"""
27+
:param img:
28+
:param angle: 如果为None,则随机选择[0,360)的旋转角度
29+
:return:
30+
"""
2131
assert isinstance(img, np.ndarray)
2232

23-
angle = random.randint(0, 359)
24-
rotate_img = rotate(img, angle)
33+
low = 0
34+
high = 360
35+
if not angle:
36+
angle = np.random.randint(low, high=high)
37+
if self.random:
38+
high = 260
39+
borderValue = (
40+
np.random.randint(low, high=high), np.random.randint(low, high=high), np.random.randint(low, high=high))
41+
else:
42+
borderValue = self.borderValue
43+
rotate_img = rotate(img, angle, borderValue=borderValue)
2544

2645
return rotate_img, angle

rotnet/engine/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def do_train(cfg, arguments,
9595
eval_results = do_evaluation(cfg, model, device)
9696
if summary_writer:
9797
for eval_result, dataset_name in zip(eval_results, cfg.DATASETS.TEST):
98-
summary_writer.add_scalar(f'loss/{dataset_name}', eval_result[dataset_name], global_step=iteration)
98+
summary_writer.add_scalar(f'eval/{dataset_name}', eval_result[dataset_name], global_step=iteration)
9999
model.train() # *IMPORTANT*: change to train mode after eval.
100100

101101
if summary_writer:

rotnet/util/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import cv2
1212

1313

14-
def rotate(img, degree):
14+
def rotate(img, degree, borderValue=(255, 255, 255)):
1515
h, w = img.shape[:2]
1616
center = (w // 2, h // 2)
1717

@@ -21,7 +21,7 @@ def rotate(img, degree):
2121
matrix = cv2.getRotationMatrix2D(center, degree, 1)
2222
matrix[0, 2] += dst_w // 2 - center[0]
2323
matrix[1, 2] += dst_h // 2 - center[1]
24-
dst_img = cv2.warpAffine(img, matrix, (dst_w, dst_h), borderValue=(255, 255, 255))
24+
dst_img = cv2.warpAffine(img, matrix, (dst_w, dst_h), borderValue=borderValue)
2525

2626
# imshow(img, 'src')
2727
# imshow(dst_img, 'dst')

0 commit comments

Comments
 (0)