forked from wuzhenyubuaa/TSNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e6dab0b
commit fd3530a
Showing
14 changed files
with
904 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# A Deeper Look at Salient Object Detection:Two-stream Framework with a Small Training Dataset | ||
|
||
## abstract | ||
In this paper, we attempt to reveal the nuance in the training strategy of salient object detection, including the choice of training datasets and the amount of training dataset that the model requires. Furthermore, we also expose the ground-truth bias of existing salient object detection benchmarks and their detrimental effect on performance scores. Based on our discoveries, we proposed a new two-stream framework that was trained on a small training dataset. To effectively integrate features of different networks, we introduced a novel gate control mechanism for the fusion of two-stream networks that achieves consistent improvements over baseline fusion approaches. To preserves clear object boundaries, we also proposed a novel multi-layer attention module that utilizes high-level saliency activation maps to guide extract details information from low-level feature maps. Extensive experiment results demonstrate that our proposed model can more accurately highlight the salient objects with a small training dataset, and substantially improve the performance scores compared to the existing state-of-the-art saliency detection models. | ||
|
||
## Network architecture | ||
|
||
![fig1](./img/pipeline.png) | ||
|
||
|
||
##Requirements | ||
- Python 3.5 | ||
- OpenCV | ||
- PyTorch 0.4 | ||
|
||
### Visual comparison with previous start-of-the-arts | ||
|
||
![fig1](./img/sal_maps.png) | ||
|
||
## Usage | ||
Clone, and cd into the repo directory. | ||
|
||
|
||
|
||
git clone git@github.com:Diamond101010/TSNet.git | ||
|
||
Before you start, you also need our pretrained model. | ||
Then run | ||
|
||
cd examles | ||
python demo.py | ||
|
||
## Download | ||
|
||
We provide the results online datasets including [DUT-OMRON](https://drive.google.com/open?id=1hq6w_LhvMblyYdLFFskLtR77wm4NDFFm), [DUTS-TE](https://drive.google.com/open?id=1LYsFtnCOGiCSL4nyyD9UWw1T0gBo-34F), [ECSSD](https://drive.google.com/open?id=1QHkds8ZMAB_YdJZ8WaOb-mFQnHDa55Un), [HKU-IS](https://drive.google.com/open?id=1ApPVWLRDJDsT0iM54jZyevkErqcVPJSy), [PASCAL-S](https://drive.google.com/open?id=1jMuhfouo3sFXcDYHZtt8S7iWanUv4ftE) | ||
<hr> | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
import os | ||
from PIL import Image | ||
import torch.utils.data as data | ||
import torchvision.transforms as transforms | ||
|
||
|
||
class SalObjDataset(data.Dataset): | ||
def __init__(self, image_root, gt_root, trainsize): | ||
self.trainsize = trainsize | ||
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] | ||
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') | ||
or f.endswith('.png')] | ||
self.images = sorted(self.images) | ||
self.gts = sorted(self.gts) | ||
self.filter_files() | ||
self.size = len(self.images) | ||
self.img_transform = transforms.Compose([ | ||
transforms.Resize((self.trainsize, self.trainsize)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | ||
self.gt_transform = transforms.Compose([ | ||
transforms.Resize((self.trainsize, self.trainsize)), | ||
transforms.ToTensor()]) | ||
|
||
def __getitem__(self, index): | ||
image = self.rgb_loader(self.images[index]) | ||
gt = self.binary_loader(self.gts[index]) | ||
image = self.img_transform(image) | ||
gt = self.gt_transform(gt) | ||
return image, gt | ||
|
||
def filter_files(self): | ||
assert len(self.images) == len(self.gts) | ||
images = [] | ||
gts = [] | ||
for img_path, gt_path in zip(self.images, self.gts): | ||
img = Image.open(img_path) | ||
gt = Image.open(gt_path) | ||
if img.size == gt.size: | ||
images.append(img_path) | ||
gts.append(gt_path) | ||
self.images = images | ||
self.gts = gts | ||
|
||
def rgb_loader(self, path): | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
return img.convert('RGB') | ||
|
||
def binary_loader(self, path): | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
# return img.convert('1') | ||
return img.convert('L') | ||
|
||
def resize(self, img, gt): | ||
assert img.size == gt.size | ||
w, h = img.size | ||
if h < self.trainsize or w < self.trainsize: | ||
h = max(h, self.trainsize) | ||
w = max(w, self.trainsize) | ||
return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) | ||
else: | ||
return img, gt | ||
|
||
def __len__(self): | ||
return self.size | ||
|
||
|
||
def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): | ||
|
||
dataset = SalObjDataset(image_root, gt_root, trainsize) | ||
data_loader = data.DataLoader(dataset=dataset, | ||
batch_size=batchsize, | ||
shuffle=shuffle, | ||
num_workers=num_workers, | ||
pin_memory=pin_memory) | ||
return data_loader | ||
|
||
|
||
class test_dataset: | ||
def __init__(self, image_root, gt_root, testsize): | ||
self.testsize = testsize | ||
self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] | ||
self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') | ||
or f.endswith('.png')] | ||
self.images = sorted(self.images) | ||
self.gts = sorted(self.gts) | ||
self.transform = transforms.Compose([ | ||
transforms.Resize((self.testsize, self.testsize)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) | ||
self.gt_transform = transforms.ToTensor() | ||
self.size = len(self.images) | ||
self.index = 0 | ||
|
||
def load_data(self): | ||
image = self.rgb_loader(self.images[self.index]) | ||
image = self.transform(image).unsqueeze(0) | ||
gt = self.binary_loader(self.gts[self.index]) | ||
name = self.images[self.index].split('/')[-1] | ||
if name.endswith('.jpg'): | ||
name = name.split('.jpg')[0] + '.png' | ||
self.index += 1 | ||
return image, gt, name | ||
|
||
def rgb_loader(self, path): | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
return img.convert('RGB') | ||
|
||
def binary_loader(self, path): | ||
with open(path, 'rb') as f: | ||
img = Image.open(f) | ||
return img.convert('L') | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
|
||
import numpy as np | ||
import pdb, os, argparse | ||
from scipy import misc | ||
|
||
from model.CPD_models import CPD_VGG | ||
from model.CPD_ResNet_models import CPD_ResNet | ||
from data import test_dataset | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--testsize', type=int, default=352, help='testing size') | ||
parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone') | ||
opt = parser.parse_args() | ||
|
||
dataset_path = 'path/dataset/' | ||
|
||
if opt.is_ResNet: | ||
model = CPD_ResNet() | ||
model.load_state_dict(torch.load('CPD-R.pth')) | ||
else: | ||
model = CPD_VGG() | ||
model.load_state_dict(torch.load('CPD.pth')) | ||
|
||
model.cuda() | ||
model.eval() | ||
|
||
test_datasets = ['PASCAL', 'ECSSD', 'DUT-OMRON', 'DUTS-TEST', 'HKUIS'] | ||
|
||
for dataset in test_datasets: | ||
if opt.is_ResNet: | ||
save_path = './results/ResNet50/' + dataset + '/' | ||
else: | ||
save_path = './results/VGG16/' + dataset + '/' | ||
if not os.path.exists(save_path): | ||
os.makedirs(save_path) | ||
image_root = dataset_path + dataset + '/images/' | ||
gt_root = dataset_path + dataset + '/gts/' | ||
test_loader = test_dataset(image_root, gt_root, opt.testsize) | ||
for i in range(test_loader.size): | ||
image, gt, name = test_loader.load_data() | ||
gt = np.asarray(gt, np.float32) | ||
gt /= (gt.max() + 1e-8) | ||
image = image.cuda() | ||
_, res = model(image) | ||
res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) | ||
res = res.sigmoid().data.cpu().numpy().squeeze() | ||
res = (res - res.min()) / (res.max() - res.min() + 1e-8) | ||
misc.imsave(save_path+name, res) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
from torch.nn.parameter import Parameter | ||
|
||
import numpy as np | ||
import scipy.stats as st | ||
|
||
|
||
def gkern(kernlen=16, nsig=3): | ||
interval = (2*nsig+1.)/kernlen | ||
x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) | ||
kern1d = np.diff(st.norm.cdf(x)) | ||
kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) | ||
kernel = kernel_raw/kernel_raw.sum() | ||
return kernel | ||
|
||
|
||
def min_max_norm(in_): | ||
max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) | ||
min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) | ||
in_ = in_ - min_ | ||
return in_.div(max_-min_+1e-8) | ||
|
||
|
||
class HA(nn.Module): | ||
# holistic attention module | ||
def __init__(self): | ||
super(HA, self).__init__() | ||
gaussian_kernel = np.float32(gkern(31, 4)) | ||
gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...] | ||
self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel)) | ||
|
||
def forward(self, attention, x): | ||
soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15) | ||
soft_attention = min_max_norm(soft_attention) | ||
x = torch.mul(x, soft_attention.max(attention)) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import torch.nn as nn | ||
import math | ||
|
||
|
||
def conv3x3(in_planes, out_planes, stride=1): | ||
"""3x3 convolution with padding""" | ||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
|
||
|
||
class BasicBlock(nn.Module): | ||
expansion = 1 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(BasicBlock, self).__init__() | ||
self.conv1 = conv3x3(inplanes, planes, stride) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.conv2 = conv3x3(planes, planes) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class Bottleneck(nn.Module): | ||
expansion = 4 | ||
|
||
def __init__(self, inplanes, planes, stride=1, downsample=None): | ||
super(Bottleneck, self).__init__() | ||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(planes) | ||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, | ||
padding=1, bias=False) | ||
self.bn2 = nn.BatchNorm2d(planes) | ||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) | ||
self.bn3 = nn.BatchNorm2d(planes * 4) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.downsample = downsample | ||
self.stride = stride | ||
|
||
def forward(self, x): | ||
residual = x | ||
|
||
out = self.conv1(x) | ||
out = self.bn1(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv2(out) | ||
out = self.bn2(out) | ||
out = self.relu(out) | ||
|
||
out = self.conv3(out) | ||
out = self.bn3(out) | ||
|
||
if self.downsample is not None: | ||
residual = self.downsample(x) | ||
|
||
out += residual | ||
out = self.relu(out) | ||
|
||
return out | ||
|
||
|
||
class B2_ResNet(nn.Module): | ||
# ResNet50 with two branches | ||
def __init__(self): | ||
# self.inplanes = 128 | ||
self.inplanes = 64 | ||
super(B2_ResNet, self).__init__() | ||
|
||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, | ||
bias=False) | ||
self.bn1 = nn.BatchNorm2d(64) | ||
self.relu = nn.ReLU(inplace=True) | ||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) | ||
self.layer1 = self._make_layer(Bottleneck, 64, 3) | ||
self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) | ||
self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2) | ||
self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2) | ||
|
||
self.inplanes = 512 | ||
self.layer3_2 = self._make_layer(Bottleneck, 256, 6, stride=2) | ||
self.layer4_2 = self._make_layer(Bottleneck, 512, 3, stride=2) | ||
|
||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | ||
m.weight.data.normal_(0, math.sqrt(2. / n)) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
m.weight.data.fill_(1) | ||
m.bias.data.zero_() | ||
|
||
def _make_layer(self, block, planes, blocks, stride=1): | ||
downsample = None | ||
if stride != 1 or self.inplanes != planes * block.expansion: | ||
downsample = nn.Sequential( | ||
nn.Conv2d(self.inplanes, planes * block.expansion, | ||
kernel_size=1, stride=stride, bias=False), | ||
nn.BatchNorm2d(planes * block.expansion), | ||
) | ||
|
||
layers = [] | ||
layers.append(block(self.inplanes, planes, stride, downsample)) | ||
self.inplanes = planes * block.expansion | ||
for i in range(1, blocks): | ||
layers.append(block(self.inplanes, planes)) | ||
|
||
return nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.bn1(x) | ||
x = self.relu(x) | ||
x = self.maxpool(x) | ||
|
||
x = self.layer1(x) | ||
x = self.layer2(x) | ||
x1 = self.layer3_1(x) | ||
x1 = self.layer4_1(x1) | ||
|
||
x2 = self.layer3_2(x) | ||
x2 = self.layer4_2(x2) | ||
|
||
return x1, x2 |
Oops, something went wrong.