Skip to content

Commit

Permalink
Update model.py
Browse files Browse the repository at this point in the history
w_g=4
change to PIL to get mask, rewrite dataloader for fast data input


Former-commit-id: 9ffbe6b
Former-commit-id: 830f75c4c7d74eae141fba9c12982429a8fc48b8
  • Loading branch information
yu45020 committed Jun 30, 2018
1 parent 1288d1d commit 7f52590
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 30 deletions.
9 changes: 8 additions & 1 deletion ReadME.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,14 @@ I don't use a text-detection model such as Textbox Plus Plus, Single Shot MultiB
To generate training data, I use two copies of images: one is the origin, and the other one is text clean only. These images are abundant and easy to obtain from either targeted users or web-scraping. By subtracting the two, I get a mask that shows the text region. The idea is inspired by He, etc's [Single Shot Text Detector with Regional Attention](https://arxiv.org/abs/1709.00138) and He,etc's [Mask-R-CNN](https://arxiv.org/abs/1703.06870). Both papers show a pixel level object detection.


Special notes on training the model: The model runs less than a second with 50 512x512 images in Nvidia P-100. The run time bottleneck will probably lay in CPU speed. Try to get more CPUs and large memory when setting ```num_workers``` in PyTOrch's dataloader. 8 workers takes around 29 GB memory. However, I use PIL to process images and find CPUS always hand up for few seconds in every epochs.
Notes on training the model:

The model runs less than a second with 80 512x512 images in Nvidia P-100. The run time bottleneck will probably lay in CPU speed. Try to get more CPUs and large memory when setting ```num_workers``` in PyTOrch's dataloader. 6 workers takes around 10 GB memory. If CPUs are not fast enough to keep GPUs busy, please downscale the input images. The loss scores are similar for both origintal and downscaled version.

I train the model with Focal loss with gamma of 2 and alpha of 0.25, SGD wit Nesterov (momentum is 0.98), batch size is 80, and the learning rate is 0.1 which works surprisingly well until the loss goes to 0.00150. Then the learning rate descreases gradually to 0.008, but the model doesn't improve.

I am trying cross entropy loss with weights.


The model is trained on black/white images, but it also works for color images.

Expand Down
1 change: 0 additions & 1 deletion checkpoints/text_seg_model_380epos.pt.REMOVED.git-id

This file was deleted.

43 changes: 33 additions & 10 deletions dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from torch import nn
from torch.nn.functional import pad
from torch.utils.data import Dataset
from torchvision.transforms import ColorJitter, ToTensor, RandomResizedCrop, Compose, Normalize, transforms
from torchvision.transforms import ColorJitter, ToTensor, RandomResizedCrop, Compose, Normalize, transforms, Grayscale
from torchvision.transforms.functional import resized_crop, to_tensor

use_cuda = torch.cuda.is_available()
Expand All @@ -30,12 +30,12 @@ def __init__(self, img_folder, max_img=False, img_size=(512, 512)):
# get raw images
self.images = glob.glob(os.path.join(img_folder, "raw/*"))
assert len(self.images) > 0
self.max_img = max_img
self.max_img = max_img if max_img else len(self.images)
# if len(self.images) > max_img and max_img:
# self.images = random.choices(self.images, k=max_img)
self.images = random.choices(self.images, k=max_img)
self.images = random.choices(self.images, k=self.max_img)
print("Find {} images. ".format(len(self.images)))

self.grayscale = Grayscale(num_output_channels=1)
self.img_size = img_size
# image augment
self.transformer = Compose([ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
Expand All @@ -56,20 +56,20 @@ def __getitem__(self, item):

def process_images(self, raw, clean):
i, j, h, w = RandomResizedCrop.get_params(raw, scale=(0.1, 2.0), ratio=(3. / 4., 4. / 3.))
raw_img = resized_crop(raw, i, j, h, w, size=self.img_size)
clean_img = resized_crop(clean, i, j, h, w, self.img_size)
raw_img = resized_crop(raw, i, j, h, w, size=self.img_size, interpolation=Image.BICUBIC)
clean_img = resized_crop(clean, i, j, h, w, self.img_size, interpolation=Image.BICUBIC)

# get mask before further image augment
mask_tensor_long = self.get_mask(raw_img, clean_img)
raw_img = self.transformer(raw_img)
return raw_img, mask_tensor_long

@staticmethod
def get_mask(raw_pil, clean_pil):
def get_mask(self, raw_pil, clean_pil):
# use PIL ! It will take care the difference in brightness/contract
mask = ImageChops.difference(raw_pil, clean_pil)
mask = self.grayscale(mask) # single channel
mask = to_tensor(mask)
mask = mask[:1, :, :] > brightness_difference # single channel
mask = mask > brightness_difference
return mask.long()


Expand Down Expand Up @@ -100,7 +100,7 @@ def resize_pad_tensor(self, pil_img):
long = max(pil_img.size)
ratio = fix_len / long
new_size = tuple(map(lambda x: int(x * ratio), pil_img.size))
img = pil_img.resize(new_size, Image.LANCZOS)
img = pil_img.resize(new_size, Image.BICUBIC)
# img = pil_img
img = self.transformer(img)

Expand All @@ -116,6 +116,29 @@ def resize_pad_tensor(self, pil_img):
mask_resizer = self.resize_mask(boarder_pad, pil_img.size)
return self.normalizer(img), origin, mask_resizer

#
# def resize_pad_tensor(self, pil_img):
# origin = self.transformer(pil_img)
# fix_len = 512
# long = min(pil_img.size)
# ratio = fix_len / long
# new_size = tuple(map(lambda x: int(x * ratio), pil_img.size))
# img = pil_img.resize(new_size, Image.BICUBIC)
# # img = pil_img
# img = self.transformer(img)
#
# _, _, h, w = img.size()
# if w > fix_len:
#
# boarder_pad = (0, w-fix_len, 0, 0)
# else:
#
# boarder_pad = (0, 0, 0, h-fix_len)
#
# img = pad(img, boarder_pad, value=0)
# mask_resizer = self.resize_mask(boarder_pad, pil_img.size)
# return self.normalizer(img), origin, mask_resizer

@staticmethod
def resize_mask(padded_values, origin_size):
unpad = tuple(map(lambda x: -x, padded_values))
Expand Down
7 changes: 4 additions & 3 deletions models/BaseModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def set_activation_inplace(self):
yield

def total_parameters(self):
return sum([i.numel() for i in self.parameters()])
total = sum([i.numel() for i in self.parameters()])
trainable = sum([i.numel() for i in self.parameters() if i.requires_grad])
print("Total parameters : {}. Trainable parameters : {}".format(total, trainable))
return total

def forward(self, *x):
raise NotImplementedError
Expand All @@ -105,5 +108,3 @@ def Conv_block(in_channels, out_channels, kernel_size, stride=1,
if activation:
m.append(activation)
return m


2 changes: 1 addition & 1 deletion models/MobileNetV2.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def __init__(self, in_channel, out_channel, stride, expand_ratio, dilation, conv
self.stride = stride
self.act_fn = activation
self.bias = bias
assert stride in [1, 2]
# assert stride in [1, 2]

self.res_connect = self.stride == 1 and in_channel == out_channel
self.conv = self.make_body(in_channel, out_channel, stride, expand_ratio, dilation)
Expand Down
19 changes: 5 additions & 14 deletions models/text_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint

from .BaseModels import BaseModule, Conv_block
from .MobileNetV2 import MobileNetV2
Expand Down Expand Up @@ -37,7 +36,7 @@ def __init__(self, pre_train_checkpoint=None, free_last_blocks=None, add_partial
print("Encoder check point is loaded")
else:
print("No check point for the encoder is loaded. ")
if free_last_blocks:
if free_last_blocks >= 0:
self.freeze_params(free_last_blocks)

else:
Expand Down Expand Up @@ -84,25 +83,17 @@ def forward(self, x):
asp_pool = [layer(x) for layer in self.asp.children()]
return torch.cat(avg_pool + asp_pool, dim=1)

def forward_checkpoint(self, x):
avg_pool = checkpoint(self.img_pooling_1, x)
avg_pool = F.upsample(avg_pool, size=x.shape[2:], mode='bilinear')
avg_pool = [checkpoint(self.img_pooling_2, avg_pool)]

asp_pool = [checkpoint(layer, x) for layer in self.asp.children()]
return torch.cat(avg_pool + asp_pool, dim=1)


class TextSegament(BaseModule):
def __init__(self, encoder_checkpoint=None, free_last_blocks=None):
def __init__(self, encoder_checkpoint=None, free_last_blocks=False, width_mult=1):
super(TextSegament, self).__init__()
self.act_fn = nn.SELU()
self.bias = True

self.layer_4x_conv = nn.Sequential(*Conv_block(24, 128, kernel_size=3, padding=1,
self.layer_4x_conv = nn.Sequential(*Conv_block(int(24 * width_mult), 128, kernel_size=3, padding=1,
bias=self.bias, BN=True, activation=self.act_fn))
# self.encoder.last_channel --|
self.feature_pooling = ASP(320, out_channel=256)
self.feature_pooling = ASP(int(320 * width_mult), out_channel=256)

# decoder
self.transition_2_decoder = nn.Sequential(*Conv_block(256 * 5, 128, kernel_size=1,
Expand All @@ -123,7 +114,7 @@ def __init__(self, encoder_checkpoint=None, free_last_blocks=None):
else:
self.initialize_weights()
# use the pre-train weights to initialize the model
self.encoder = MobileNetEncoder(encoder_checkpoint, free_last_blocks,
self.encoder = MobileNetEncoder(encoder_checkpoint, free_last_blocks, width_mult=width_mult,
activation=nn.ReLU6(), bias=False) # may need to retrain the last 4 layers

def forward(self, x):
Expand Down

0 comments on commit 7f52590

Please sign in to comment.