Skip to content

Commit

Permalink
Add an argument to omit row_sparse_push for end to end benchmarking (a…
Browse files Browse the repository at this point in the history
…pache#7799)

* Omit row_sparse_push changes

* Call get_bz2 in test_utils

* Add get_bz_data

* Remove non useful files

* Keep only bz2 changes in test_io

* Remove unnneeded blank line
  • Loading branch information
anirudh2290 authored and piiswrong committed Sep 24, 2017
1 parent 6c3ca32 commit ebf1bf9
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 60 deletions.
6 changes: 3 additions & 3 deletions benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
import mxnet as mx
import numpy as np
import numpy.random as rnd
from mxnet.test_utils import rand_ndarray, set_default_context, assert_almost_equal
from mxnet.test_utils import rand_ndarray, set_default_context, assert_almost_equal, get_bz2_data
from mxnet.base import check_call, _LIB
from util import get_data, estimate_density
from util import estimate_density

PARSER = argparse.ArgumentParser(description="Benchmark sparse operators",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_dot_real(data_dict):

path = os.path.join(data_dir, data_dict['data_name'])
if not os.path.exists(path):
get_data(
get_bz2_data(
data_dir,
data_dict['data_name'],
data_dict['url'],
Expand Down
37 changes: 14 additions & 23 deletions benchmark/python/sparse/sparse_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
# specific language governing permissions and limitations
# under the License.

from mxnet.test_utils import *
import time
import argparse
import os
import multiprocessing
from mxnet.test_utils import *

MAX_NUM_BATCH = 99999999
COMP = "compute"
Expand Down Expand Up @@ -57,19 +57,8 @@
parser.add_argument('--measure-only', default=None,
help="Measure only",
choices=[IO, COMP, COMM])


def get_libsvm_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")

parser.add_argument('--omit-row-sparse-push', action='store_true',
help="omit row_sparse_push")

class DummyIter(mx.io.DataIter):
"A dummy iterator that always return the same batch, used for speed testing"
Expand Down Expand Up @@ -119,13 +108,14 @@ def next(self):


def get_sym(feature_dim):
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
w = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim), init=norm_init, stype='row_sparse')
embed = mx.symbol.sparse.dot(x, w)
y = mx.symbol.Variable("softmax_label")
model = mx.symbol.SoftmaxOutput(data=embed, label=y, name="out")
return model
inputs = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
weights = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim),
init=norm_init, stype='row_sparse')
embed = mx.symbol.sparse.dot(inputs, weights)
softmax_output = mx.symbol.Variable("softmax_label")
model = mx.symbol.SoftmaxOutput(data=embed, label=softmax_output, name="out")
return model


def row_sparse_push(kv, param_arrays, grad_arrays, param_names):
Expand Down Expand Up @@ -170,6 +160,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
log_level = args.sparse_log_level
measure_only = args.measure_only
num_cores = multiprocessing.cpu_count()
omit_row_sparse_push = args.omit_row_sparse_push
if measure_only == COMP or measure_only == IO:
assert not kvstore, "when compute_only or io_only is set, kvstore should be None"
num_batch = datasets[dataset]['lc'] / batch_size if num_batch == MAX_NUM_BATCH else num_batch
Expand Down Expand Up @@ -211,7 +202,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
data_dir = os.path.join(os.getcwd(), 'data')
path = os.path.join(data_dir, metadata['data_name'])
if not os.path.exists(path):
get_libsvm_data(data_dir, metadata['data_name'], metadata['url'],
get_bz2_data(data_dir, metadata['data_name'], metadata['url'],
metadata['data_origin_name'])
assert os.path.exists(path)

Expand Down Expand Up @@ -279,7 +270,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
if nbatch == 1:
mod.forward_backward(batch)
mod.update()
else:
elif not omit_row_sparse_push:
row_sparse_push(kv, mod._exec_group.param_arrays, mod._exec_group.grad_arrays, mod._exec_group.param_names)


Expand Down
13 changes: 0 additions & 13 deletions benchmark/python/sparse/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,6 @@
import os
import random


def get_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
import urllib
zippath = os.path.join(data_dir, data_origin_name)
urllib.urlretrieve(url, zippath)
os.system("bzip2 -d %r" % data_origin_name)
os.chdir("..")


def estimate_density(DATA_PATH, feature_size):
"""sample 10 times of a size of 1000 for estimating the density of the sparse dataset"""
if not os.path.exists(DATA_PATH):
Expand Down
15 changes: 15 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import os
import errno
import logging
import bz2
from contextlib import contextmanager
import numpy as np
import numpy.testing as npt
Expand Down Expand Up @@ -1409,6 +1410,20 @@ def read_data(label_url, image_url):
return {'train_data':train_img, 'train_label':train_lbl,
'test_data':test_img, 'test_label':test_lbl}

def get_bz2_data(data_dir, data_name, url, data_origin_name):
download(url, dirname=data_dir, overwrite=False)
os.chdir(data_dir)
if not os.path.exists(data_name):
bz_file = bz2.BZ2File(data_origin_name, 'rb')
with open(data_name, 'wb') as fout:
try:
content = bz_file.read()
fout.write(content)
finally:
bz_file.close()
os.remove(data_origin_name)
os.chdir("..")

def set_env_var(key, val, default_val=""):
"""Set environment variable
Expand Down
22 changes: 1 addition & 21 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,26 +177,6 @@ def test_NDArrayIter_csr():
begin += batch_size

def test_LibSVMIter():
def get_data(data_dir, data_name, url, data_origin_name):
if not os.path.isdir(data_dir):
os.system("mkdir " + data_dir)
os.chdir(data_dir)
if (not os.path.exists(data_name)):
if sys.version_info[0] >= 3:
from urllib.request import urlretrieve
else:
from urllib import urlretrieve
zippath = os.path.join(data_dir, data_origin_name)
urlretrieve(url, zippath)
import bz2
bz_file = bz2.BZ2File(data_origin_name, 'rb')
with open(data_name, 'wb') as fout:
try:
content = bz_file.read()
fout.write(content)
finally:
bz_file.close()
os.chdir("..")

def check_libSVMIter_synthetic():
cwd = os.getcwd()
Expand Down Expand Up @@ -239,7 +219,7 @@ def check_libSVMIter_news_data():
batch_size = 128
num_examples = news_metadata['num_examples']
data_dir = os.path.join(os.getcwd(), 'data')
get_data(data_dir, news_metadata['name'], news_metadata['url'],
get_bz2_data(data_dir, news_metadata['name'], news_metadata['url'],
news_metadata['origin_name'])
path = os.path.join(data_dir, news_metadata['name'])
data_train = mx.io.LibSVMIter(data_libsvm=path, data_shape=(news_metadata['feature_dim'],),
Expand Down

0 comments on commit ebf1bf9

Please sign in to comment.