Skip to content

Commit

Permalink
control calculating train acc by option (PaddlePaddle#11)
Browse files Browse the repository at this point in the history
* bug fix and py3 compatible

* add function to set whether to calc acc during training.
  • Loading branch information
lilong12 authored Dec 23, 2019
1 parent 35cb91e commit 51c752e
Show file tree
Hide file tree
Showing 7 changed files with 97 additions and 48 deletions.
7 changes: 6 additions & 1 deletion docs/api_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ PLSC大规模分类库提供了默认配置参数,用于设置训练、评估
| set_model_save_dir(dir) | 设置模型保存路径model_save_dir | 类型为字符串 |
| set_dataset_dir(dir) | 设置数据集根目录dataset_dir | 类型为字符串 |
| set_train_image_num(num) | 设置训练图像的总数量 | 类型为int |
| set_calc_acc(calc) | 设置是否在训练时计算acc1和acc5值 | 类型为bool |
| set_class_num(num) | 设置分类类别的总数量 | 类型为int |
| set_emb_size(size) | 设置最后一层隐层的输出维度 | 类型为int |
| set_model(model) | 设置用户使用的自定义模型类实例 | BaseModel的子类 |
Expand All @@ -71,7 +72,11 @@ PLSC大规模分类库提供了默认配置参数,用于设置训练、评估
| test() | 模型评估 | None |
| train() | 模型训练 | None |

备注:上述API均为PaddlePaddle大规模分类库PLSC的Entry类的方法,需要通过该类的实例
备注:

当设置set_calc_acc的参数值为True,会在训练是计算acc1和acc5的值,但这会占用额外的显存空间。

上述API均为PaddlePaddle大规模分类库PLSC的Entry类的方法,需要通过该类的实例
调用,例如:

```python
Expand Down
85 changes: 56 additions & 29 deletions plsc/entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%d %b %Y %H:%M:%S')
logger = logging.getLogger(__name__)


class Entry(object):
"""
Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(self):
self.with_test = self.config.with_test
self.model_save_dir = self.config.model_save_dir
self.warmup_epochs = self.config.warmup_epochs
self.calc_train_acc = False

if self.checkpoint_dir:
self.checkpoint_dir = os.path.abspath(self.checkpoint_dir)
Expand Down Expand Up @@ -173,6 +174,14 @@ def set_model_save_dir(self, directory):
self.model_save_dir = directory
logger.info("Set model_save_dir to {}.".format(directory))

def set_calc_acc(self, calc):
"""
Whether to calcuate acc1 and acc5 during training.
"""
self.calc_train_acc = calc
logger.info("Calcuating acc1 and acc5 during training: {}.".format(
calc))

def set_dataset_dir(self, directory):
"""
Set the root directory for datasets.
Expand Down Expand Up @@ -321,21 +330,27 @@ def build_program(self,
margin=self.margin,
scale=self.scale)

if self.loss_type in ["dist_softmax", "dist_arcface"]:
shard_prob = loss._get_info("shard_prob")
acc1 = None
acc5 = None

prob_all = fluid.layers.collective._c_allgather(shard_prob,
nranks=num_trainers, use_calc_stream=True)
prob_list = fluid.layers.split(prob_all, dim=0,
num_or_sections=num_trainers)
prob = fluid.layers.concat(prob_list, axis=1)
label_all = fluid.layers.collective._c_allgather(label,
nranks=num_trainers, use_calc_stream=True)
acc1 = fluid.layers.accuracy(input=prob, label=label_all, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label_all, k=5)
if self.loss_type in ["dist_softmax", "dist_arcface"]:
if self.calc_train_acc:
shard_prob = loss._get_info("shard_prob")

prob_all = fluid.layers.collective._c_allgather(shard_prob,
nranks=num_trainers, use_calc_stream=True)
prob_list = fluid.layers.split(prob_all, dim=0,
num_or_sections=num_trainers)
prob = fluid.layers.concat(prob_list, axis=1)
label_all = fluid.layers.collective._c_allgather(label,
nranks=num_trainers, use_calc_stream=True)
acc1 = fluid.layers.accuracy(input=prob, label=label_all, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label_all, k=5)
else:
acc1 = fluid.layers.accuracy(input=prob, label=label, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label, k=5)
if self.calc_train_acc:
acc1 = fluid.layers.accuracy(input=prob, label=label, k=1)
acc5 = fluid.layers.accuracy(input=prob, label=label, k=5)

optimizer = None
if is_train:
# initialize optimizer
Expand Down Expand Up @@ -621,11 +636,11 @@ def test(self, pass_id=0):

feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=test_program)
fetch_list = [emb.name, acc1.name, acc5.name]
fetch_list = [emb.name]
real_test_batch_size = self.global_test_batch_size

test_start = time.time()
for i in xrange(len(test_list)):
for i in range(len(test_list)):
data_list, issame_list = test_list[i]
embeddings_list = []
for j in xrange(len(data_list)):
Expand All @@ -643,7 +658,7 @@ def test(self, pass_id=0):
for k in xrange(begin, end):
_data.append((data[k], 0))
assert len(_data) == self.test_batch_size
[_embeddings, acc1, acc5] = exe.run(test_program,
[_embeddings] = exe.run(test_program,
fetch_list = fetch_list, feed=feeder.feed(_data),
use_program_cache=True)
if embeddings is None:
Expand All @@ -657,7 +672,7 @@ def test(self, pass_id=0):
_data = []
for k in xrange(end - self.test_batch_size, end):
_data.append((data[k], 0))
[_embeddings, acc1, acc5] = exe.run(test_program,
[_embeddings] = exe.run(test_program,
fetch_list = fetch_list, feed=feeder.feed(_data),
use_program_cache=True)
_embeddings = _embeddings[0:self.test_batch_size,:]
Expand Down Expand Up @@ -730,7 +745,7 @@ def train(self):
if self.with_test:
test_feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=test_program)
fetch_list_test = [test_emb.name, test_acc1.name, test_acc5.name]
fetch_list_test = [test_emb.name]
real_test_batch_size = self.global_test_batch_size

if self.checkpoint_dir:
Expand All @@ -750,8 +765,11 @@ def train(self):
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'], program=origin_prog)

fetch_list = [train_loss.name, global_lr.name,
train_acc1.name, train_acc5.name]
if self.calc_train_acc:
fetch_list = [train_loss.name, global_lr.name,
train_acc1.name, train_acc5.name]
else:
fetch_list = [train_loss.name, global_lr.name]

local_time = 0.0
nsamples = 0
Expand All @@ -763,9 +781,13 @@ def train(self):
for batch_id, data in enumerate(train_reader()):
nsamples += global_batch_size
t1 = time.time()
loss, lr, acc1, acc5 = exe.run(train_prog,
feed=feeder.feed(data), fetch_list=fetch_list,
use_program_cache=True)
if self.calc_train_acc:
loss, lr, acc1, acc5 = exe.run(train_prog,
feed=feeder.feed(data), fetch_list=fetch_list,
use_program_cache=True)
else:
loss, lr = exe.run(train_prog, feed=feeder.feed(data),
fetch_list=fetch_list, use_program_cache=True)
t2 = time.time()
period = t2 - t1
local_time += period
Expand All @@ -776,9 +798,14 @@ def train(self):
if batch_id % inspect_steps == 0:
avg_loss = np.mean(local_train_info[0])
avg_lr = np.mean(local_train_info[1])
print("Pass:%d batch:%d lr:%f loss:%f qps:%.2f acc1:%.4f acc5:%.4f" % (
pass_id, batch_id, avg_lr, avg_loss, nsamples / local_time,
acc1, acc5))
if self.calc_train_acc:
logger.info("Pass:%d batch:%d lr:%f loss:%f qps:%.2f "
"acc1:%.4f acc5:%.4f" % (pass_id, batch_id, avg_lr,
avg_loss, nsamples / local_time, acc1, acc5))
else:
logger.info("Pass:%d batch:%d lr:%f loss:%f qps:%.2f" %(
pass_id, batch_id, avg_lr, avg_loss,
nsamples / local_time))
local_time = 0
nsamples = 0
local_train_info = [[], [], [], []]
Expand Down Expand Up @@ -807,7 +834,7 @@ def train(self):
for k in xrange(begin, end):
_data.append((data[k], 0))
assert len(_data) == self.test_batch_size
[_embeddings, acc1, acc5] = exe.run(test_program,
[_embeddings] = exe.run(test_program,
fetch_list = fetch_list_test, feed=test_feeder.feed(_data),
use_program_cache=True)
if embeddings is None:
Expand All @@ -821,7 +848,7 @@ def train(self):
_data = []
for k in xrange(end - self.test_batch_size, end):
_data.append((data[k], 0))
[_embeddings, acc1, acc5] = exe.run(test_program,
[_embeddings] = exe.run(test_program,
fetch_list = fetch_list_test, feed=test_feeder.feed(_data),
use_program_cache=True)
_embeddings = _embeddings[0:self.test_batch_size,:]
Expand Down
4 changes: 2 additions & 2 deletions plsc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import paddle
import paddle.fluid as fluid
from paddle.fluid import unique_name
import dist_algo
from . import dist_algo


__all__ = ["BaseModel"]
Expand Down Expand Up @@ -73,7 +73,7 @@ def get_output(self,
param_attr,
bias_attr)
elif loss_type == "arcface":
loss, prob = self.fc_arcface(emb,
loss, prob = self.arcface(emb,
label,
num_classes,
param_attr,
Expand Down
5 changes: 3 additions & 2 deletions plsc/utils/jpeg_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import functools
import numpy as np
import paddle
import six
from PIL import Image, ImageEnhance
try:
from StringIO import StringIO
Expand Down Expand Up @@ -238,9 +239,9 @@ def load_bin(path, image_size):
for flip in [0, 1]:
data = np.empty((len(issame_list)*2, 3, image_size[0], image_size[1]))
data_list.append(data)
for i in xrange(len(issame_list)*2):
for i in range(len(issame_list)*2):
_bin = bins[i]
if not isinstance(_bin, basestring):
if not isinstance(_bin, six.string_types):
_bin = _bin.tostring()
img_ori = Image.open(StringIO(_bin))
for flip in [0, 1]:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ scipy ; python_version>="3.5"
Pillow
sklearn
easydict
six
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@


REQUIRED_PACKAGES = [
'sklearn', 'easydict', 'Pillow', 'numpy', 'scipy'
'sklearn', 'easydict', 'Pillow', 'numpy', 'scipy', 'six'
]


Expand Down
41 changes: 28 additions & 13 deletions tools/process_base64_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import logging
import sqlite3
import tempfile
import six


logging.basicConfig(level=logging.INFO,
Expand Down Expand Up @@ -85,6 +86,13 @@ def __init__(self, data_dir, file_list, nranks):
self.conn = None
self.cursor = None

def insert_to_db(self, cnt, line):
label = int(line[0])
data = line[1]
sql_cmd = "INSERT INTO DATASET (ID, DATA, LABEL) "
sql_cmd += "VALUES ({}, '{}', {});".format(cnt, data, label)
self.cursor.execute(sql_cmd)

def create_db(self):
start = time.time()
print(self.sqlite3_file)
Expand All @@ -98,19 +106,26 @@ def create_db(self):
file_list_path = os.path.join(self.data_dir, self.file_list)
with open(file_list_path, 'r') as f:
cnt = 0
for line in f.xreadlines():
line = line.strip()
file_path = os.path.join(self.data_dir, line)
with open(file_path, 'r') as df:
for line in df.xreadlines():
line = line.strip().split('\t')
label = int(line[0])
data = line[1]
sql_cmd = "INSERT INTO DATASET (ID, DATA, LABEL) "
sql_cmd += "VALUES ({}, '{}', {});".format(cnt, data, label)
self.cursor.execute(sql_cmd)
cnt += 1
os.remove(file_path)
if six.PY2:
for line in f.xreadlines():
line = line.strip()
file_path = os.path.join(self.data_dir, line)
with open(file_path, 'r') as df:
for line in df.xreadlines():
line = line.strip().split('\t')
self.insert_to_db(cnt, line)
cnt += 1
os.remove(file_path)
else:
for line in f:
line = line.strip()
file_path = os.path.join(self.data_dir, line)
with open(file_path, 'r') as df:
for line in df:
line = line.strip().split('\t')
self.insert_to_db(cnt, line)
cnt += 1
os.remove(file_path)

self.conn.commit()
diff = time.time() - start
Expand Down

0 comments on commit 51c752e

Please sign in to comment.