Skip to content

remove redundant code for transforms #83

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindocr/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def build_dataset(
# TODO: find optimal setting automatically according to num of CPU cores
num_workers = loader_config.get("num_workers", 8) # Number of subprocesses used to fetch the dataset/map data row/gen batch in parallel
cores = multiprocessing.cpu_count()
num_devices = 1 if num_shards is None else num_shards
num_devices = 1 if num_shards is None else num_shards
if num_workers > int(cores / num_devices):
num_workers = int(cores / num_devices)
print('WARNING: num_workers is adjusted to {num_workers}, to fit {cores} CPU cores shared for {num_devices} devices')
Expand Down
207 changes: 1 addition & 206 deletions mindocr/data/transforms/det_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from shapely.geometry import Polygon
import numpy as np

__all__ = ['DetLabelEncode', 'BorderMap', 'ShrinkBinaryMap', 'EastRandomCropData', 'PSERandomCrop', 'expand_poly']
__all__ = ['DetLabelEncode', 'BorderMap', 'ShrinkBinaryMap', 'expand_poly']


class DetLabelEncode:
Expand Down Expand Up @@ -199,211 +199,6 @@ def _validate_polys(data):
data['polys'][i] = data['polys'][i][::-1]


# random crop algorithm similar to https://github.com/argman/EAST
# TODO: check randomness, seems to crop the same region every time
class EastRandomCropData:
"""
Code adopted from https://github1s.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py

Randomly select a region and crop images to a target size and make sure
to contain text region. This transform may break up text instances, and for
broken text instances, we will crop it's bbox and polygon coordinates. This
transform is recommend to be used in segmentation-based network.
"""
def __init__(self, size=(640, 640), max_tries=50, min_crop_side_ratio=0.1, require_original_image=False,
keep_ratio=True):
self.size = size
self.max_tries = max_tries
self.min_crop_side_ratio = min_crop_side_ratio
self.require_original_image = require_original_image
self.keep_ratio = keep_ratio

def __call__(self, data: dict) -> dict:
"""
从scales中随机选择一个尺度,对图片和文本框进行缩放
:param data: {'img':,'polys':,'texts':,'ignore_tags':}
:return:
"""
im = data['image']
text_polys = data['polys']
ignore_tags = data['ignore_tags']
texts = data['texts']
all_care_polys = [text_polys[i] for i, tag in enumerate(ignore_tags) if not tag]
# 计算crop区域
crop_x, crop_y, crop_w, crop_h = self.crop_area(im, all_care_polys)
print('crop area: ', crop_x, crop_y, crop_w, crop_h)
# crop 图片 保持比例填充
scale_w = self.size[0] / crop_w
scale_h = self.size[1] / crop_h
scale = min(scale_w, scale_h)
h = int(crop_h * scale)
w = int(crop_w * scale)
if self.keep_ratio:
if len(im.shape) == 3:
padimg = np.zeros((self.size[1], self.size[0], im.shape[2]), im.dtype)
else:
padimg = np.zeros((self.size[1], self.size[0]), im.dtype)
padimg[:h, :w] = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], (w, h))
img = padimg
else:
img = cv2.resize(im[crop_y:crop_y + crop_h, crop_x:crop_x + crop_w], tuple(self.size))
# crop 文本框
text_polys_crop = []
ignore_tags_crop = []
texts_crop = []
for poly, text, tag in zip(text_polys, texts, ignore_tags):
poly = ((poly - (crop_x, crop_y)) * scale).tolist()
if not self.is_poly_outside_rect(poly, 0, 0, w, h):
text_polys_crop.append(poly)
ignore_tags_crop.append(tag)
texts_crop.append(text)
data['image'] = img
data['polys'] = np.float32(text_polys_crop)
data['ignore_tags'] = ignore_tags_crop
data['texts'] = texts_crop
return data

def is_poly_in_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].min() < x or poly[:, 0].max() > x + w:
return False
if poly[:, 1].min() < y or poly[:, 1].max() > y + h:
return False
return True

def is_poly_outside_rect(self, poly, x, y, w, h):
poly = np.array(poly)
if poly[:, 0].max() < x or poly[:, 0].min() > x + w:
return True
if poly[:, 1].max() < y or poly[:, 1].min() > y + h:
return True
return False

def split_regions(self, axis):
regions = []
min_axis = 0
for i in range(1, axis.shape[0]):
if axis[i] != axis[i - 1] + 1:
region = axis[min_axis:i]
min_axis = i
regions.append(region)
return regions

def random_select(self, axis, max_size):
xx = np.random.choice(axis, size=2)
xmin = np.min(xx)
xmax = np.max(xx)
xmin = np.clip(xmin, 0, max_size - 1)
xmax = np.clip(xmax, 0, max_size - 1)
return xmin, xmax

def region_wise_random_select(self, regions, max_size):
selected_index = list(np.random.choice(len(regions), 2))
selected_values = []
for index in selected_index:
axis = regions[index]
xx = int(np.random.choice(axis, size=1))
selected_values.append(xx)
xmin = min(selected_values)
xmax = max(selected_values)
return xmin, xmax

def crop_area(self, im, text_polys):
h, w = im.shape[:2]
h_array = np.zeros(h, dtype=np.int32)
w_array = np.zeros(w, dtype=np.int32)
for points in text_polys:
points = np.round(points, decimals=0).astype(np.int32)
minx = np.min(points[:, 0])
maxx = np.max(points[:, 0])
w_array[minx:maxx] = 1
miny = np.min(points[:, 1])
maxy = np.max(points[:, 1])
h_array[miny:maxy] = 1
# ensure the cropped area not across a text
h_axis = np.where(h_array == 0)[0]
w_axis = np.where(w_array == 0)[0]

if len(h_axis) == 0 or len(w_axis) == 0:
return 0, 0, w, h

h_regions = self.split_regions(h_axis)
w_regions = self.split_regions(w_axis)

for i in range(self.max_tries):
if len(w_regions) > 1:
xmin, xmax = self.region_wise_random_select(w_regions, w)
else:
xmin, xmax = self.random_select(w_axis, w)
if len(h_regions) > 1:
ymin, ymax = self.region_wise_random_select(h_regions, h)
else:
ymin, ymax = self.random_select(h_axis, h)

if xmax - xmin < self.min_crop_side_ratio * w or ymax - ymin < self.min_crop_side_ratio * h:
# area too small
continue
num_poly_in_rect = 0
for poly in text_polys:
if not self.is_poly_outside_rect(poly, xmin, ymin, xmax - xmin, ymax - ymin):
num_poly_in_rect += 1
break

if num_poly_in_rect > 0:
return xmin, ymin, xmax - xmin, ymax - ymin

return 0, 0, w, h


class PSERandomCrop:
"""
Code adopted from https://github1s.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/random_crop_data.py
"""
def __init__(self, size):
self.size = size

def __call__(self, data):
imgs = data['imgs']

h, w = imgs[0].shape[0:2]
th, tw = self.size
if w == tw and h == th:
return imgs

# label中存在文本实例,并且按照概率进行裁剪,使用threshold_label_map控制
if np.max(imgs[2]) > 0 and random.random() > 3 / 8:
# 文本实例的左上角点
tl = np.min(np.where(imgs[2] > 0), axis=1) - self.size
tl[tl < 0] = 0
# 文本实例的右下角点
br = np.max(np.where(imgs[2] > 0), axis=1) - self.size
br[br < 0] = 0
# 保证选到右下角点时,有足够的距离进行crop
br[0] = min(br[0], h - th)
br[1] = min(br[1], w - tw)

for _ in range(50000):
i = random.randint(tl[0], br[0])
j = random.randint(tl[1], br[1])
# 保证shrink_label_map有文本
if imgs[1][i:i + th, j:j + tw].sum() <= 0:
continue
else:
break
else:
i = random.randint(0, h - th)
j = random.randint(0, w - tw)

# return i, j, th, tw
for idx in range(len(imgs)):
if len(imgs[idx].shape) == 3:
imgs[idx] = imgs[idx][i:i + th, j:j + tw, :]
else:
imgs[idx] = imgs[idx][i:i + th, j:j + tw]
data['imgs'] = imgs
return data


def expand_poly(poly, distance: float, joint_type=pyclipper.JT_ROUND) -> List[list]:
offset = pyclipper.PyclipperOffset()
offset.AddPath(poly, joint_type, pyclipper.ET_CLOSEDPOLYGON)
Expand Down
1 change: 0 additions & 1 deletion mindocr/data/transforms/modelzoo_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
__all__ = ['MZRandomColorAdjust', 'MZScalePad', 'MZResizeByGrid', 'MZRandomCropData', 'MZRandomScaleByShortSide']


# TODO: Does it support BGR mode?
class MZRandomColorAdjust:
def __init__(self, brightness=32.0 / 255, saturation=0.5, to_numpy=False):
self.colorjitter = RandomColorAdjust(brightness=brightness, saturation=saturation)
Expand Down
153 changes: 0 additions & 153 deletions tests/st/det_db_test.yaml

This file was deleted.

Loading