Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added class gzFileLoader to read models in gzipped files #1403

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .appveyor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ before_build:
- cmd: cd build
- cmd: cmake -DEIGEN3_INCLUDE_DIR=C:/projects/eigen -G "Visual Studio 14 2015 Win64" -DCMAKE_BUILD_TYPE=%configuration% -DENABLE_BOOST=ON -DENABLE_CPP_EXAMPLES=ON -DBOOST_ROOT:PATHNAME="%BOOST_ROOT%" -DBoost_LIBRARY_DIRS:FILEPATH="%BOOST_LIBRARYDIR%" -DBoost_NO_BOOST_CMAKE=TRUE -DBoost_NO_SYSTEM_PATHS=TRUE -DPYTHON=python.exe ..
- cmd: set VS90COMNTOOLS=%VS140COMNTOOLS%
- cmd: set PATH=%BOOST_LIBRARYDIR%;%PATH%
- cmd: cd ..

build:
Expand Down
2 changes: 1 addition & 1 deletion .travis/install_dependencies.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ if [[ "$TRAVIS_OS_NAME" == "linux" ]]; then
sudo apt-get install -y gcc-4.8 g++-4.8
PYTHON_PACKAGES="numpy pypandoc twine auditwheel cython"
if [[ "$PYTHON_INSTALL" == manual ]]; then
sudo apt-get install -y --allow-unauthenticated libboost-filesystem1.55-dev libboost-program-options1.55-dev libboost-serialization1.55-dev libboost-test1.55-dev libboost-regex1.55-dev
sudo apt-get install -y --allow-unauthenticated libboost-filesystem1.55-dev libboost-program-options1.55-dev libboost-serialization1.55-dev libboost-test1.55-dev libboost-regex1.55-dev libboost-iostreams1.55-dev
sudo -H pip install -U $PYTHON_PACKAGES
else
sudo apt-get install -y pandoc
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ if(ENABLE_BOOST)
endif()
endif()
set(Boost_REALPATH ON)
find_package(Boost COMPONENTS program_options regex serialization REQUIRED)
find_package(Boost COMPONENTS program_options regex serialization iostreams REQUIRED)
add_definitions(-DHAVE_BOOST)
message("-- Boost dir is " ${Boost_INCLUDE_DIR})
include_directories(${Boost_INCLUDE_DIR})
if(MSVC)
Expand Down
3 changes: 2 additions & 1 deletion dynet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ set(dynet_library_SRCS
)
if(ENABLE_BOOST)
list(APPEND dynet_library_SRCS mp.cc)

endif()

# Headers:
Expand Down Expand Up @@ -161,7 +162,7 @@ weight-decay.h
if(ENABLE_BOOST)
list(APPEND dynet_library_HDRS mp.h)
endif()

set(dynet_gpu_mergeable_SRCS
nodes-activations
nodes-affinetransform
Expand Down
293 changes: 293 additions & 0 deletions dynet/io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

#include <algorithm>

#ifdef HAVE_BOOST
#include <boost/iostreams/filtering_streambuf.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#endif

// Normally DyNet style permits using namespace std, but to make compatibility
// possible with some external code, it is simpler if types are fully
// qualified in dynet/io.cc. Please do not uncomment the following:
Expand All @@ -20,6 +25,7 @@ static const int FLOAT32_EXPONENT = 2;
namespace dynet {
namespace {

// ----------------- utility functions
bool valid_key(const std::string & s) {
if (s.size() == 0) return true;
if (s == "/") return false;
Expand Down Expand Up @@ -52,11 +58,27 @@ void read_param_header(std::string line, std::string &type, std::string &name, D
}
}


bool relevant_header_line(std::string line, const std::string &pref, const std::string &key_, std::string &type, std::string &name, Dim& dim,size_t& byte_count, bool& zero_grad){

if (line.substr(0,pref.size()) != pref)
return false; // not the required header start "#", "#Parameter", etc

// if a param header, check the key
read_param_header(line, type, name, dim, byte_count, zero_grad);
return (name.substr(0, key_.size()) == key_);
}



} // anyonymous namespace

Saver::~Saver() {}
Loader::~Loader() {}


// =========== TextFileSaver ====================

TextFileSaver::TextFileSaver(const std::string & filename, bool append) :
p_datastream(
new std::ofstream(
Expand Down Expand Up @@ -135,6 +157,8 @@ void TextFileSaver::save(const LookupParameterStorage & p,
datastream << dynet::as_vector(p.all_grads) << std::endl;
}

// =========== TextFileLoader ====================

TextFileLoader::TextFileLoader(const std::string & filename) :
dataname(filename) { }

Expand Down Expand Up @@ -333,4 +357,273 @@ LookupParameter TextFileLoader::load_lookup_param(ParameterCollection & model,
DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");
}


// =========== gzFileLoader ====================
#ifdef HAVE_BOOST

gzFileLoader::gzFileLoader(const std::string & filename) :
dataname(filename) { }

gzFileLoader::~gzFileLoader() {}



void gzFileLoader::populate(ParameterCollection & model, const std::string & key) {

std::ifstream fmod(dataname, std::ios_base::in | std::ios_base::binary);
if(!fmod) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
boost::iostreams::filtering_streambuf<boost::iostreams::input> inbuf;
inbuf.push(boost::iostreams::gzip_decompressor());
inbuf.push(fmod);
std::istream datastream(&inbuf);

std::string line, type, name;
bool zero_grad = false;
Dim dim;
size_t byte_count = 0;
std::vector<float> values;
Tensor *value_t, *grad_t;
size_t param_id = 0, lookup_id = 0;
ParameterCollectionStorage & storage = model.get_storage();
std::string key_ = key;
if (key_.size() != 0 && key_.back() != '/') key_ += "/";
while(std::getline(datastream, line)) {
// skip non-relevant parameter lines
while (! relevant_header_line(line, "#", key_, type, name, dim, byte_count, zero_grad))
std::getline(datastream, line);

// We found a relevant parameter line, load it
if (type == "#Parameter#") {
values.resize(dim.size());
if(param_id >= storage.params.size())
DYNET_RUNTIME_ERR("Too many parameters to load in populated model at " << name);
ParameterStorage & param = *storage.params[param_id++];
if(param.dim != dim)
DYNET_RUNTIME_ERR("Dimensions of parameter " << name << " looked up from file (" << dim <<
") do not match parameters to be populated (" << param.dim << ")");
value_t = &param.values;
grad_t = &param.g;
}

// Load a lookup parameter
else if(type == "#LookupParameter#") {
values.resize(dim.size());
if(lookup_id >= storage.lookup_params.size())
DYNET_RUNTIME_ERR("Too many lookup parameters in populated model at " << name);
LookupParameterStorage & param = *storage.lookup_params[lookup_id++];
if(param.all_dim != dim)
DYNET_RUNTIME_ERR("Dimensions of lookup parameter " << name << " lookup up from file (" << dim <<
") do not match parameters to be populated (" << param.all_dim << ")");
value_t = &param.all_values;
grad_t = &param.all_grads;
}

// some unexpected header
else {
DYNET_RUNTIME_ERR("Bad parameter specification in model: " << line);
}

// load parameter
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(*value_t, values);
if(!zero_grad){
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(*grad_t, values);
} else {
TensorTools::zero(*grad_t);
}
}

if(param_id != storage.params.size() || lookup_id != storage.lookup_params.size())
DYNET_RUNTIME_ERR("Number of parameter/lookup parameter objects loaded from file (" <<
param_id << '/' << lookup_id << ") did not match number to be populated (" <<
storage.params.size() << '/' << storage.lookup_params.size() << ')');

fmod.close();
}

void gzFileLoader::populate(Parameter & param,
const std::string & key) {
if(key == "") DYNET_INVALID_ARG("gzFileLoader.populate() requires non-empty key");

std::ifstream fmod(dataname, std::ios_base::in | std::ios_base::binary);
if(!fmod) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
boost::iostreams::filtering_streambuf<boost::iostreams::input> inbuf;
inbuf.push(boost::iostreams::gzip_decompressor());
inbuf.push(fmod);
std::istream datastream(&inbuf);

std::string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;

// skip non-relevant parameter lines
bool found = false;
std::getline(datastream, line);
while (!datastream.eof() && !found) {
if (relevant_header_line(line, "#Parameter#", key, type, name, dim, byte_count, zero_grad))
found = true;
else
std::getline(datastream, line);
}
// search parameter was not found
if (!found) DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");

// parameter was found, load it
if(param.p->dim != dim)
DYNET_RUNTIME_ERR("Attempted to populate parameter where arguments don't match (" << param.p->dim << " != " << dim << ")");
std::vector<float> values(dim.size());
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().values, values);
if(!zero_grad){
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().g, values);
}
else {
TensorTools::zero(param.get_storage().g);
}

fmod.close();
return;
}


void gzFileLoader::populate(LookupParameter & lookup_param,
const std::string & key) {
if(key == "") DYNET_INVALID_ARG("gzFileLoader.populate() requires non-empty key");

std::ifstream fmod(dataname, std::ios_base::in | std::ios_base::binary);
if(!fmod) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
boost::iostreams::filtering_streambuf<boost::iostreams::input> inbuf;
inbuf.push(boost::iostreams::gzip_decompressor());
inbuf.push(fmod);
std::istream datastream(&inbuf);

std::string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;

// skip non-relevant parameter lines
bool found = false;
std::getline(datastream, line);
while (!datastream.eof() && !found) {
if (relevant_header_line(line, "#LookupParameter#", key, type, name, dim, byte_count, zero_grad))
found = true;
else
std::getline(datastream, line);
}
// search parameter was not found
if (!found) DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");

// parameter was found, load it
if(lookup_param.p->all_dim != dim)
DYNET_RUNTIME_ERR("Attempted to populate lookup parameter where arguments don't match (" << lookup_param.p->all_dim << " != " << dim << ")");

std::vector<float> values(dim.size());
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_values, values);
if(!zero_grad){
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_grads, values);
}
else {
TensorTools::zero(lookup_param.get_storage().all_grads);
}
fmod.close();
return;
}

Parameter gzFileLoader::load_param(ParameterCollection & model,
const std::string & key) {
if (key == "") DYNET_INVALID_ARG("gzFileLoader.load_param() requires non-empty key");

std::ifstream fmod(dataname, std::ios_base::in | std::ios_base::binary);
if(!fmod) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
boost::iostreams::filtering_streambuf<boost::iostreams::input> inbuf;
inbuf.push(boost::iostreams::gzip_decompressor());
inbuf.push(fmod);
std::istream datastream(&inbuf);

std::string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;

// skip non-relevant parameter lines
bool found = false;
std::getline(datastream, line);
while (!datastream.eof() && !found) {
if (relevant_header_line(line, "#Parameter#", key, type, name, dim, byte_count, zero_grad))
found = true;
else
std::getline(datastream, line);
}
// search parameter was not found
if (!found) DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");

// parameter was found, add and load it
Parameter param = model.add_parameters(dim);
param.get_storage().name = name;
std::vector<float> values(dim.size());
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().values, values);
if(!zero_grad){
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(param.get_storage().g, values);
} else {
TensorTools::zero(param.get_storage().g);
}
fmod.close();
return param;
}

LookupParameter gzFileLoader::load_lookup_param(ParameterCollection & model,
const std::string & key) {

if(key == "") DYNET_INVALID_ARG("gzFileLoader.load_lookup_param() requires non-empty key");

std::ifstream fmod(dataname, std::ios_base::in | std::ios_base::binary);
if(!fmod) DYNET_RUNTIME_ERR("Could not read model from " << dataname);
boost::iostreams::filtering_streambuf<boost::iostreams::input> inbuf;
inbuf.push(boost::iostreams::gzip_decompressor());
inbuf.push(fmod);
std::istream datastream(&inbuf);

std::string line, type, name;
bool zero_grad=false;
Dim dim;
size_t byte_count = 0;

// skip non-relevant parameter lines
bool found = false;
std::getline(datastream, line);
while (!datastream.eof() && !found) {
if (relevant_header_line(line, "#LookupParameter#", key, type, name, dim, byte_count, zero_grad))
found = true;
else
std::getline(datastream, line);
}
// search parameter was not found
if (!found) DYNET_RUNTIME_ERR("Could not find key " << key << " in the model file");

// parameter was found, load it
std::vector<float> values(dim.size());
size_t size = dim[dim.nd-1]; dim.nd--;
LookupParameter lookup_param = model.add_lookup_parameters(size, dim);
lookup_param.get_storage().name = name;
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_values, values);
if(!zero_grad){
{ std::getline(datastream, line); std::istringstream iss(line); iss >> values; }
TensorTools::set_elements(lookup_param.get_storage().all_grads, values);
} else {
TensorTools::zero(lookup_param.get_storage().all_grads);
}
fmod.close();
return lookup_param;
}
#endif // ifdef HAVE_BOOST

} // namespace dynet
Loading