diff --git a/benchmark/python/sparse/dot.py b/benchmark/python/sparse/dot.py index 145d05d2e1ea..164e50aef051 100644 --- a/benchmark/python/sparse/dot.py +++ b/benchmark/python/sparse/dot.py @@ -26,9 +26,9 @@ import mxnet as mx import numpy as np import numpy.random as rnd -from mxnet.test_utils import rand_ndarray, set_default_context, assert_almost_equal +from mxnet.test_utils import rand_ndarray, set_default_context, assert_almost_equal, get_bz2_data from mxnet.base import check_call, _LIB -from util import get_data, estimate_density +from util import estimate_density PARSER = argparse.ArgumentParser(description="Benchmark sparse operators", formatter_class=argparse.ArgumentDefaultsHelpFormatter) @@ -204,7 +204,7 @@ def test_dot_real(data_dict): path = os.path.join(data_dir, data_dict['data_name']) if not os.path.exists(path): - get_data( + get_bz2_data( data_dir, data_dict['data_name'], data_dict['url'], diff --git a/benchmark/python/sparse/sparse_end2end.py b/benchmark/python/sparse/sparse_end2end.py index 33a82881780c..0c1699b1daa5 100644 --- a/benchmark/python/sparse/sparse_end2end.py +++ b/benchmark/python/sparse/sparse_end2end.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. -from mxnet.test_utils import * import time import argparse import os import multiprocessing +from mxnet.test_utils import * MAX_NUM_BATCH = 99999999 COMP = "compute" @@ -57,19 +57,8 @@ parser.add_argument('--measure-only', default=None, help="Measure only", choices=[IO, COMP, COMM]) - - -def get_libsvm_data(data_dir, data_name, url, data_origin_name): - if not os.path.isdir(data_dir): - os.system("mkdir " + data_dir) - os.chdir(data_dir) - if (not os.path.exists(data_name)): - import urllib - zippath = os.path.join(data_dir, data_origin_name) - urllib.urlretrieve(url, zippath) - os.system("bzip2 -d %r" % data_origin_name) - os.chdir("..") - +parser.add_argument('--omit-row-sparse-push', action='store_true', + help="omit row_sparse_push") class DummyIter(mx.io.DataIter): "A dummy iterator that always return the same batch, used for speed testing" @@ -119,13 +108,14 @@ def next(self): def get_sym(feature_dim): - x = mx.symbol.Variable("data", stype='csr') - norm_init = mx.initializer.Normal(sigma=0.01) - w = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim), init=norm_init, stype='row_sparse') - embed = mx.symbol.sparse.dot(x, w) - y = mx.symbol.Variable("softmax_label") - model = mx.symbol.SoftmaxOutput(data=embed, label=y, name="out") - return model + inputs = mx.symbol.Variable("data", stype='csr') + norm_init = mx.initializer.Normal(sigma=0.01) + weights = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim), + init=norm_init, stype='row_sparse') + embed = mx.symbol.sparse.dot(inputs, weights) + softmax_output = mx.symbol.Variable("softmax_label") + model = mx.symbol.SoftmaxOutput(data=embed, label=softmax_output, name="out") + return model def row_sparse_push(kv, param_arrays, grad_arrays, param_names): @@ -170,6 +160,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority): log_level = args.sparse_log_level measure_only = args.measure_only num_cores = multiprocessing.cpu_count() + omit_row_sparse_push = args.omit_row_sparse_push if measure_only == COMP or measure_only == IO: assert not kvstore, "when compute_only or io_only is set, kvstore should be None" num_batch = datasets[dataset]['lc'] / batch_size if num_batch == MAX_NUM_BATCH else num_batch @@ -211,7 +202,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority): data_dir = os.path.join(os.getcwd(), 'data') path = os.path.join(data_dir, metadata['data_name']) if not os.path.exists(path): - get_libsvm_data(data_dir, metadata['data_name'], metadata['url'], + get_bz2_data(data_dir, metadata['data_name'], metadata['url'], metadata['data_origin_name']) assert os.path.exists(path) @@ -279,7 +270,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority): if nbatch == 1: mod.forward_backward(batch) mod.update() - else: + elif not omit_row_sparse_push: row_sparse_push(kv, mod._exec_group.param_arrays, mod._exec_group.grad_arrays, mod._exec_group.param_names) diff --git a/benchmark/python/sparse/util.py b/benchmark/python/sparse/util.py index 947ff4a65037..c20b33a86d65 100644 --- a/benchmark/python/sparse/util.py +++ b/benchmark/python/sparse/util.py @@ -18,19 +18,6 @@ import os import random - -def get_data(data_dir, data_name, url, data_origin_name): - if not os.path.isdir(data_dir): - os.system("mkdir " + data_dir) - os.chdir(data_dir) - if (not os.path.exists(data_name)): - import urllib - zippath = os.path.join(data_dir, data_origin_name) - urllib.urlretrieve(url, zippath) - os.system("bzip2 -d %r" % data_origin_name) - os.chdir("..") - - def estimate_density(DATA_PATH, feature_size): """sample 10 times of a size of 1000 for estimating the density of the sparse dataset""" if not os.path.exists(DATA_PATH): diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index 075d61d70697..daf421dbd631 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -28,6 +28,7 @@ import os import errno import logging +import bz2 from contextlib import contextmanager import numpy as np import numpy.testing as npt @@ -1409,6 +1410,20 @@ def read_data(label_url, image_url): return {'train_data':train_img, 'train_label':train_lbl, 'test_data':test_img, 'test_label':test_lbl} +def get_bz2_data(data_dir, data_name, url, data_origin_name): + download(url, dirname=data_dir, overwrite=False) + os.chdir(data_dir) + if not os.path.exists(data_name): + bz_file = bz2.BZ2File(data_origin_name, 'rb') + with open(data_name, 'wb') as fout: + try: + content = bz_file.read() + fout.write(content) + finally: + bz_file.close() + os.remove(data_origin_name) + os.chdir("..") + def set_env_var(key, val, default_val=""): """Set environment variable diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index c7f4f2004037..adf1d9151bb9 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -177,26 +177,6 @@ def test_NDArrayIter_csr(): begin += batch_size def test_LibSVMIter(): - def get_data(data_dir, data_name, url, data_origin_name): - if not os.path.isdir(data_dir): - os.system("mkdir " + data_dir) - os.chdir(data_dir) - if (not os.path.exists(data_name)): - if sys.version_info[0] >= 3: - from urllib.request import urlretrieve - else: - from urllib import urlretrieve - zippath = os.path.join(data_dir, data_origin_name) - urlretrieve(url, zippath) - import bz2 - bz_file = bz2.BZ2File(data_origin_name, 'rb') - with open(data_name, 'wb') as fout: - try: - content = bz_file.read() - fout.write(content) - finally: - bz_file.close() - os.chdir("..") def check_libSVMIter_synthetic(): cwd = os.getcwd() @@ -239,7 +219,7 @@ def check_libSVMIter_news_data(): batch_size = 128 num_examples = news_metadata['num_examples'] data_dir = os.path.join(os.getcwd(), 'data') - get_data(data_dir, news_metadata['name'], news_metadata['url'], + get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'], news_metadata['origin_name']) path = os.path.join(data_dir, news_metadata['name']) data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],),