Skip to content

Split dataset into multiple files #2320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jun 2, 2017
77 changes: 76 additions & 1 deletion python/paddle/v2/dataset/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import sys
import importlib
import paddle.v2.dataset
import cPickle
import glob

__all__ = ['DATA_HOME', 'download', 'md5file']
__all__ = ['DATA_HOME', 'download', 'md5file', 'split', 'cluster_files_reader']

DATA_HOME = os.path.expanduser('~/.cache/paddle/dataset')

Expand Down Expand Up @@ -74,3 +76,76 @@ def fetch_all():
getattr(
importlib.import_module("paddle.v2.dataset.%s" % module_name),
"fetch")()


def split(reader, line_count, suffix="%05d.pickle", dumper=cPickle.dump):
"""
you can call the function as:

split(paddle.v2.dataset.cifar.train10(), line_count=1000,
suffix="imikolov-train-%05d.pickle")

the output files as:

|-imikolov-train-00000.pickle
|-imikolov-train-00001.pickle
|- ...
|-imikolov-train-00480.pickle

:param reader: is a reader creator
:param line_count: line count for each file
:param suffix: the suffix for the output files, should contain "%d"
means the id for each file. Default is "%05d.pickle"
:param dumper: is a callable function that dump object to file, this
function will be called as dumper(obj, f) and obj is the object
will be dumped, f is a file object. Default is cPickle.dump.
"""
if not callable(dumper):
raise TypeError("dumper should be callable.")
lines = []
indx_f = 0
for i, d in enumerate(reader()):
lines.append(d)
if i >= line_count and i % line_count == 0:
with open(suffix % indx_f, "w") as f:
dumper(lines, f)
lines = []
indx_f += 1
if lines:
with open(suffix % indx_f, "w") as f:
dumper(lines, f)


def cluster_files_reader(files_pattern,
trainer_count,
trainer_id,
loader=cPickle.load):
"""
Create a reader that yield element from the given files, select
a file set according trainer count and trainer_id

:param files_pattern: the files which generating by split(...)
:param trainer_count: total trainer count
:param trainer_id: the trainer rank id
:param loader: is a callable function that load object from file, this
function will be called as loader(f) and f is a file object.
Default is cPickle.load
"""

def reader():
if not callable(loader):
raise TypeError("loader should be callable.")
file_list = glob.glob(files_pattern)
file_list.sort()
my_file_list = []
for idx, fn in enumerate(file_list):
if idx % trainer_count == trainer_id:
print "append file: %s" % fn
my_file_list.append(fn)
for fn in my_file_list:
with open(fn, "r") as f:
lines = loader(f)
for line in lines:
yield line

return reader
25 changes: 25 additions & 0 deletions python/paddle/v2/dataset/tests/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import paddle.v2.dataset.common
import unittest
import tempfile
import glob


class TestCommon(unittest.TestCase):
Expand All @@ -32,6 +33,30 @@ def test_download(self):
paddle.v2.dataset.common.download(
yi_avatar, 'test', 'f75287202d6622414c706c36c16f8e0d'))

def test_split(self):
def test_reader():
def reader():
for x in xrange(10):
yield x

return reader

_, temp_path = tempfile.mkstemp()
paddle.v2.dataset.common.split(
test_reader(), 4, suffix=temp_path + '/test-%05d.pickle')
files = glob.glob(temp_path + '/test-%05d.pickle')
self.assertEqual(len(files), 3)

def test_cluster_file_reader(self):
_, temp_path = tempfile.mkstemp()
for x in xrange(5):
with open(temp_path + '/%05d.test' % x) as f:
f.write('%d\n' % x)
reader = paddle.v2.dataset.common.cluster_files_reader(
temp_path + '/*.test', 5, 0)
for idx, e in enumerate(reader()):
self.assertEqual(e, str("0"))


if __name__ == '__main__':
unittest.main()