Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/reader/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ function(reader_library TARGET_NAME)
PARENT_SCOPE)
endfunction()

reader_library(open_files_op SRCS open_files_op.cc)
reader_library(create_random_data_generator_op SRCS create_random_data_generator_op.cc)
reader_library(create_shuffle_reader_op SRCS create_shuffle_reader_op.cc)
reader_library(create_batch_reader_op SRCS create_batch_reader_op.cc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,13 @@ class CreateDoubleBufferReaderOpMaker : public DecoratedReaderMakerBase {
};

void DoubleBufferReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}

if (local_buffer_.payloads_.empty()) {
buffer_->Receive(&local_buffer_);
}

*out = local_buffer_.payloads_;
local_buffer_.payloads_.clear();
if (local_buffer_.ctx_) {
Expand Down
212 changes: 212 additions & 0 deletions paddle/fluid/operators/reader/open_files_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/operators/reader/reader_op_registry.h"

namespace paddle {
namespace operators {
namespace reader {

class MultipleReader : public framework::ReaderBase {
public:
MultipleReader(const std::vector<std::string>& file_names,
const std::vector<framework::DDim>& dims, size_t thread_num)
: file_names_(file_names), dims_(dims) {
prefetchers_.resize(thread_num);
StartNewScheduler();
}

void ReadNext(std::vector<framework::LoDTensor>* out) override;
bool HasNext() const override;
void ReInit() override;

~MultipleReader() { EndScheduler(); }

private:
void StartNewScheduler();
void EndScheduler();
void ScheduleThreadFunc();
void PrefetchThreadFunc(std::string file_name, size_t thread_idx);

std::vector<std::string> file_names_;
std::vector<framework::DDim> dims_;
std::thread scheduler_;
std::vector<std::thread> prefetchers_;
framework::Channel<size_t>* waiting_file_idx_;
framework::Channel<size_t>* available_thread_idx_;
framework::Channel<std::vector<framework::LoDTensor>>* buffer_;
mutable std::vector<framework::LoDTensor> local_buffer_;
};

void MultipleReader::ReadNext(std::vector<framework::LoDTensor>* out) {
if (!HasNext()) {
PADDLE_THROW("There is no next data!");
}

if (local_buffer_.empty()) {
buffer_->Receive(&local_buffer_);
}
*out = local_buffer_;
local_buffer_.clear();
}

bool MultipleReader::HasNext() const {
return local_buffer_.empty() ? buffer_->Receive(&local_buffer_) : true;
}

void MultipleReader::ReInit() {
EndScheduler();
local_buffer_.clear();
StartNewScheduler();
}

void MultipleReader::StartNewScheduler() {
size_t thread_num = prefetchers_.size();
waiting_file_idx_ = framework::MakeChannel<size_t>(file_names_.size());
available_thread_idx_ = framework::MakeChannel<size_t>(thread_num);
buffer_ =
framework::MakeChannel<std::vector<framework::LoDTensor>>(thread_num);

for (size_t i = 0; i < file_names_.size(); ++i) {
waiting_file_idx_->Send(&i);
}
waiting_file_idx_->Close();
for (size_t i = 0; i < thread_num; ++i) {
available_thread_idx_->Send(&i);
}

scheduler_ = std::thread([this] { ScheduleThreadFunc(); });
}

void MultipleReader::EndScheduler() {
available_thread_idx_->Close();
buffer_->Close();
waiting_file_idx_->Close();
if (scheduler_.joinable()) {
scheduler_.join();
}
delete buffer_;
delete available_thread_idx_;
delete waiting_file_idx_;
}

void MultipleReader::ScheduleThreadFunc() {
VLOG(5) << "MultipleReader schedule thread starts.";
size_t completed_thread_num = 0;
size_t thread_idx;
while (available_thread_idx_->Receive(&thread_idx)) {
std::thread& prefetcher = prefetchers_[thread_idx];
if (prefetcher.joinable()) {
prefetcher.join();
}
size_t file_idx;
if (waiting_file_idx_->Receive(&file_idx)) {
// Still have files to read. Start a new prefetch thread.
std::string file_name = file_names_[file_idx];
prefetcher = std::thread([this, file_name, thread_idx] {
PrefetchThreadFunc(file_name, thread_idx);
});
} else {
// No more file to read.
++completed_thread_num;
if (completed_thread_num == prefetchers_.size()) {
buffer_->Close();
break;
}
}
}
// If users invoke ReInit() when scheduler is running, it will close the
// 'avaiable_thread_idx_' and prefecther threads have no way to tell scheduler
// to release their resource. So a check is needed before scheduler ends.
for (auto& p : prefetchers_) {
if (p.joinable()) {
p.join();
}
}
VLOG(5) << "MultipleReader schedule thread terminates.";
}

void MultipleReader::PrefetchThreadFunc(std::string file_name,
size_t thread_idx) {
VLOG(5) << "The prefetch thread of file '" << file_name << "' starts.";
std::unique_ptr<framework::ReaderBase> reader =
CreateReaderByFileName(file_name, dims_);
while (reader->HasNext()) {
std::vector<framework::LoDTensor> ins;
reader->ReadNext(&ins);
if (!buffer_->Send(&ins)) {
VLOG(5) << "WARNING: The buffer channel has been closed. The prefetch "
"thread of file '"
<< file_name << "' will terminate.";
break;
}
}
if (!available_thread_idx_->Send(&thread_idx)) {
VLOG(5) << "WARNING: The available_thread_idx_ channel has been closed. "
"Fail to send thread_idx.";
}
VLOG(5) << "The prefetch thread of file '" << file_name << "' terminates.";
}

class OpenFilesOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;

private:
void RunImpl(const framework::Scope& scope,
const platform::Place& dev_place) const override {
const auto& shape_concat = Attr<std::vector<int>>("shape_concat");
const auto& ranks = Attr<std::vector<int>>("ranks");
PADDLE_ENFORCE(!shape_concat.empty() && !ranks.empty());
PADDLE_ENFORCE_EQ(std::accumulate(ranks.begin(), ranks.end(), 0),
int(shape_concat.size()),
"The accumulate of all ranks should be equal to the "
"shape concat's length.");
const auto& file_names = Attr<std::vector<std::string>>("file_names");
PADDLE_ENFORCE(!file_names.empty(), "No file to be read!");
const size_t thread_num = Attr<int>("thread_num");

auto* out = scope.FindVar(Output("Out"))
->template GetMutable<framework::ReaderHolder>();
out->Reset(new MultipleReader(
file_names, RestoreShapes(shape_concat, ranks), thread_num));
}
};

class OpenFilesOpMaker : public FileReaderMakerBase {
public:
OpenFilesOpMaker(OpProto* op_proto, OpAttrChecker* op_checker)
: FileReaderMakerBase(op_proto, op_checker) {
AddAttr<std::vector<std::string>>("file_names", "Files to be read.");
AddAttr<int>("thread_num", "The maximal concurrent prefetch thread number.")
.GreaterThan(0);

AddComment(R"DOC(
OpenFiles Operator

An OpenFilesOp creates a MultipleReader, which is able to
read data multi-threaded from multiple files.
)DOC");
}
};

} // namespace reader
} // namespace operators
} // namespace paddle

namespace reader = paddle::operators::reader;

REGISTER_FILE_READER_OPERATOR(open_files, reader::OpenFilesOp,
reader::OpenFilesOpMaker);
15 changes: 15 additions & 0 deletions paddle/fluid/operators/reader/reader_op_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,21 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry() {
return regs;
}

std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims) {
size_t separator_pos = file_name.find_last_of(kFileFormatSeparator);
PADDLE_ENFORCE_NE(separator_pos, std::string::npos,
"File name illegal! A legal file name should be like: "
"[file_name].[file_format] (e.g., 'data_file.recordio').");
std::string filetype = file_name.substr(separator_pos + 1);

auto itor = FileReaderRegistry().find(filetype);
PADDLE_ENFORCE(itor != FileReaderRegistry().end(),
"No file reader registered for '%s' format.", filetype);
framework::ReaderBase* reader = (itor->second)(file_name, dims);
return std::unique_ptr<framework::ReaderBase>(reader);
}

FileReaderMakerBase::FileReaderMakerBase(
framework::OpProtoAndCheckerMaker::OpProto* op_proto,
framework::OpAttrChecker* op_checker)
Expand Down
9 changes: 7 additions & 2 deletions paddle/fluid/operators/reader/reader_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ namespace paddle {
namespace operators {
namespace reader {

static constexpr char kFileFormatSeparator[] = ".";

using FileReaderCreator = std::function<framework::ReaderBase*(
const std::string&, const std::vector<framework::DDim>&)>;

Expand All @@ -29,12 +31,15 @@ std::unordered_map<std::string, FileReaderCreator>& FileReaderRegistry();
template <typename Reader>
int RegisterFileReader(const std::string& filetype) {
FileReaderRegistry()[filetype] = [](
const std::string& fn, const std::vector<paddle::framework::DDim>& dim) {
return new Reader(fn, dim);
const std::string& fn, const std::vector<framework::DDim>& dims) {
return new Reader(fn, dims);
};
return 0;
}

std::unique_ptr<framework::ReaderBase> CreateReaderByFileName(
const std::string& file_name, const std::vector<framework::DDim>& dims);

extern std::vector<framework::DDim> RestoreShapes(
const std::vector<int>& shape_concat, const std::vector<int>& ranks);

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/recordio/header.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ Header::Header(uint32_t num, uint32_t sum, Compressor c, uint32_t cs)

bool Header::Parse(std::istream& is) {
uint32_t magic;
size_t read_size =
is.readsome(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
is.read(reinterpret_cast<char*>(&magic), sizeof(uint32_t));
size_t read_size = is.gcount();
if (read_size < sizeof(uint32_t)) {
return false;
}
Expand Down
33 changes: 32 additions & 1 deletion python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@

__all__ = [
'data', 'BlockGuardServ', 'ListenAndServ', 'Send', 'open_recordio_file',
'read_file', 'create_shuffle_reader', 'create_double_buffer_reader'
'open_files', 'read_file', 'create_shuffle_reader',
'create_double_buffer_reader'
]


Expand Down Expand Up @@ -287,6 +288,36 @@ def open_recordio_file(filename, shapes, lod_levels, dtypes):
startup_var)


def open_files(filenames, thread_num, shapes, lod_levels, dtypes):
dtypes = [convert_np_dtype_to_dtype_(dt) for dt in dtypes]
shape_concat = []
ranks = []

for shape in shapes:
shape_concat.extend(shape)
ranks.append(len(shape))

var_name = unique_name('multiple_reader')

startup_blk = default_startup_program().current_block()
startup_var = startup_blk.create_var(name=var_name)
startup_blk.append_op(
type='open_files',
outputs={'Out': [startup_var]},
attrs={
'shape_concat': shape_concat,
'lod_levels': lod_levels,
'ranks': ranks,
'file_names': filenames,
'thread_num': thread_num
})

startup_var.desc.set_dtypes(dtypes)
startup_var.persistable = True
return _copy_reader_var_(default_main_program().current_block(),
startup_var)


def __create_decorated_reader__(op_type, reader, attrs):
var_name = unique_name(op_type)
startup_blk = default_startup_program().current_block()
Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
mnist.recordio
mnist_0.recordio
mnist_1.recordio
mnist_2.recordio
Loading