Skip to content

Commit

Permalink
Adding a hook (wrapper) for non-std stream reader in PyTorchStreamRea…
Browse files Browse the repository at this point in the history
…der (pytorch#15551)

Summary:
To implement a stream is very annoying, since it is closely defined with the underlying storage streambuffer.

So in this PR, we add ReadAdapterInterface and PyTorchStreamReader will use it. We implement IStreamAdapter as a wrapper of std::istream. And keep the user interface unchanged.
Pull Request resolved: pytorch#15551

Reviewed By: zrphercule

Differential Revision: D13568907

Pulled By: houseroad

fbshipit-source-id: 93708cb801248a6c101f35cb14d1631029365c3c
  • Loading branch information
houseroad authored and facebook-github-bot committed Jan 5, 2019
1 parent 1488c5d commit a918f1d
Show file tree
Hide file tree
Showing 13 changed files with 229 additions and 50 deletions.
5 changes: 4 additions & 1 deletion caffe2/serialize/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ file(GLOB tmp *_test.cc)
set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} ${tmp})
list(APPEND Caffe2_CPU_SRCS
${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8/miniz.c
${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc)
${CMAKE_CURRENT_SOURCE_DIR}/inline_container.cc
${CMAKE_CURRENT_SOURCE_DIR}/istream_adapter.cc
${CMAKE_CURRENT_SOURCE_DIR}/file_adapter.cc
${CMAKE_CURRENT_SOURCE_DIR}/read_adapter_interface.cc)
list(APPEND Caffe2_CPU_INCLUDE ${PROJECT_SOURCE_DIR}/third_party/miniz-2.0.8)

set(Caffe2_CPU_TEST_SRCS ${Caffe2_CPU_TEST_SRCS} PARENT_SCOPE)
Expand Down
28 changes: 28 additions & 0 deletions caffe2/serialize/file_adapter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include "caffe2/serialize/file_adapter.h"
#include <c10/util/Exception.h>
#include "caffe2/core/common.h"

namespace caffe2 {
namespace serialize {

FileAdapter::FileAdapter(const std::string& file_name) {
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
if (!file_stream_) {
AT_ERROR("open file failed, file path: ", file_name);
}
istream_adapter_ = caffe2::make_unique<IStreamAdapter>(&file_stream_);
}

size_t FileAdapter::size() const {
return istream_adapter_->size();
}

size_t FileAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
const {
return istream_adapter_->read(pos, buf, n, what);
}

FileAdapter::~FileAdapter() {}

} // namespace serialize
} // namespace caffe2
28 changes: 28 additions & 0 deletions caffe2/serialize/file_adapter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <fstream>
#include <memory>

#include <c10/macros/Macros.h>
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"

namespace caffe2 {
namespace serialize {

class FileAdapter final : public ReadAdapterInterface {
public:
C10_DISABLE_COPY_AND_ASSIGN(FileAdapter);
explicit FileAdapter(const std::string& file_name);
size_t size() const override;
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override;
~FileAdapter();

private:
std::ifstream file_stream_;
std::unique_ptr<IStreamAdapter> istream_adapter_;
};

} // namespace serialize
} // namespace caffe2
71 changes: 42 additions & 29 deletions caffe2/serialize/inline_container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,17 @@
#include <c10/core/Allocator.h>
#include <c10/core/Backend.h>

#include "caffe2/core/common.h"
#include "caffe2/core/logging.h"
#include "caffe2/serialize/file_adapter.h"
#include "caffe2/serialize/inline_container.h"
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"

#include "miniz.h"

namespace torch { namespace jit {
namespace caffe2 {
namespace serialize {

size_t istream_read_func(void *pOpaque, mz_uint64 file_ofs, void *pBuf, size_t n) {
auto self = static_cast<PyTorchStreamReader*>(pOpaque);
Expand Down Expand Up @@ -42,27 +47,33 @@ static std::string basename(const std::string& name) {
}

size_t PyTorchStreamReader::read(uint64_t pos, char* buf, size_t n) {
in_->seekg(pos);
if(!*in_)
return 0;
in_->read(static_cast<char*>(buf), n);
if(!*in_)
return 0;
return n;
return in_->read(pos, buf, n, "reading file");
}

PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in)
: ar_(new mz_zip_archive), in_(in) {
memset(ar_.get(), 0, sizeof(mz_zip_archive));
PyTorchStreamReader::PyTorchStreamReader(const std::string& file_name)
: ar_(caffe2::make_unique<mz_zip_archive>()),
in_(caffe2::make_unique<FileAdapter>(file_name)) {
init();
}

if (!in_) {
file_stream_.open(file_name, std::ifstream::in | std::ifstream::binary);
in_ = &file_stream_;
valid("opening archive");
}
PyTorchStreamReader::PyTorchStreamReader(std::istream* in)
: ar_(caffe2::make_unique<mz_zip_archive>()),
in_(caffe2::make_unique<IStreamAdapter>(in)) {
init();
}

PyTorchStreamReader::PyTorchStreamReader(
std::unique_ptr<ReadAdapterInterface> in)
: ar_(caffe2::make_unique<mz_zip_archive>()), in_(std::move(in)) {
init();
}

in_->seekg(0, in_->end);
size_t size = in_->tellg();
void PyTorchStreamReader::init() {
AT_ASSERT(in_ != nullptr);
AT_ASSERT(ar_ != nullptr);
memset(ar_.get(), 0, sizeof(mz_zip_archive));

size_t size = in_->size();

// check for the old magic number,
constexpr size_t kMagicValueLength = 8;
Expand All @@ -81,7 +92,6 @@ PyTorchStreamReader::PyTorchStreamReader(std::string file_name, std::istream* in
mz_zip_reader_init(ar_.get(), size, 0);
valid("reading zip archive");


// figure out the archive_name (i.e. the zip folder all the other files are in)
// all lookups to getRecord will be prefixed by this folder
int n = mz_zip_reader_get_num_files(ar_.get());
Expand Down Expand Up @@ -126,9 +136,6 @@ void PyTorchStreamReader::valid(const char* what) {
if (err != MZ_ZIP_NO_ERROR) {
CAFFE_THROW("PytorchStreamReader failed ", what, ": ", mz_zip_get_error_string(err));
}
if (!*in_) {
CAFFE_THROW("PytorchStreamReader failed ", what, ".");
}
}

constexpr int MZ_ZIP_LOCAL_DIR_HEADER_SIZE = 30;
Expand Down Expand Up @@ -191,11 +198,12 @@ size_t PyTorchStreamReader::getRecordOffset(const std::string& name) {
mz_zip_archive_file_stat stat;
mz_zip_reader_file_stat(ar_.get(), getFileID(name), &stat);
valid("retriving file meta-data");
in_->seekg(stat.m_local_header_ofs);
valid("seeking to file header");
uint8_t local_header[MZ_ZIP_LOCAL_DIR_HEADER_SIZE];
in_->read(reinterpret_cast<char*>(local_header), MZ_ZIP_LOCAL_DIR_HEADER_SIZE);
valid("reading file header");
in_->read(
stat.m_local_header_ofs,
local_header,
MZ_ZIP_LOCAL_DIR_HEADER_SIZE,
"reading file header");
size_t filename_len = read_le_16(local_header + MZ_ZIP_LDH_FILENAME_LEN_OFS);
size_t extra_len = read_le_16(local_header + MZ_ZIP_LDH_EXTRA_LEN_OFS);
return stat.m_local_header_ofs + MZ_ZIP_LOCAL_DIR_HEADER_SIZE + filename_len + extra_len;
Expand Down Expand Up @@ -226,8 +234,12 @@ size_t ostream_write_func(void *pOpaque, mz_uint64 file_ofs, const void *pBuf, s
return n;
}

PyTorchStreamWriter::PyTorchStreamWriter(std::string file_name, std::ostream* out)
: ar_(new mz_zip_archive), archive_name_(basename(file_name)), out_(out) {
PyTorchStreamWriter::PyTorchStreamWriter(
std::string file_name,
std::ostream* out)
: ar_(caffe2::make_unique<mz_zip_archive>()),
archive_name_(basename(file_name)),
out_(out) {
memset(ar_.get(), 0, sizeof(mz_zip_archive));

if (archive_name_.size() == 0) {
Expand Down Expand Up @@ -302,4 +314,5 @@ PyTorchStreamWriter::~PyTorchStreamWriter() {
}
}

}} // namespace torch::jit
} // namespace serialize
} // namespace caffe2
33 changes: 19 additions & 14 deletions caffe2/serialize/inline_container.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
#include <c10/core/Backend.h>

#include "caffe2/core/logging.h"
#include "caffe2/serialize/istream_adapter.h"
#include "caffe2/serialize/read_adapter_interface.h"

extern "C" {
typedef struct mz_zip_archive mz_zip_archive;
Expand Down Expand Up @@ -84,7 +86,8 @@ typedef struct mz_zip_archive mz_zip_archive;
// model.json as the last file when writing after we have accumulated all
// other information.

namespace torch { namespace jit {
namespace caffe2 {
namespace serialize {

constexpr uint64_t kMinSupportedFileFormatVersion = 0x1L;
constexpr uint64_t kMaxSupportedFileFormatVersion = 0x1L;
Expand All @@ -97,9 +100,9 @@ constexpr uint64_t kFieldAlignment = 64;

class CAFFE2_API PyTorchStreamReader final {
public:
PyTorchStreamReader(std::string archive_name, std::istream* in=nullptr);
PyTorchStreamReader(std::istream* in)
: PyTorchStreamReader("archive", in) {}
explicit PyTorchStreamReader(const std::string& file_name);
explicit PyTorchStreamReader(std::istream* in);
explicit PyTorchStreamReader(std::unique_ptr<ReadAdapterInterface> in);

// return dataptr, size
std::tuple<at::DataPtr, size_t> getRecord(const std::string& name);
Expand All @@ -109,15 +112,16 @@ class CAFFE2_API PyTorchStreamReader final {
~PyTorchStreamReader();

private:
size_t read(uint64_t pos, char* buf, size_t n);
void valid(const char* what);
size_t getFileID(const std::string& name);

friend size_t istream_read_func(void *pOpaque, uint64_t file_ofs, void *pBuf, size_t n);
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::istream* in_;
std::ifstream file_stream_;
void init();
size_t read(uint64_t pos, char* buf, size_t n);
void valid(const char* what);
size_t getFileID(const std::string& name);

friend size_t
istream_read_func(void* pOpaque, uint64_t file_ofs, void* pBuf, size_t n);
std::unique_ptr<mz_zip_archive> ar_;
std::string archive_name_;
std::unique_ptr<ReadAdapterInterface> in_;
};

class CAFFE2_API PyTorchStreamWriter final {
Expand Down Expand Up @@ -150,4 +154,5 @@ class CAFFE2_API PyTorchStreamWriter final {
friend size_t ostream_write_func(void *pOpaque, uint64_t file_ofs, const void *pBuf, size_t n);
};

}} // namespace torch::jit
} // namespace serialize
} // namespace caffe2
10 changes: 6 additions & 4 deletions caffe2/serialize/inline_container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,16 @@

#include "caffe2/serialize/inline_container.h"

namespace at {
namespace caffe2 {
namespace serialize {
namespace {

TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
int64_t kFieldAlignment = 64L;

std::ostringstream oss;
// write records through writers
torch::jit::PyTorchStreamWriter writer(&oss);
PyTorchStreamWriter writer(&oss);
std::array<char, 127> data1;

for (int i = 0; i < data1.size(); ++i) {
Expand All @@ -37,7 +38,7 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
std::istringstream iss(the_file);

// read records through readers
torch::jit::PyTorchStreamReader reader(&iss);
PyTorchStreamReader reader(&iss);
at::DataPtr data_ptr;
int64_t size;
std::tie(data_ptr, size) = reader.getRecord("key1");
Expand All @@ -58,4 +59,5 @@ TEST(PyTorchStreamWriterAndReader, SaveAndLoad) {
}

} // namespace
} // namespace at
} // namespace serialize
} // namespace caffe2
39 changes: 39 additions & 0 deletions caffe2/serialize/istream_adapter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "caffe2/serialize/istream_adapter.h"
#include <c10/util/Exception.h>

namespace caffe2 {
namespace serialize {

IStreamAdapter::IStreamAdapter(std::istream* istream) : istream_(istream) {}

size_t IStreamAdapter::size() const {
auto prev_pos = istream_->tellg();
validate("getting the current position");
istream_->seekg(0, istream_->end);
validate("seeking to end");
auto result = istream_->tellg();
validate("getting size");
istream_->seekg(prev_pos);
validate("seeking to the original position");
return result;
}

size_t IStreamAdapter::read(uint64_t pos, void* buf, size_t n, const char* what)
const {
istream_->seekg(pos);
validate(what);
istream_->read(static_cast<char*>(buf), n);
validate(what);
return n;
}

void IStreamAdapter::validate(const char* what) const {
if (!*istream_) {
AT_ERROR("istream reader failed: ", what, ".");
}
}

IStreamAdapter::~IStreamAdapter() {}

} // namespace serialize
} // namespace caffe2
28 changes: 28 additions & 0 deletions caffe2/serialize/istream_adapter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once

#include <istream>

#include <c10/macros/Macros.h>

#include "caffe2/serialize/read_adapter_interface.h"

namespace caffe2 {
namespace serialize {

// this is a reader implemented by std::istream
class IStreamAdapter final : public ReadAdapterInterface {
public:
C10_DISABLE_COPY_AND_ASSIGN(IStreamAdapter);
explicit IStreamAdapter(std::istream* istream);
size_t size() const override;
size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const override;
~IStreamAdapter();

private:
std::istream* istream_;
void validate(const char* what) const;
};

} // namespace serialize
} // namespace caffe2
9 changes: 9 additions & 0 deletions caffe2/serialize/read_adapter_interface.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#include "caffe2/serialize/read_adapter_interface.h"

namespace caffe2 {
namespace serialize {

ReadAdapterInterface::~ReadAdapterInterface() {}

} // namespace serialize
} // namespace caffe2
21 changes: 21 additions & 0 deletions caffe2/serialize/read_adapter_interface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#pragma once

#include <cstddef>
#include <cstdint>

namespace caffe2 {
namespace serialize {

// this is the interface for the (file/stream/memory) reader in
// PyTorchStreamReader. with this interface, we can extend the support
// besides standard istream
class ReadAdapterInterface {
public:
virtual size_t size() const = 0;
virtual size_t read(uint64_t pos, void* buf, size_t n, const char* what = "")
const = 0;
virtual ~ReadAdapterInterface();
};

} // namespace serialize
} // namespace caffe2
Loading

0 comments on commit a918f1d

Please sign in to comment.