Skip to content

Commit

Permalink
Change namespace and make logging functionality changes (apache#7627)
Browse files Browse the repository at this point in the history
* Change namespace and make logging functionality changes

* Help comment changes
  • Loading branch information
anirudh2290 authored and piiswrong committed Aug 29, 2017
1 parent 62e6d2f commit cb432a7
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
4 changes: 2 additions & 2 deletions benchmark/python/sparse/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def run_benchmark(mini_path):
for _ in train_iter:
csr_data = train_iter.getdata()
dns_data = csr_data.tostype('default')
cost_sparse = measure_cost(num_repeat, False, False, mx.nd.dot, csr_data, weight, transpose_a=transpose)
cost_sparse = measure_cost(num_repeat, False, False, mx.nd.sparse.dot, csr_data, weight, transpose_a=transpose)
cost_dense = measure_cost(num_repeat, False, False, mx.nd.dot, dns_data, weight, transpose_a=transpose)
total_cost["sparse"] += cost_sparse
total_cost["dense"] += cost_dense
Expand Down Expand Up @@ -270,7 +270,7 @@ def bench_dot(lhs_shape, rhs_shape, lhs_stype, rhs_stype,
set_default_context(ctx)
assert fw == "mxnet" or fw == "scipy"
# Set funcs
dot_func_sparse = mx.nd.dot if fw == "mxnet" else sp.spmatrix.dot
dot_func_sparse = mx.nd.sparse.dot if fw == "mxnet" else sp.spmatrix.dot
dot_func_dense = mx.nd.dot if fw == "mxnet" else np.dot
# Create matrix instances
lhs_nd = rand_ndarray(lhs_shape, lhs_stype, density=lhs_den, distribution=distribution)
Expand Down
17 changes: 12 additions & 5 deletions benchmark/python/sparse/sparse_end2end.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
help='whether to use dummy iterator to exclude io cost')
parser.add_argument('--kvstore', type=str, default='local',
help='what kvstore to use [local, dist_sync, etc]')
parser.add_argument('--log-level', type=str, default='debug',
help='logging level [debug, info, error]')
parser.add_argument('--sparse-log-level', type=str, default='INFO',
help='logging level [DEBUG, INFO, ERROR]')
parser.add_argument('--dataset', type=str, default='avazu',
help='what test dataset to use')
parser.add_argument('--num-gpu', type=int, default=0,
Expand All @@ -46,6 +46,8 @@
help='number of columns of the forward output')
parser.add_argument('--dummy-metric', type=int, default=0,
help='whether to call update_metric')
parser.add_argument('--enable-logging-for', default="0",
help="Enable logging for the specified list of workers")


def get_libsvm_data(data_dir, data_name, url, data_origin_name):
Expand Down Expand Up @@ -101,7 +103,7 @@ def get_sym(feature_dim):
x = mx.symbol.Variable("data", stype='csr')
norm_init = mx.initializer.Normal(sigma=0.01)
w = mx.symbol.Variable("w", shape=(feature_dim, args.output_dim), init=norm_init, stype='row_sparse')
embed = mx.symbol.dot(x, w)
embed = mx.symbol.sparse.dot(x, w)
y = mx.symbol.Variable("softmax_label")
model = mx.symbol.SoftmaxOutput(data=embed, label=y, name="out")
return model
Expand Down Expand Up @@ -137,7 +139,7 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):
batch_size = args.batch_size if args.num_gpu == 0 else args.num_gpu * args.batch_size
dummy_iter = args.dummy_iter
dataset = args.dataset
log_level = args.log_level
log_level = args.sparse_log_level
contexts = mx.context.cpu(0) if args.num_gpu < 1\
else [mx.context.gpu(i) for i in range(args.num_gpu)]

Expand All @@ -148,12 +150,17 @@ def row_sparse_pull(kv, key, data, slices, weight_array, priority):

# only print log for rank 0 worker
import logging
if rank != 0:
if log_level == 'ERROR':
log_level = logging.ERROR
elif log_level == 'DEBUG':
log_level = logging.DEBUG
else:
log_level = logging.INFO

# Only log if it is in the list of workers to be logged
logging_workers_list = [int(i) for i in args.enable_logging_for.split(",")]
log_level = log_level if rank in logging_workers_list else logging.CRITICAL

head = '%(asctime)-15s %(message)s'
logging.basicConfig(level=log_level, format=head)

Expand Down

0 comments on commit cb432a7

Please sign in to comment.