Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/reader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
reader_library(create_recordio_file_reader_op SRCS create_recordio_file_reader_op.cc)
reader_library(create_double_buffer_reader_op SRCS create_double_buffer_reader_op.cc)
reader_library(create_multi_pass_reader_op SRCS create_multi_pass_reader_op.cc)
# Export local libraries to parent
set(READER_LIBRARY ${LOCAL_READER_LIBS} PARENT_SCOPE)
101 changes: 101 additions & 0 deletions paddle/fluid/operators/reader/create_multi_pass_reader_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/detail/safe_ref.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class MultiPassReader : public framework::DecoratedReader {
public:
MultiPassReader(ReaderBase* reader, int pass_num)
: DecoratedReader(reader), pass_num_(pass_num), pass_count_(0) {}

void ReadNext(std::vector<framework::LoDTensor>* out) override {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}
reader_->ReadNext(out);
}

bool HasNext() const override {
if (reader_->HasNext()) {
return true;
} else {
++pass_count_;
if (pass_count_ >= pass_num_) {
return false;
} else {
reader_->ReInit();
return true;
}
}
}

void ReInit() override {
pass_count_ = 0;
reader_->ReInit();
}

private:
int pass_num_;
mutable int pass_count_;
};

class CreateMultiPassReaderOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;

private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& underlying_reader = scope.FindVar(Input("UnderlyingReader"))
->Get<framework::ReaderHolder>();
auto& out = detail::Ref(scope.FindVar(Output("Out")));
int pass_num = Attr<int>("pass_num");
out.GetMutable<framework::ReaderHolder>()->Reset(
new MultiPassReader(underlying_reader.Get(), pass_num));
}
};

class CreateMultiPassReaderOpMaker : public DecoratedReaderMakerBase {
public:
CreateMultiPassReaderOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: DecoratedReaderMakerBase(op_proto, op_checker) {
AddAttr<int>("pass_num", "The number of pass to run.").GreaterThan(0);
AddComment(R"DOC(
CreateMultiPassReader Operator

This operator creates a multi-pass reader. A multi-pass reader
is used to yield data for several pass training continuously.
It takes the the number of pass to run as one of its attributes
('pass_num'), and maintains a pass counter to record how many
passes it has completed. When the underlying reader reach the EOF,
the multi-pass reader checks whether it has completed training
of the given number of pass. If not, the underlying reader will
be re-initialized and starts a new pass automatically.
)DOC");
}
};

} // namespace reader
} // namespace operators
} // namespace paddle

namespace ops = paddle::operators::reader;
REGISTER_DECORATED_READER_OPERATOR(create_multi_pass_reader,
ops::CreateMultiPassReaderOp,
ops::CreateMultiPassReaderOpMaker);
7 changes: 6 additions & 1 deletion python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader'
'create_double_buffer_reader', 'create_multi_pass_reader'
]


Expand Down Expand Up @@ -345,6 +345,11 @@ def create_double_buffer_reader(reader, place=None):
attrs)


def create_multi_pass_reader(reader, pass_num):
return __create_decorated_reader__('create_multi_pass_reader', reader,
{'pass_num': int(pass_num)})


def read_file(file_obj):
helper = LayerHelper('read_file')
out = [
Expand Down
65 changes: 65 additions & 0 deletions python/paddle/fluid/tests/unittests/test_multi_pass_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle.fluid as fluid
import paddle.v2 as paddle
import paddle.v2.dataset.mnist as mnist


class TestMultipleReader(unittest.TestCase):
def setUp(self):
self.batch_size = 64
self.pass_num = 3
# Convert mnist to recordio file
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = paddle.batch(mnist.train(), batch_size=self.batch_size)
feeder = fluid.DataFeeder(
feed_list=[
fluid.layers.data(
name='image', shape=[784]),
fluid.layers.data(
name='label', shape=[1], dtype='int64'),
],
place=fluid.CPUPlace())
self.num_batch = fluid.recordio_writer.convert_reader_to_recordio_file(
'./mnist.recordio', data_file, feeder)

def test_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
data_file = fluid.layers.open_recordio_file(
filename='./mnist.recordio',
shapes=[(-1, 784), (-1, 1)],
lod_levels=[0, 0],
dtypes=['float32', 'int64'])
data_file = fluid.layers.create_multi_pass_reader(
reader=data_file, pass_num=self.pass_num)
img, label = fluid.layers.read_file(data_file)

if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
else:
place = fluid.CPUPlace()

exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())

batch_count = 0
while not data_file.eof():
img_val, = exe.run(fetch_list=[img])
batch_count += 1
self.assertLessEqual(img_val.shape[0], self.batch_size)
data_file.reset()
self.assertEqual(batch_count, self.num_batch * self.pass_num)