Skip to content

Commit d31d432

Browse files
authored
Create shanghaitech.py
1 parent 1ed67fa commit d31d432

File tree

1 file changed

+271
-0
lines changed

1 file changed

+271
-0
lines changed

src/lib/dataset/shanghaitech.py

Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
# Written by Willy
4+
from __future__ import print_function, division
5+
from skimage import io, color
6+
from skimage import transform as sk_transform
7+
from torch.utils.data import Dataset, DataLoader
8+
from torchvision import transforms
9+
import os
10+
import torch
11+
import matplotlib.pyplot as plt
12+
import scipy.io as scio
13+
import warnings
14+
import numpy as np
15+
warnings.filterwarnings("ignore")
16+
import random
17+
from src.lib.utils.image_opt import genDensity,showGt,getPerspective,getLevel,getAttentionDensity,showMultiscale,findSigma
18+
19+
class IsColor(object):
20+
def __init__(self,color=True):
21+
self.color = color
22+
def __call__(self,sample):
23+
image = sample['image']
24+
if len(image.shape)==2:
25+
if self.color:
26+
image = color.gray2rgb(image)
27+
else:
28+
if not self.color:
29+
image = color.rgb2gray(image)
30+
_image = np.zeros((image.shape[0],image.shape[1],3),np.uint8)
31+
_image[:, :, 0] = image*255
32+
_image[:, :, 1] = image*255
33+
_image[:, :, 2] = image*255
34+
image = _image
35+
sample['image'] = image
36+
return sample
37+
38+
39+
class RandomFlip(object):
40+
def __call__(self, sample):
41+
if random.random() < 0.5:
42+
image = sample['image']
43+
sample['image'] = sample['image'][:,::-1]
44+
if len(sample['dots']) != 0:
45+
sample['dots'][:,0] = image.shape[1] -1 - sample['dots'][:,0]
46+
return sample
47+
48+
class PreferredSize(object):
49+
def __init__(self,size = 0, use_multiscale=False):
50+
self.size = size
51+
self.use_multiscale=use_multiscale
52+
53+
def __call__(self,sample):
54+
if self.size>0:
55+
image = sample['image']
56+
57+
h,w,c = image.shape
58+
ratio = 1
59+
if h > w:
60+
new_h, new_w = self.size, int(self.size * w / h+0.5)
61+
ratio = self.size/h
62+
else:
63+
new_h, new_w = int(self.size * h / w+0.5), self.size
64+
ratio = self.size/w
65+
66+
if self.use_multiscale:
67+
multi_img = sample['scale_images']
68+
69+
multi_img_new = []
70+
for img_s in multi_img:
71+
h_s,w_s,c_s = img_s.shape
72+
if h_s > w_s:
73+
new_h_s, new_w_s = self.size, int(self.size * w_s / h_s + 0.5)
74+
ratio_s = self.size / h_s
75+
else:
76+
new_h_s, new_w_s = int(self.size * h_s / w_s + 0.5), self.size
77+
ratio_s = self.size / w_s
78+
79+
resized_img_s = sk_transform.resize(img_s,(new_h_s,new_w_s),preserve_range=True)
80+
out_img_s = np.zeros((self.size,self.size,c_s),dtype=np.float32)
81+
out_img_s[...] = 127.5
82+
out_img_s[:new_h_s, :new_w_s, :] = resized_img_s
83+
multi_img_new.append(out_img_s)
84+
85+
sample['scale_images'] = np.array(multi_img_new)
86+
sample['dots'] = sample['dots']*ratio
87+
88+
resized_image = sk_transform.resize(image,(new_h,new_w),preserve_range = True)
89+
out_image = np.zeros((self.size,self.size,c),dtype=np.float32)
90+
out_image[...] = 127.5
91+
out_image[:new_h,:new_w,:] = resized_image
92+
sample['image'] = out_image
93+
return sample
94+
95+
96+
class NineCrop(object):
97+
def __call__(self,sample):
98+
image = sample['image']
99+
dots = sample['dots']
100+
sigmas = sample['sigma']
101+
h,w = image.shape[:2]
102+
i = random.randint(0,2)
103+
j = random.randint(0,2)
104+
left = int(w/4*i)
105+
top = int(h/4*j)
106+
width = int(w/2)
107+
height = int(h/2)
108+
109+
image = image[top: top + height,
110+
left: left + width]
111+
if len(dots) != 0:
112+
idx = np.where(
113+
(dots[:, 0] >= left) & (dots[:, 1] >= top) & (dots[:, 0] < left+width) & (dots[:, 1] < top+height))
114+
dots = dots[idx]
115+
dots[:,0] -= left
116+
dots[:,1] -= top
117+
idx = idx[0].tolist()
118+
119+
if len(idx)==0:
120+
sigmas = torch.FloatTensor([])
121+
else:
122+
sigmas = sigmas.index_select(0,torch.LongTensor(idx))
123+
124+
sample['image'] = image
125+
sample['dots'] = dots
126+
sample['sigma'] = sigmas
127+
return sample
128+
129+
class Multiscale(object):
130+
def __init__(self,cropscale=[]):
131+
self.cropscale = cropscale
132+
133+
134+
def __call__(self,sample):
135+
image = sample['image']
136+
137+
h,w = image.shape[:2]
138+
cx = int(w/2)
139+
cy = int(h/2)
140+
scale_img = []
141+
for i in self.cropscale:
142+
scale_img.append(image[cy - int(h*i/2): cy + int(h*i/2),
143+
cx - int(w*i/2): cx + int(w*i/2)])
144+
145+
sample['scale_images'] = np.array(scale_img)
146+
147+
return sample
148+
149+
class ToTensor(object):
150+
def __init__(self,rescale=1.0,margin_size = 1001,max_dot = 4000,use_att=False,use_multiscale=False):
151+
self.rescale = rescale
152+
self.margin_size = margin_size
153+
self.max_dot = max_dot
154+
self.use_att = use_att
155+
self.use_multiscale = use_multiscale
156+
157+
158+
def __call__(self,sample):
159+
image = sample['image']
160+
161+
dots = sample['dots']
162+
sigmas = sample['sigma']
163+
#sigmas = torch.FloatTensor(len(dots)).fill_(15.0)
164+
if self.use_att:
165+
densityMap = getAttentionDensity(image,3, dots, sigmas, self.margin_size,self.rescale)
166+
else:
167+
densityMap = genDensity(image, dots, sigmas, self.margin_size,self.rescale)
168+
169+
if np.sum(densityMap)!= 0:
170+
densityMap = densityMap * (len(dots) / np.sum(densityMap))
171+
#densityMap = torch.FloatTensor(densityMap)
172+
173+
image = image.transpose((2, 0, 1))
174+
if self.use_multiscale:
175+
multi_img = sample['scale_images']
176+
177+
multi_img = multi_img.transpose((0,3,1,2))
178+
179+
sample['scale_images'] = multi_img.astype(np.float32)
180+
181+
182+
outdots = np.zeros((self.max_dot,2))
183+
#outdots[:dots.shape[0],:] = dots
184+
count = len(dots)
185+
if count:
186+
outdots[:dots.shape[0],:] = dots
187+
sample['image'] = image.astype(np.float32)
188+
sample['densityMap'] = densityMap
189+
sample['dots'] = outdots
190+
sample['count'] = count
191+
sample.pop('sigma')
192+
return sample
193+
194+
195+
class Normalize(object):
196+
def __init__(self,use_multiscale=False):
197+
self.use_multiscale=use_multiscale
198+
199+
def __call__(self,sample):
200+
image = sample['image']
201+
image = (image - 127.5)/127.5
202+
sample['image'] = image
203+
if self.use_multiscale:
204+
multi_img = sample['scale_images']
205+
multi_img = (multi_img - 127.5)/127.5
206+
207+
sample['scale_images'] = multi_img
208+
return sample
209+
210+
class HeadCountDataset(Dataset):
211+
212+
def __init__(self,max_iter,phase, data_file, transform=None,use_pers=False,use_attention=True):
213+
self.data_file = data_file
214+
f = open(self.data_file,'r')
215+
self.data_idx = [i.strip() for i in f.readlines()]
216+
f.close()
217+
if phase == 'train':
218+
self.data_idx = self.data_idx * int(np.ceil(float(max_iter) / len(self.data_idx)))
219+
print('iteration length:', len(self.data_idx))
220+
self.transform = transform
221+
self.use_pmap = use_pers
222+
223+
self.root_path = os.path.abspath(os.path.dirname(__file__)+os.path.sep+"../../../")
224+
self.use_att = use_attention
225+
226+
def __len__(self):
227+
return len(self.data_idx)
228+
229+
def __getitem__(self, idx):
230+
line = self.data_idx[idx]
231+
img_path, dot_path = line.split(' ')
232+
image = io.imread(os.path.join(self.root_path, img_path)).astype(np.float32)
233+
notation = scio.loadmat(os.path.join(self.root_path, dot_path), struct_as_record=False, squeeze_me=True)
234+
info = notation['image_info']
235+
dots = info.location
236+
idx = np.where((dots[:,0]>=0)&(dots[:,1]>=0)&(dots[:,0]<image.shape[1])&(dots[:,1]<image.shape[0]))
237+
dots = dots[idx]
238+
#sigma = findSigma(dots,3,0.3)
239+
# sigma = torch.FloatTensor(len(dots)).fill_(15)
240+
if self.use_pmap:
241+
pmap_path = os.path.join(self.root_path, img_path.replace('images', 'pmap').replace('.jpg', '.mat'))
242+
pmap_mat = scio.loadmat(pmap_path)
243+
pmap = pmap_mat['pmap']
244+
sigma = getPerspective(dots, pmap)
245+
246+
else:
247+
sigma = findSigma(dots,3,0.3)
248+
if self.use_att:
249+
250+
atv = getLevel(3, 0.1, np.array([3,9, 27]),dots,5)
251+
252+
sample = {'image': image, 'dots': dots, 'sigma': atv, 'image_path': img_path}
253+
else:
254+
255+
sample = {'image': image, 'dots': dots,'sigma':sigma, 'image_path': img_path}
256+
if self.transform:
257+
sample = self.transform(sample)
258+
return sample
259+
260+
261+
if __name__ == '__main__':
262+
263+
headcount_dataset = HeadCountDataset(250000,'train','data/ShanghaiTech/part_B_final/train_data.txt',use_pers=False,use_attention=False,
264+
transform=transforms.Compose(
265+
[IsColor(True),NineCrop(), RandomFlip(),Multiscale(cropscale=[0.75, 0.5]),PreferredSize(512,use_multiscale=True), ToTensor(use_att=False,use_multiscale=True), Normalize(use_multiscale=True)]))
266+
dataloader = DataLoader(headcount_dataset, batch_size=1, shuffle=False, num_workers=1)
267+
268+
for i_batch, sample_batched in enumerate(dataloader):
269+
print('success')
270+
showMultiscale(sample_batched)
271+

0 commit comments

Comments
 (0)