Skip to content

Commit 84b2855

Browse files
authored
Merge pull request #577 from will-am/fluid_se_resnext
Add fluid version of SE-ResNeXt
2 parents 0e844a1 + 670090a commit 84b2855

File tree

2 files changed

+282
-0
lines changed

2 files changed

+282
-0
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import os
2+
import random
3+
import functools
4+
import numpy as np
5+
import paddle.v2 as paddle
6+
from PIL import Image, ImageEnhance
7+
8+
random.seed(0)
9+
10+
_R_MEAN = 123.0
11+
_G_MEAN = 117.0
12+
_B_MEAN = 104.0
13+
14+
DATA_DIM = 224
15+
16+
THREAD = 8
17+
BUF_SIZE = 1024
18+
19+
DATA_DIR = 'ILSVRC2012'
20+
TRAIN_LIST = 'ILSVRC2012/train_list.txt'
21+
TEST_LIST = 'ILSVRC2012/test_list.txt'
22+
23+
img_mean = np.array([_R_MEAN, _G_MEAN, _B_MEAN]).reshape((3, 1, 1))
24+
25+
26+
def resize_short(img, target_size):
27+
percent = float(target_size) / min(img.size[0], img.size[1])
28+
resized_width = int(round(img.size[0] * percent))
29+
resized_height = int(round(img.size[1] * percent))
30+
img = img.resize((resized_width, resized_height), Image.LANCZOS)
31+
return img
32+
33+
34+
def crop_image(img, target_size, center):
35+
width, height = img.size
36+
size = target_size
37+
if center == True:
38+
w_start = (width - size) / 2
39+
h_start = (height - size) / 2
40+
else:
41+
w_start = random.randint(0, width - size)
42+
h_start = random.randint(0, height - size)
43+
w_end = w_start + size
44+
h_end = h_start + size
45+
img = img.crop((w_start, h_start, w_end, h_end))
46+
return img
47+
48+
49+
def distort_color(img):
50+
def random_brightness(img, lower=0.5, upper=1.5):
51+
e = random.uniform(lower, upper)
52+
return ImageEnhance.Brightness(img).enhance(e)
53+
54+
def random_contrast(img, lower=0.5, upper=1.5):
55+
e = random.uniform(lower, upper)
56+
return ImageEnhance.Contrast(img).enhance(e)
57+
58+
def random_color(img, lower=0.5, upper=1.5):
59+
e = random.uniform(lower, upper)
60+
return ImageEnhance.Color(img).enhance(e)
61+
62+
ops = [random_brightness, random_contrast, random_color]
63+
random.shuffle(ops)
64+
65+
img = ops[0](img)
66+
img = ops[1](img)
67+
img = ops[2](img)
68+
69+
return img
70+
71+
72+
def process_image(sample, mode):
73+
img_path = sample[0]
74+
75+
img = Image.open(img_path)
76+
if mode == 'train':
77+
img = resize_short(img, DATA_DIM + 32)
78+
else:
79+
img = resize_short(img, DATA_DIM)
80+
img = crop_image(img, target_size=DATA_DIM, center=(mode != 'train'))
81+
if mode == 'train':
82+
img = distort_color(img)
83+
if random.randint(0, 1) == 1:
84+
img = img.transpose(Image.FLIP_LEFT_RIGHT)
85+
86+
if img.mode != 'RGB':
87+
img = img.convert('RGB')
88+
89+
img = np.array(img).astype('float32').transpose((2, 0, 1))
90+
img -= img_mean
91+
92+
if mode == 'train' or mode == 'test':
93+
return img, sample[1]
94+
elif mode == 'infer':
95+
return img
96+
97+
98+
def _reader_creator(file_list, mode, shuffle=False):
99+
def reader():
100+
with open(file_list) as flist:
101+
lines = [line.strip() for line in flist]
102+
if shuffle:
103+
random.shuffle(lines)
104+
for line in lines:
105+
if mode == 'train' or mode == 'test':
106+
img_path, label = line.split()
107+
img_path = os.path.join(DATA_DIR, img_path)
108+
yield img_path, int(label)
109+
elif mode == 'infer':
110+
img_path = os.path.join(DATA_DIR, line)
111+
yield [img_path]
112+
113+
mapper = functools.partial(process_image, mode=mode)
114+
115+
return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE)
116+
117+
118+
def train():
119+
return _reader_creator(TRAIN_LIST, 'train', shuffle=True)
120+
121+
122+
def test():
123+
return _reader_creator(TEST_LIST, 'test', shuffle=False)
124+
125+
126+
def infer(file_list):
127+
return _reader_creator(file_list, 'infer', shuffle=False)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
import os
2+
import paddle.v2 as paddle
3+
import paddle.v2.fluid as fluid
4+
import reader
5+
6+
7+
def conv_bn_layer(input, num_filters, filter_size, stride=1, groups=1,
8+
act=None):
9+
conv = fluid.layers.conv2d(
10+
input=input,
11+
num_filters=num_filters,
12+
filter_size=filter_size,
13+
stride=stride,
14+
padding=(filter_size - 1) / 2,
15+
groups=groups,
16+
act=None,
17+
bias_attr=False)
18+
return fluid.layers.batch_norm(input=conv, act=act)
19+
20+
21+
def squeeze_excitation(input, num_channels, reduction_ratio):
22+
pool = fluid.layers.pool2d(
23+
input=input, pool_size=0, pool_type='avg', global_pooling=True)
24+
squeeze = fluid.layers.fc(
25+
input=pool, size=num_channels / reduction_ratio, act='relu')
26+
excitation = fluid.layers.fc(
27+
input=squeeze, size=num_channels, act='sigmoid')
28+
scale = fluid.layers.elementwise_mul(x=input, y=excitation, axis=0)
29+
return scale
30+
31+
32+
def shortcut(input, ch_out, stride):
33+
ch_in = input.shape[1]
34+
if ch_in != ch_out:
35+
return conv_bn_layer(input, ch_out, 3, stride)
36+
else:
37+
return input
38+
39+
40+
def bottleneck_block(input, num_filters, stride, cardinality, reduction_ratio):
41+
conv0 = conv_bn_layer(
42+
input=input, num_filters=num_filters, filter_size=1, act='relu')
43+
conv1 = conv_bn_layer(
44+
input=conv0,
45+
num_filters=num_filters,
46+
filter_size=3,
47+
stride=stride,
48+
groups=cardinality,
49+
act='relu')
50+
conv2 = conv_bn_layer(
51+
input=conv1, num_filters=num_filters * 2, filter_size=1, act=None)
52+
scale = squeeze_excitation(
53+
input=conv2,
54+
num_channels=num_filters * 2,
55+
reduction_ratio=reduction_ratio)
56+
57+
short = shortcut(input, num_filters * 2, stride)
58+
59+
return fluid.layers.elementwise_add(x=short, y=scale, act='relu')
60+
61+
62+
def SE_ResNeXt(input, class_dim, infer=False):
63+
cardinality = 64
64+
reduction_ratio = 16
65+
depth = [3, 8, 36, 3]
66+
num_filters = [128, 256, 512, 1024]
67+
68+
conv = conv_bn_layer(
69+
input=input, num_filters=64, filter_size=3, stride=2, act='relu')
70+
conv = conv_bn_layer(
71+
input=conv, num_filters=64, filter_size=3, stride=1, act='relu')
72+
conv = conv_bn_layer(
73+
input=conv, num_filters=128, filter_size=3, stride=1, act='relu')
74+
conv = fluid.layers.pool2d(
75+
input=conv, pool_size=3, pool_stride=2, pool_type='max')
76+
77+
for block in range(len(depth)):
78+
for i in range(depth[block]):
79+
conv = bottleneck_block(
80+
input=conv,
81+
num_filters=num_filters[block],
82+
stride=2 if i == 0 and block != 0 else 1,
83+
cardinality=cardinality,
84+
reduction_ratio=reduction_ratio)
85+
86+
pool = fluid.layers.pool2d(
87+
input=conv, pool_size=0, pool_type='avg', global_pooling=True)
88+
if not infer:
89+
drop = fluid.layers.dropout(x=pool, dropout_prob=0.2)
90+
else:
91+
drop = pool
92+
out = fluid.layers.fc(input=drop, size=class_dim, act='softmax')
93+
return out
94+
95+
96+
def train(learning_rate, batch_size, num_passes, model_save_dir='model'):
97+
class_dim = 1000
98+
image_shape = [3, 224, 224]
99+
100+
image = fluid.layers.data(name='image', shape=image_shape, dtype='float32')
101+
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
102+
103+
out = SE_ResNeXt(input=image, class_dim=class_dim)
104+
105+
cost = fluid.layers.cross_entropy(input=out, label=label)
106+
avg_cost = fluid.layers.mean(x=cost)
107+
108+
optimizer = fluid.optimizer.Momentum(
109+
learning_rate=learning_rate / batch_size,
110+
momentum=0.9,
111+
regularization=fluid.regularizer.L2Decay(1e-4 * batch_size))
112+
opts = optimizer.minimize(avg_cost)
113+
accuracy = fluid.evaluator.Accuracy(input=out, label=label)
114+
115+
inference_program = fluid.default_main_program().clone()
116+
with fluid.program_guard(inference_program):
117+
test_accuracy = fluid.evaluator.Accuracy(input=out, label=label)
118+
test_target = [avg_cost] + test_accuracy.metrics + test_accuracy.states
119+
inference_program = fluid.io.get_inference_program(test_target)
120+
121+
place = fluid.CUDAPlace(0)
122+
exe = fluid.Executor(place)
123+
exe.run(fluid.default_startup_program())
124+
125+
train_reader = paddle.batch(datareader.train(), batch_size=batch_size)
126+
test_reader = paddle.batch(datareader.test(), batch_size=batch_size)
127+
feeder = fluid.DataFeeder(place=place, feed_list=[image, label])
128+
129+
for pass_id in range(num_passes):
130+
accuracy.reset(exe)
131+
for batch_id, data in enumerate(train_reader()):
132+
loss, acc = exe.run(
133+
fluid.default_main_program(),
134+
feed=feeder.feed(data),
135+
fetch_list=[avg_cost] + accuracy.metrics)
136+
print("Pass {0}, batch {1}, loss {2}, acc {3}".format(
137+
pass_id, batch_id, loss[0], acc[0]))
138+
pass_acc = accuracy.eval(exe)
139+
140+
test_accuracy.reset(exe)
141+
for data in test_reader():
142+
out, acc = exe.run(
143+
inference_program,
144+
feed=feeder.feed(data),
145+
fetch_list=[avg_cost] + test_accuracy.metrics)
146+
test_pass_acc = test_accuracy.eval(exe)
147+
print("End pass {0}, train_acc {1}, test_acc {2}".format(
148+
pass_id, pass_acc, test_pass_acc))
149+
150+
model_path = os.path.join(model_save_dir, str(pass_id))
151+
fluid.io.save_inference_model(model_path, ['image'], [out], exe)
152+
153+
154+
if __name__ == '__main__':
155+
train(learning_rate=0.1, batch_size=7, num_passes=100)

0 commit comments

Comments
 (0)