From e1b536ae7fc93785bc7bdae97397e675ec10b7a8 Mon Sep 17 00:00:00 2001 From: Ziheng Jiang Date: Sun, 9 Oct 2016 23:46:29 +0800 Subject: [PATCH] MXNet Profiler (#3163) * NNVM Refactor (#3194) * Init nnvm change * temp checkin * Move TShape to NNVM * Redirect Symbolic API to NNVM * Add Op Prop Adapter * Finish migrate in shape infer * Pass all symbolic test * temp commit * enable aux data * [EXEC] Basic version of exec for forward only * [EXEC] Enable most optimizations, still wait grad and context * fix legacy op with latest one * Update NNVM NodeRef * Adapt to newer interface * ALl registry of backop is complete * temp commit * Hack finish backward pass * [EXEC] One day pass * [EXEC] Pass all operator unittest * [EXEC] enable model parallel * Fully pass all legacy tests * Remove legacy symbolic code * update news * Make travis compile * Fix python3 * Update viz module to new json format * [NNVM] Imperative Invoke (#3208) * [Engine] Deduplicate Variable Util * [NNVM] NNVM Imperative Invoke * [NNVM] Imperative improve speed * fix * fix * [scala] link libnnvm.a (#3214) * [PYTHON] Optional Cython Module for Symbols (#3242) * [CYTHON] Checkin cython enhancement * fix lint * [DOC] Move common doc to base * [EXEC] Support fcompute (#3249) * [EXEC] Support fcompute * Fix lint * fix lint * [OP] Add alias support (#3261) * Fix path in setup.py (#3276) * Fix path in setup.py * revert the nnvm version * [WIP] Element wise op refactor (#3245) * [OPERATOR] Refactor Unary Ops * [OPERATOR] Refactor Binary Scalar Ops * Use alias * update nnvm version (#3290) * Fix breaking changes after pull master (#3291) * [CYTHON] Cython module for NDArray (#3292) * [NDARRAY] Cython module for ndarray * More strict tests * [NNVM] change of attr to set_attr (#3303) * Update run_test.sh * add nnvm cmake with windows (#3255) * [WIP] binary broadcast wip (#3301) * [WIP] binary broadcast wip [OPERATOR] Binary Broadcast ops fix lint lint fix max and min update submodule before removing reduce axis broad cast reduce ops * update * fix * fix warning * fix * x (#3308) * [IO] Python based ImageIter and Augumenter (#3227) * [IO] Python based ImageIter and Augumenter * fix * fix * fix * [OPT] NNVM Optimizer (#3314) * fix cpython in windows (#3309) * Add Mathematical functions (#3317) * fix image io * add hypot degrees radians cosh sinh tanh arcsinh arccosh arctanh (#3335) * add recent examples, collect some missing tutorials (#3340) * Improving docs & utilities for distributed training example. (#3341) * add init dict * disable SSE for arm hardware e.g. Raspberry Pi (#3346) * Add channel_ to Shape2D calculation (#3181) * Add channel_ to Shape2D calculation * scalapkg, add example multitask (#3186) * RNN cell demo with ptb LSTM language model (#3197) * rnn-cell demo (push to server for testing) * a running example with cuDNN RNN cell * Bulk lint fix (#3211) * [TENSOR] Add FlatTo1D for all elementwise ops (#3238) * Fix little bug on context (#3202) * add PennTreeBank Language Model using lstm model in R (#2659) * Add function 'print_summary' and some revise (#3161) * Add function 'print_summary' and some revise Add function 'print_summary' for print detail information of network, and format argument was add in 'plot_network'. You can use 'print_summary' like: """ net = get_symbol(1000) shape = {'softmax_label': (64, 12), 'data': (64, 3, 224, 224)} mx.viz.print_summary(net, shape=shape) """ If without shape, the number of arguments would be nonsense currently. * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Update visualization.py * Added my CmakeLists.txt for caffe plugin, etc. * Revert "fix travis scala test config" (#3246) This reverts parts of commit 3e15f6288609419e7bf5aa3119b1212eaeeec8be. Reenables testing the Julia bindings * [Scala] Code generation for Symbol (#3217) [scala] auto-generate Symbol functions * fix spelling errors (#3258) Also align grammar and punctuation in short descriptions of features * fix typo in run_test.sh (#3260) * Copy slice along arbitrary axis (#3259) * rnn-cell demo (push to server for testing) * a running example with cuDNN RNN cell * add copyslice along arbitrary axis for NDArray * copy_slice_to as an ndarray operator * Python interface to the _copy_slice_to operator * fix lint error * Enable concatenation for dim-1 vectors (#3264) * fix PReLU backward computing (#3277) * Add `reverse` option in Reshape (#3280) * add scala example, end2end neural-style (#3267) add scala example, end2end neural-style * Improve multi-GPU performance (#3241) * update kvstore * update model.py * bandwith tool * update readme * tiny * fix lint * fix batch size of dist_device_sync * fix * fix perf problem of kvstore when only using a single device * roll back to previous strategy how to choose update_on_kvsotre * add an optionl MXNET_ENABLE_GPU_P2P to control whether or not use p2p * update dmlccore (#3293) * Fix newer version of gtest and cpptest (#3294) * when set use_global_stats then do not use cudnn (#3289) * when set use_global_stats then do not use cudnn * fix batch norm with use_global_stats * Fix req+reserve_space in cudnn_rnn (#3274) Fix req Fix reserve_space Allocate reserve_space using Storage * add cudnn off option in Convolution (#3270) * add support for building on power (#3302) * add recent examples, collect some missing tutorials (#3340) * CMake for caffe plugin * Fix metric & im2rec.py * [Scala] Nnvm ops for NDArray & Symbol (#3361) * [scala] nnvm op support * [scala] remove unused codes * fix scala native code style * [R] Fix the R interface (#3334) * [R] Fix the R interface. remove man * Fix BN legacy issue * Locate compiled library on Windows (#3369) * Fix metric & im2rec.py (#3375) image io fix * Update legacy op FBackwardInGradIndex (#3376) * Update legacy op FBackwardInGradIndex * fix test * Fix for LRN Layer (#3366) * fixed cpu forward bug * added out_data[lrn_enum::kOut] as backward req. * removed lint * removed duplicate out_data[lrn_enum::kTmpNorm], * removed inplace option * add backward index * include some special functions (#3337) - gamma - gammaln - log1p - expm1 * fix kv build (#3385) * initial profiler branch based on dmlc/mxnet:nnvm * [profiler] add profiler & modify engine API * [profiler] add USE_PROFILER compile flag & modify code for changed engine api * [profiler] add c_api interface & modify graph_executor * [profiler] add python api * [profiler] typo & lint error * [profiler] reduce overhead & add PROFIELR_MESSAGE_FUNCNAME macro * [profiler] remove profiling argument from PushSync/PushAsync * [profiler] refactor profiler.h/.cc * [profiler] improve readability * [profiler] typo && add TODO comment * [profiler] fix ndarray op name & add WaitForVar back * [profiler] add example/profiler/profiler_ndarray.py * [profiler] fix memleak by using op->name * [profiler] fix lint * [profiler] fix lint --- Makefile | 5 + example/profiler/profiler_executor.py | 142 +++++++++++ example/profiler/profiler_matmul.py | 49 ++++ example/profiler/profiler_ndarray.py | 311 +++++++++++++++++++++++++ include/mxnet/base.h | 13 ++ include/mxnet/c_api.h | 18 ++ include/mxnet/engine.h | 18 +- make/config.mk | 3 + python/mxnet/__init__.py | 2 + python/mxnet/profiler.py | 37 +++ src/c_api/c_api.cc | 15 ++ src/c_api/c_api_ndarray.cc | 6 +- src/common/mxrtc.cc | 3 +- src/engine/naive_engine.cc | 19 +- src/engine/profiler.cc | 173 ++++++++++++++ src/engine/profiler.h | 139 +++++++++++ src/engine/threaded_engine.cc | 45 +++- src/engine/threaded_engine.h | 42 +++- src/executor/graph_executor.cc | 21 +- src/executor/graph_executor.h | 2 + src/io/image_io.cc | 3 +- src/kvstore/comm.h | 2 +- src/kvstore/kvstore_dist.h | 8 +- src/ndarray/ndarray.cc | 54 +++-- src/operator/cudnn_convolution.cc | 3 +- src/operator/custom.cc | 6 +- src/operator/ndarray_op.cc | 6 +- src/operator/operator_util.cc | 9 +- src/optimizer/sgd-inl.h | 12 +- src/resource.cc | 3 +- tests/python/unittest/test_operator.py | 167 +++++++++++++ 31 files changed, 1261 insertions(+), 75 deletions(-) create mode 100644 example/profiler/profiler_executor.py create mode 100644 example/profiler/profiler_matmul.py create mode 100644 example/profiler/profiler_ndarray.py create mode 100644 python/mxnet/profiler.py create mode 100644 src/engine/profiler.cc create mode 100644 src/engine/profiler.h diff --git a/Makefile b/Makefile index 0de099b37f1a..e6a45aba7581 100644 --- a/Makefile +++ b/Makefile @@ -45,6 +45,11 @@ else NVCCFLAGS = -std=c++11 -Xcompiler -D_FORCE_INLINES -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) endif +# CFLAGS for profiler +ifeq ($(USE_PROFILER), 1) + CFLAGS += -DMXNET_USE_PROFILER=1 +endif + ifndef LINT_LANG LINT_LANG="all" endif diff --git a/example/profiler/profiler_executor.py b/example/profiler/profiler_executor.py new file mode 100644 index 000000000000..514dc595f6df --- /dev/null +++ b/example/profiler/profiler_executor.py @@ -0,0 +1,142 @@ +import mxnet as mx +import argparse +import os, sys +import time +import numpy as np +from mxnet import profiler + + +def parse_args(): + parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.') + parser.add_argument('--profile_filename', type=str, default='profile_executor_5iter.json') + parser.add_argument('--iter_num', type=int, default=5) + parser.add_argument('--fc1', type=int, default=128) + parser.add_argument('--fc2', type=int, default=128) + parser.add_argument('--fc3', type=int, default=128) + parser.add_argument('--fc4', type=int, default=128) + return parser.parse_args() + + +def _download(data_dir): + if not os.path.isdir(data_dir): + os.system("mkdir " + data_dir) + os.chdir(data_dir) + if (not os.path.exists('train-images-idx3-ubyte')) or \ + (not os.path.exists('train-labels-idx1-ubyte')) or \ + (not os.path.exists('t10k-images-idx3-ubyte')) or \ + (not os.path.exists('t10k-labels-idx1-ubyte')): + os.system("wget http://webdocs.cs.ualberta.ca/~bx3/data/mnist.zip") + os.system("unzip -u mnist.zip; rm mnist.zip") + os.chdir("..") + + +def get_data(data_shape): + data_dir = "mnist/" + batch_size = 128 + if '://' not in data_dir: + _download(data_dir) + + train = mx.io.MNISTIter( + image = data_dir + "train-images-idx3-ubyte", + label = data_dir + "train-labels-idx1-ubyte", + input_shape = data_shape, + batch_size = batch_size, + shuffle = True, + ) + + val = mx.io.MNISTIter( + image = data_dir + "t10k-images-idx3-ubyte", + label = data_dir + "t10k-labels-idx1-ubyte", + input_shape = data_shape, + batch_size = batch_size, + ) + + return (train, val) + +def get_symbol(): + data = mx.symbol.Variable('data') + fc1 = mx.symbol.FullyConnected(data=data, name='fc1', num_hidden=args.fc1) + act1 = mx.symbol.Activation(data=fc1, name='relu1', act_type='relu') + fc2 = mx.symbol.FullyConnected(data=act1 , name='fc2', num_hidden=args.fc2) + act2 = mx.symbol.Activation(data=fc2, name='relu2', act_type='relu') + fc3 = mx.symbol.FullyConnected(data=act2 , name='fc3', num_hidden=args.fc3) + act3 = mx.symbol.Activation(data=fc3, name='relu3', act_type='relu') + fc4 = mx.symbol.FullyConnected(data=act3 , name='fc4', num_hidden=args.fc4) + act4 = mx.symbol.Activation(data=fc4, name='relu4', act_type='relu') + fc5 = mx.symbol.FullyConnected(data=act4 , name='fc5', num_hidden=10) + net = mx.symbol.SoftmaxOutput(data=fc5 , name='softmax') + return net, [('data', (128, 1, 28, 28))], [('softmax_label', (128, ))] + +def get_module(ctx, sym, provide_data, provide_label, batch_size=None, is_train=True, use_memonger=False): + if use_memonger: + sym = search_plan(sym, data=data_shapes) + mod = mx.mod.Module(symbol=sym, + data_names=[name for name, _ in provide_data], + label_names=[name for name, _ in provide_label], + context=ctx) + if batch_size is not None: + provide_data = [(name, (batch_size,) + shape[1:]) for name, shape in provide_data] + provide_label = [(name, (batch_size,) + shape[1:]) for name, shape in provide_label] + if is_train: + mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=True, inputs_need_grad=False) + else: + mod.bind(data_shapes=provide_data, label_shapes=provide_label, for_training=False, inputs_need_grad=False) + + mod.init_params(initializer=mx.init.Xavier(magnitude=2.)) + mod.init_optimizer(optimizer='ccsgd', + optimizer_params={ + 'learning_rate': 0.0001, + 'momentum': 0.0, + 'wd': 0.0 + }) + return mod + + +def benchmark(mod, dry_run=10, iterations=10): + if len(mod._context) == 1: + ctx = mod._context[0] + else: + ctx = mx.cpu() + data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx) for _, shape in mod.data_shapes] + label = [mx.nd.array(np.random.randint(1, 100, size=shape), ctx=ctx) for _, shape in mod.label_shapes] + batch = mx.io.DataBatch(data, label) + + # dry run + for i in range(dry_run): + mod.forward(batch, is_train=True) + mod.backward() + for output in mod.get_outputs(merge_multi_context=False)[0]: + output.wait_to_read() + mod.update() + + t0 = time.clock() + + profiler.profiler_set_state('run') + # real run + for i in range(iterations): + mod.forward(batch, is_train=True) + mod.backward() + mod.update() + for output in mod.get_outputs(merge_multi_context=False)[0]: + output.wait_to_read() + profiler.profiler_set_state('stop') + + t1 = time.clock() + return (t1 - t0)*1000.0 / iterations + + +def executor(num_iteration): + sym, provide_data, provide_label = get_symbol() + ctx = [mx.gpu(0)] + mod = get_module(ctx, sym, provide_data, provide_label, batch_size=128) + return benchmark(mod, iterations=args.iter_num) + + +args = parse_args() + +if __name__ == '__main__': + mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename) + print('profile file save to {0}'.format(args.profile_filename)) + print('executor num_iteration: {0}'.format(args.iter_num)) + executor_time = executor(args.iter_num) + print("executor {0} ms / iteration".format(executor_time)) diff --git a/example/profiler/profiler_matmul.py b/example/profiler/profiler_matmul.py new file mode 100644 index 000000000000..b25877a55565 --- /dev/null +++ b/example/profiler/profiler_matmul.py @@ -0,0 +1,49 @@ +import mxnet as mx +import argparse +import os, sys +import time +import numpy as np + +def parse_args(): + parser = argparse.ArgumentParser(description='Set network parameters for benchmark test.') + parser.add_argument('--profile_filename', type=str, default='profile_matmul_20iter.json') + parser.add_argument('--iter_num', type=int, default=100) + parser.add_argument('--begin_profiling_iter', type=int, default=50) + parser.add_argument('--end_profiling_iter', type=int, default=70) + return parser.parse_args() + +args = parse_args() + +if __name__ == '__main__': + mx.profiler.profiler_set_config(mode='symbolic', filename=args.profile_filename) + print('profile file save to {0}'.format(args.profile_filename)) + + + A = mx.sym.Variable('A') + B = mx.sym.Variable('B') + C = mx.symbol.dot(A, B) + + executor = C.simple_bind(mx.gpu(1), 'write', A=(4096, 4096), B=(4096, 4096)) + + a = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096)) + b = mx.random.uniform(-1.0, 1.0, shape=(4096, 4096)) + + a.copyto(executor.arg_dict['A']) + b.copyto(executor.arg_dict['B']) + + flag = False + print "execution begin" + for i in range(args.iter_num): + if i == args.begin_profiling_iter: + t0 = time.clock() + mx.profiler.profiler_set_state('run') + if i == args.end_profiling_iter: + t1 = time.clock() + mx.profiler.profiler_set_state('stop') + executor.forward() + c = executor.outputs[0] + c.wait_to_read() + print "execution end" + duration = t1 - t0 + print('duration: {0}s'.format(duration)) + print(' {0}ms/operator'.format(duration*1000/args.iter_num)) diff --git a/example/profiler/profiler_ndarray.py b/example/profiler/profiler_ndarray.py new file mode 100644 index 000000000000..bb4d658275c0 --- /dev/null +++ b/example/profiler/profiler_ndarray.py @@ -0,0 +1,311 @@ +import os +import mxnet as mx +import numpy as np +import pickle as pkl + + +def _np_reduce(dat, axis, keepdims, numpy_reduce_func): + if isinstance(axis, int): + axis = [axis] + else: + axis = list(axis) if axis is not None else range(len(dat.shape)) + ret = dat + for i in reversed(sorted(axis)): + ret = numpy_reduce_func(ret, axis=i) + if keepdims: + keepdims_shape = list(dat.shape) + for i in axis: + keepdims_shape[i] = 1 + ret = ret.reshape(tuple(keepdims_shape)) + return ret + + +def reldiff(a, b): + diff = np.abs(a - b) + norm = np.abs(a) + reldiff = np.max(diff / (norm + 1e-7)) + return reldiff + + +def same(a, b): + return np.sum(a != b) == 0 + + +def check_with_uniform(uf, arg_shapes, dim=None, npuf=None, rmin=-10, type_list=[np.float32]): + """check function consistency with uniform random numbers""" + if isinstance(arg_shapes, int): + assert dim + shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) + arg_shapes = [shape] * arg_shapes + for dtype in type_list: + ndarray_arg = [] + numpy_arg = [] + for s in arg_shapes: + npy = np.random.uniform(rmin, 10, s).astype(dtype) + narr = mx.nd.array(npy, dtype=dtype) + ndarray_arg.append(narr) + numpy_arg.append(npy) + out1 = uf(*ndarray_arg) + if npuf is None: + out2 = uf(*numpy_arg).astype(dtype) + else: + out2 = npuf(*numpy_arg).astype(dtype) + + assert out1.shape == out2.shape + if isinstance(out1, mx.nd.NDArray): + out1 = out1.asnumpy() + if dtype == np.float16: + assert reldiff(out1, out2) < 2e-3 + else: + assert reldiff(out1, out2) < 1e-6 + + +def random_ndarray(dim): + shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) + data = mx.nd.array(np.random.uniform(-10, 10, shape)) + return data + +def test_ndarray_elementwise(): + np.random.seed(0) + nrepeat = 10 + maxdim = 4 + all_type = [np.float32, np.float64, np.float16, np.uint8, np.int32] + real_type = [np.float32, np.float64, np.float16] + for repeat in range(nrepeat): + for dim in range(1, maxdim): + check_with_uniform(lambda x, y: x + y, 2, dim, type_list=all_type) + check_with_uniform(lambda x, y: x - y, 2, dim, type_list=all_type) + check_with_uniform(lambda x, y: x * y, 2, dim, type_list=all_type) + check_with_uniform(lambda x, y: x / y, 2, dim, type_list=real_type) + check_with_uniform(lambda x, y: x / y, 2, dim, rmin=1, type_list=all_type) + check_with_uniform(mx.nd.sqrt, 1, dim, np.sqrt, rmin=0) + check_with_uniform(mx.nd.square, 1, dim, np.square, rmin=0) + check_with_uniform(lambda x: mx.nd.norm(x).asscalar(), 1, dim, np.linalg.norm) + +def test_ndarray_negate(): + npy = np.random.uniform(-10, 10, (2,3,4)) + arr = mx.nd.array(npy) + assert reldiff(npy, arr.asnumpy()) < 1e-6 + assert reldiff(-npy, (-arr).asnumpy()) < 1e-6 + + # a final check to make sure the negation (-) is not implemented + # as inplace operation, so the contents of arr does not change after + # we compute (-arr) + assert reldiff(npy, arr.asnumpy()) < 1e-6 + + +def test_ndarray_choose(): + shape = (100, 20) + npy = np.arange(np.prod(shape)).reshape(shape) + arr = mx.nd.array(npy) + nrepeat = 3 + for repeat in range(nrepeat): + indices = np.random.randint(shape[1], size=shape[0]) + assert same(npy[np.arange(shape[0]), indices], + mx.nd.choose_element_0index(arr, mx.nd.array(indices)).asnumpy()) + + +def test_ndarray_fill(): + shape = (100, 20) + npy = np.arange(np.prod(shape)).reshape(shape) + arr = mx.nd.array(npy) + new_npy = npy.copy() + nrepeat = 3 + for repeat in range(nrepeat): + indices = np.random.randint(shape[1], size=shape[0]) + val = np.random.randint(shape[1], size=shape[0]) + new_npy[:] = npy + new_npy[np.arange(shape[0]), indices] = val + assert same(new_npy, + mx.nd.fill_element_0index(arr, mx.nd.array(val), mx.nd.array(indices)).asnumpy()) + + +def test_ndarray_onehot(): + shape = (100, 20) + npy = np.arange(np.prod(shape)).reshape(shape) + arr = mx.nd.array(npy) + nrepeat = 3 + for repeat in range(nrepeat): + indices = np.random.randint(shape[1], size=shape[0]) + npy[:] = 0.0 + npy[np.arange(shape[0]), indices] = 1.0 + mx.nd.onehot_encode(mx.nd.array(indices), out=arr) + assert same(npy, arr.asnumpy()) + + +def test_ndarray_copy(): + c = mx.nd.array(np.random.uniform(-10, 10, (10, 10))) + d = c.copyto(mx.Context('cpu', 0)) + assert np.sum(np.abs(c.asnumpy() != d.asnumpy())) == 0.0 + + +def test_ndarray_scalar(): + c = mx.nd.empty((10,10)) + d = mx.nd.empty((10,10)) + c[:] = 0.5 + d[:] = 1.0 + d -= c * 2 / 3 * 6.0 + c += 0.5 + assert(np.sum(c.asnumpy()) - 100 < 1e-5) + assert(np.sum(d.asnumpy()) + 100 < 1e-5) + c[:] = 2 + assert(np.sum(c.asnumpy()) - 200 < 1e-5) + d = -c + 2 + assert(np.sum(d.asnumpy()) < 1e-5) + +def test_ndarray_pickle(): + np.random.seed(0) + maxdim = 5 + nrepeat = 10 + for repeat in range(nrepeat): + for dim in range(1, maxdim): + a = random_ndarray(dim) + b = mx.nd.empty(a.shape) + a[:] = np.random.uniform(-10, 10, a.shape) + b[:] = np.random.uniform(-10, 10, a.shape) + a = a + b + data = pkl.dumps(a) + a2 = pkl.loads(data) + assert np.sum(a.asnumpy() != a2.asnumpy()) == 0 + + +def test_ndarray_saveload(): + np.random.seed(0) + maxdim = 5 + nrepeat = 10 + fname = 'tmp_list.bin' + for repeat in range(nrepeat): + data = [] + for i in range(10): + data.append(random_ndarray(np.random.randint(1, 5))) + mx.nd.save(fname, data) + data2 = mx.nd.load(fname) + assert len(data) == len(data2) + for x, y in zip(data, data2): + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 + dmap = {'ndarray xx %s' % i : x for i, x in enumerate(data)} + mx.nd.save(fname, dmap) + dmap2 = mx.nd.load(fname) + assert len(dmap2) == len(dmap) + for k, x in dmap.items(): + y = dmap2[k] + assert np.sum(x.asnumpy() != y.asnumpy()) == 0 + os.remove(fname) + + +def test_ndarray_slice(): + shape = (10,) + A = mx.nd.array(np.random.uniform(-10, 10, shape)) + A2 = A.asnumpy() + assert same(A[3:8].asnumpy(), A2[3:8]) + A2[3:8] *= 10; + A[3:8] = A2[3:8] + assert same(A[3:8].asnumpy(), A2[3:8]) + + +def test_ndarray_slice_along_axis(): + arr = mx.nd.array(np.random.uniform(-10, 10, (3, 4, 2, 3))) + sub_arr = mx.nd.zeros((3, 2, 2, 3)) + arr._copy_slice_to(1, 1, 3, sub_arr) + + # test we sliced correctly + assert same(arr.asnumpy()[:, 1:3, :, :], sub_arr.asnumpy()) + + # test that slice is copy, instead of shared memory + sub_arr[:] = 0 + assert not same(arr.asnumpy()[:, 1:3, :, :], sub_arr.asnumpy()) + + +def test_clip(): + shape = (10,) + A = mx.random.uniform(-10, 10, shape) + B = mx.nd.clip(A, -2, 2) + B1 = B.asnumpy() + for i in range(shape[0]): + assert B1[i] >= -2 + assert B1[i] <= 2 + +def test_dot(): + a = np.random.uniform(-3, 3, (3, 4)) + b = np.random.uniform(-3, 3, (4, 5)) + c = np.dot(a, b) + A = mx.nd.array(a) + B = mx.nd.array(b) + C = mx.nd.dot(A, B) + assert reldiff(c, C.asnumpy()) < 1e-5 + +def test_reduce(): + sample_num = 200 + def test_reduce_inner(numpy_reduce_func, nd_reduce_func): + for i in range(sample_num): + ndim = np.random.randint(1, 6) + shape = np.random.randint(1, 11, size=ndim) + axis_flags = np.random.randint(0, 2, size=ndim) + axes = [] + for (axis, flag) in enumerate(axis_flags): + if flag: + axes.append(axis) + keepdims = np.random.randint(0, 2) + dat = np.random.rand(*shape) - 0.5 + if 0 == len(axes): + axes = tuple(range(ndim)) + else: + axes = tuple(axes) + numpy_ret = numpy_reduce_func(dat, axis=axes, keepdims=keepdims) + + ndarray_ret = nd_reduce_func(mx.nd.array(dat), axis=axes, keepdims=keepdims) + if type(ndarray_ret) is mx.ndarray.NDArray: + ndarray_ret = ndarray_ret.asnumpy() + assert (ndarray_ret.shape == numpy_ret.shape) or \ + (ndarray_ret.shape == (1,) and numpy_ret.shape == ()), "nd:%s, numpy:%s" \ + %(ndarray_ret.shape, numpy_ret.shape) + err = np.square(ndarray_ret - numpy_ret).mean() + assert err < 1E-4 + test_reduce_inner(lambda data, axis, keepdims:_np_reduce(data, axis, keepdims, np.sum), + mx.nd.sum) + test_reduce_inner(lambda data, axis, keepdims:_np_reduce(data, axis, keepdims, np.max), + mx.nd.max) + test_reduce_inner(lambda data, axis, keepdims:_np_reduce(data, axis, keepdims, np.min), + mx.nd.min) + +def test_broadcast(): + sample_num = 1000 + def test_broadcast_to(): + for i in range(sample_num): + ndim = np.random.randint(1, 6) + target_shape = np.random.randint(1, 11, size=ndim) + shape = target_shape.copy() + axis_flags = np.random.randint(0, 2, size=ndim) + axes = [] + for (axis, flag) in enumerate(axis_flags): + if flag: + shape[axis] = 1 + dat = np.random.rand(*shape) - 0.5 + numpy_ret = dat + ndarray_ret = mx.nd.array(dat).broadcast_to(shape=target_shape) + if type(ndarray_ret) is mx.ndarray.NDArray: + ndarray_ret = ndarray_ret.asnumpy() + assert (ndarray_ret.shape == target_shape).all() + err = np.square(ndarray_ret - numpy_ret).mean() + assert err < 1E-8 + test_broadcast_to() + +if __name__ == '__main__': + mx.profiler.profiler_set_config(mode='all', filename='profile_ndarray.json') + mx.profiler.profiler_set_state('run') + test_ndarray_slice_along_axis() + test_broadcast() + test_ndarray_elementwise() + test_ndarray_slice() + test_ndarray_pickle() + test_ndarray_saveload() + test_ndarray_copy() + test_ndarray_negate() + test_ndarray_scalar() + test_clip() + test_dot() + test_ndarray_choose() + test_ndarray_onehot() + test_ndarray_fill() + test_reduce() + mx.profiler.profiler_set_state('stop') diff --git a/include/mxnet/base.h b/include/mxnet/base.h index 0df855965d09..52f72d756f0c 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -73,6 +73,19 @@ #define MXNET_PREDICT_ONLY 0 #endif +/*! + * \brief define operator message for profiler + */ +#if MXNET_USE_PROFILER +#define PROFILER_MESSAGE(msg) msg +#else +#define PROFILER_MESSAGE(msg) nullptr +#endif + +/*! + * \brief define function name as profiler message + */ +#define PROFILER_MESSAGE_FUNCNAME PROFILER_MESSAGE(__FUNCTION__) /*! \brief namespace of mxnet */ namespace mxnet { diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 80435815e62c..31ea0cf40273 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -160,6 +160,24 @@ MXNET_DLL int MXRandomSeed(int seed); * \return 0 when success, -1 when failure happens. */ MXNET_DLL int MXNotifyShutdown(); +/*! + * \brief Set up configuration of profiler + * \param mode indicate the working mode of profiler, + * record anly symbolic operator when mode == 0, + * record all operator when mode == 1 + * \param filename where to save trace file + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXSetProfilerConfig(int mode, const char* filename); +/*! + * \brief Set up state of profiler + * \param state indicate the working state of profiler, + * profiler not running when state == 0, + * profiler running when state == 1 + * \return 0 when success, -1 when failure happens. + */ +MXNET_DLL int MXSetProfilerState(int state); + //------------------------------------- // Part 1: NDArray creation and deletion //------------------------------------- diff --git a/include/mxnet/engine.h b/include/mxnet/engine.h index 162b3769f36e..c003d3f39097 100644 --- a/include/mxnet/engine.h +++ b/include/mxnet/engine.h @@ -107,12 +107,14 @@ class MXNET_API Engine { * mutate. * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. + * \param opr_name The operator name. * \return The new operator allocated. */ virtual OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop = FnProperty::kNormal) = 0; + FnProperty prop = FnProperty::kNormal, + const char* opr_name = nullptr) = 0; /*! * \brief Delete the given operator. * \param op The operator to delete. @@ -126,8 +128,9 @@ class MXNET_API Engine { * \param op The operator to push. * \param exec_ctx Execution context. * \param priority Priority of the action, as hint to the engine. + * \param profiling The variable indicate whether to profile this operator. */ - virtual void Push(OprHandle op, Context exec_ctx, int priority = 0) = 0; + virtual void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) = 0; /*! * \brief Push an asynchronous operation to the engine. * \param exec_fun Execution function, this function takes a parameter @@ -139,12 +142,14 @@ class MXNET_API Engine { * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. + * \param opr_name The operator name. */ virtual void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, - int priority = 0) = 0; + int priority = 0, + const char* opr_name = nullptr) = 0; /*! * \brief Schedule the deletion of a variable. * @@ -193,18 +198,19 @@ class MXNET_API Engine { * \param mutable_vars The variables that current operation will mutate. * \param prop Property of the function. * \param priority Priority of the action, as hint to the engine. + * \param opr_name The operator name. * \tparam SyncFn the synchronous function to be pushed. */ - template inline void PushSync(SyncFn exec_fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, FnProperty prop = FnProperty::kNormal, - int priority = 0) { + int priority = 0, + const char* opr_name = nullptr) { this->PushAsync([exec_fn](RunContext ctx, CallbackOnComplete on_complete) { exec_fn(ctx); on_complete(); - }, exec_ctx, const_vars, mutable_vars, prop, priority); + }, exec_ctx, const_vars, mutable_vars, prop, priority, opr_name); } /*! diff --git a/make/config.mk b/make/config.mk index 6610b5997705..b8ac34f56b29 100644 --- a/make/config.mk +++ b/make/config.mk @@ -27,6 +27,9 @@ export NVCC = nvcc # whether compile with debug DEBUG = 0 +# whether compiler with profiler +USE_PROFILER = + # the additional link flags you want to add ADD_LDFLAGS = diff --git a/python/mxnet/__init__.py b/python/mxnet/__init__.py index 715e31e343e8..efc9705eca62 100644 --- a/python/mxnet/__init__.py +++ b/python/mxnet/__init__.py @@ -45,6 +45,8 @@ from . import torch from . import torch as th +from . import profiler + from . import module from . import module as mod diff --git a/python/mxnet/profiler.py b/python/mxnet/profiler.py new file mode 100644 index 000000000000..dbbe1b3b4bf7 --- /dev/null +++ b/python/mxnet/profiler.py @@ -0,0 +1,37 @@ +# coding: utf-8 +# pylint: disable=fixme, invalid-name, too-many-arguments, too-many-locals, too-many-lines +# pylint: disable=too-many-branches, too-many-statements +"""profiler setting methods.""" +from __future__ import absolute_import + +import ctypes +from .base import _LIB, check_call, c_str + +def profiler_set_config(mode='symbolic', filename='profile.json'): + """Set up the configure of profiler. + + Parameters + ---------- + mode : string, optional + Indicting whether to enable the profiler, can + be 'symbolic' or 'all'. Default is `symbolic`. + filename : string, optional + The name of output trace file. Default is + 'trace.json'. + """ + mode2int = {'symbolic': 0, 'all': 1} + check_call(_LIB.MXSetProfilerConfig( + ctypes.c_int(mode2int[mode]), + c_str(filename))) + +def profiler_set_state(state='stop'): + """Set up the profiler state to record operator. + + Parameters + ---------- + state : string, optional + Indicting whether to run the profiler, can + be 'stop' or 'run'. Default is `stop`. + """ + state2int = {'stop': 0, 'run': 1} + check_call(_LIB.MXSetProfilerState(ctypes.c_int(state2int[state]))) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 7fae07934b52..d79e0f88a34b 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -25,6 +25,7 @@ #include #include "./c_api_common.h" #include "../operator/custom-inl.h" +#include "../engine/profiler.h" using namespace mxnet; @@ -76,6 +77,20 @@ int MXNotifyShutdown() { API_END(); } +int MXSetProfilerConfig(int mode, const char* filename) { + API_BEGIN(); + // mode, kOnlySymbolic: 0, kAllOperator: 1 + engine::Profiler::Get()->SetConfig(engine::Profiler::ProfilerMode(mode), std::string(filename)); + API_END(); +} + +int MXSetProfilerState(int state) { + API_BEGIN(); + // state, kNotRunning: 0, kRunning: 1 + engine::Profiler::Get()->SetState(engine::Profiler::ProfilerState(state)); + API_END(); +} + int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); *out = new NDArray(); diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index 59dad5c23051..5ad506c8d674 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -208,7 +208,8 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, rctx.get_stream()->Wait(); } on_complete(); - }, ctx, read_vars, write_vars); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); } else if (createop.count(op)) { Operator* opr = createop[op](attrs, ctx, in_shapes, in_types); struct Capture { @@ -253,7 +254,8 @@ int MXImperativeInvoke(AtomicSymbolCreator creator, delete capture; on_complete(); } - }, ctx, read_vars, write_vars); + }, ctx, read_vars, write_vars, FnProperty::kNormal, + 0, PROFILER_MESSAGE(op->name.c_str())); } else { LOG(FATAL) << "Operator " << op->name diff --git a/src/common/mxrtc.cc b/src/common/mxrtc.cc index c1ab065db627..e808e11215bf 100644 --- a/src/common/mxrtc.cc +++ b/src/common/mxrtc.cc @@ -71,7 +71,8 @@ void MXRtc::push(std::vector const& input, std::vector var_in, var_out; for (auto& i : input) var_in.push_back(i.var()); for (auto& i : output) var_out.push_back(i.var()); - Engine::Get()->PushSync(op, output[0].ctx(), var_in, var_out); + Engine::Get()->PushSync(op, output[0].ctx(), var_in, var_out, + FnProperty::kNormal, 0, PROFILER_MESSAGE("MXRtc")); } std::string MXRtc::decorate(const std::string& name, diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index d39452e65bc1..6044d9a62de5 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -18,6 +18,7 @@ class NaiveEngine final : public Engine { std::vector const_vars; std::vector mutable_vars; FnProperty prop; + const char* opr_name; }; NaiveEngine() { @@ -44,32 +45,37 @@ class NaiveEngine final : public Engine { OprHandle NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop) override { + FnProperty prop = FnProperty::kNormal, + const char* opr_name = nullptr) override { NaiveOpr *opr = new NaiveOpr(); opr->fn = fn; opr->const_vars = const_vars; opr->mutable_vars = mutable_vars; opr->prop = prop; + opr->opr_name = opr_name; return opr; } void DeleteOperator(OprHandle op) override { NaiveOpr *opr = op->Cast(); delete opr; } - void Push(OprHandle op, Context exec_ctx, int priority) override { + void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override { NaiveOpr *opr = op->Cast(); this->PushAsync(opr->fn, exec_ctx, opr->const_vars, opr->mutable_vars, - opr->prop); + opr->prop, + priority, + PROFILER_MESSAGE(opr->opr_name)); } void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop, - int priority = 0) override { + FnProperty prop = FnProperty::kNormal, + int priority = 0, + const char* opr_name = nullptr) override { CallbackOnComplete callback = CreateCallback( NaiveEngine::OnComplete, nullptr); this->req_completed_ = false; @@ -97,7 +103,8 @@ class NaiveEngine final : public Engine { << "NaiveEngine only support synchronize Push so far"; } void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override { - this->PushSync(delete_fn, exec_ctx, {}, {var}, FnProperty::kNormal); + this->PushSync(delete_fn, exec_ctx, {}, {var}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("DeleteVariable")); } void WaitForVar(VarHandle var) override { } diff --git a/src/engine/profiler.cc b/src/engine/profiler.cc new file mode 100644 index 000000000000..ade8be4610ed --- /dev/null +++ b/src/engine/profiler.cc @@ -0,0 +1,173 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file profiler.cc + * \brief implements profiler + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./profiler.h" + +namespace mxnet { +namespace engine { + +// TODO(ziheng) more lock free + +Profiler* Profiler::instance_ = nullptr; +std::mutex Profiler::m_; + +Profiler::Profiler() + : state_(kNotRunning), enable_output_(false), mode_(kOnlySymbolic), filename_("profile.json") { + this->init_time_ = NowInUsec(); + + // TODO(ziheng) get device number during execution + int kMaxNumCpus = 60, kMaxNumGpus = 16; + this->cpu_num_ = kMaxNumCpus, this->gpu_num_ = kMaxNumGpus; + + this->profile_stat = new DevStat[cpu_num_ + gpu_num_ + 1]; + for (unsigned int i = 0; i < cpu_num_; ++i) { + profile_stat[i].dev_name = "cpu/" + std::to_string(i); + } + for (unsigned int i = 0; i < gpu_num_; ++i) { + profile_stat[cpu_num_ + i].dev_name = "gpu/" + std::to_string(i); + } + profile_stat[cpu_num_ + gpu_num_].dev_name = "cpu pinned/"; +} + +Profiler* Profiler::Get() { + std::lock_guard lock{Profiler::m_}; + if (instance_ == nullptr) { + instance_ = new Profiler; + } + return instance_; +} + +void Profiler::SetState(ProfilerState state) { + std::lock_guard lock{Profiler::m_}; + this->state_ = state; + // once running, output will be enabled. + if (state == kRunning) + this->enable_output_ = true; +} + +void Profiler::SetConfig(ProfilerMode mode, std::string output_filename) { + std::lock_guard lock{Profiler::m_}; + this->mode_ = mode; + this->filename_ = output_filename; +} + +OprExecStat *Profiler::AddOprStat(int dev_type, int dev_id) { + std::lock_guard lock{Profiler::m_}; + + OprExecStat* opr_stat = new OprExecStat; + opr_stat->dev_type = dev_type; + opr_stat->dev_id = dev_id; + + int idx = (dev_type-1) * cpu_num_ + dev_id; + DevStat& dev_stat = profile_stat[idx]; + dev_stat.opr_exec_stats.push_back(opr_stat); + + return opr_stat; +} + +void Profiler::EmitPid(std::ostream& os, const std::string& name, int pid) { + os << " {\n" + << " \"ph\": \"M\",\n" + << " \"args\": {\n" + << " \"name\": \"" << name << "\"\n" + << " },\n" + << " \"pid\": " << pid << ",\n" + << " \"name\": \"process_name\"\n" + << " }"; +} + +void Profiler::EmitEvent(std::ostream& os, const std::string& name, + const std::string& category, const std::string& ph, + uint64_t ts, int pid, int tid) { + os << " {\n" + << " \"name\": \"" << name << "\",\n" + << " \"cat\": " << "\"" << category << "\",\n" + << " \"ph\": \""<< ph << "\",\n" + << " \"ts\": " << ts << ",\n" + << " \"pid\": " << pid << ",\n" + << " \"tid\": " << tid << "\n" + << " }"; +} + + +void Profiler::DumpProfile() { + std::lock_guard lock{Profiler::m_}; + std::ofstream file; + file.open(filename_); + + file << "{" << std::endl; + file << " \"traceEvents\": [" << std::endl; + + int dev_num = cpu_num_ + gpu_num_ + 1; + + for (int i = 0; i < dev_num; ++i) { + const DevStat &d = profile_stat[i]; + this->EmitPid(file, d.dev_name, i); + file << ",\n"; + } + + bool first_flag = true; + for (int i = 0; i < dev_num; ++i) { + const DevStat &d = profile_stat[i]; + int opr_num = d.opr_exec_stats.size(); + + for (int j = 0; j < opr_num; ++j) { + const OprExecStat* opr_stat = d.opr_exec_stats[j]; + + int pid = i; + int tid = opr_stat->thread_id; + + if (first_flag) { + first_flag = false; + } else { + file << ","; + } + file << std::endl; + this->EmitEvent(file, opr_stat->opr_name, "category", "B", + opr_stat->opr_start_rel_micros, pid, tid); + file << ",\n"; + this->EmitEvent(file, opr_stat->opr_name, "category", "E", + opr_stat->opr_end_rel_micros, pid, tid); + } + } + + file << "\n" << std::endl; + file << " ]," << std::endl; + file << " \"displayTimeUnit\": \"ms\"" << std::endl; + file << "}" << std::endl; +} + + +inline uint64_t NowInUsec() { + return std::chrono::duration_cast( + std::chrono::high_resolution_clock::now().time_since_epoch()).count(); +} + +void SetOprStart(OprExecStat* opr_stat) { + if (!opr_stat) { + LOG(WARNING) << "SetOpStart: nullptr"; + return; + } + opr_stat->opr_start_rel_micros = NowInUsec() - Profiler::Get()->GetInitTime(); +} + +void SetOprEnd(OprExecStat* opr_stat) { + if (!opr_stat) { + LOG(WARNING) << "SetOpEnd: nullptr"; + return; + } + opr_stat->opr_end_rel_micros = NowInUsec() - Profiler::Get()->GetInitTime(); +} + +} // namespace engine +} // namespace mxnet diff --git a/src/engine/profiler.h b/src/engine/profiler.h new file mode 100644 index 000000000000..b5a1cd5f7250 --- /dev/null +++ b/src/engine/profiler.h @@ -0,0 +1,139 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file profiler.h + * \brief implements profiler + */ +#ifndef MXNET_ENGINE_PROFILER_H_ +#define MXNET_ENGINE_PROFILER_H_ + +#include +#include +#include +#include + +namespace mxnet { +namespace engine { + +/*! + * \brief Operation execution statistics + */ +struct OprExecStat { + /*! \brief operation name */ + std::string opr_name; + /*! + * \brief operation execution start relative timestamp + * time unit is microsecond (10^-6 s) + */ + uint64_t opr_start_rel_micros; + /*! + * \brief operation execution end relative timestamp + * time unit is microsecond (10^-6 s) + */ + uint64_t opr_end_rel_micros; + /*! \brief id of thread which operation run on */ + uint32_t thread_id; + /*! + * \brief device type + * CPU: 1, GPU: 2, CPUPinned: 3 + */ + uint32_t dev_type; + /*! \brief device id */ + uint32_t dev_id; +}; + +/*! + * \brief Device statistics + */ +struct DevStat { + /*! \brief device name */ + std::string dev_name; + /*! \brief operation execution statistics on this device */ + std::vector opr_exec_stats; +}; + + +/*! + * \brief profiler that records the operation execution information + * and saves the profile statistics. + */ +class Profiler { + public: + enum ProfilerMode { + kOnlySymbolic = 0, + kAllOperator = 1 + }; + enum ProfilerState { + kNotRunning = 0, + kRunning = 1 + }; + /*! \return Profiler singleton */ + static Profiler* Get(); + /*! \brief set state of profiler */ + void SetState(ProfilerState state); + /*! \return state of profiler */ + inline ProfilerState GetState() { + return this->state_; + } + /*! \brief set configure of profiler */ + void SetConfig(ProfilerMode mode, std::string output_filename); + /*! \return mode of profiler */ + inline ProfilerMode GetMode() { + return this->mode_; + } + /*! \return whether the profiler is enabled to output */ + inline bool IsEnableOutput() { + return this->enable_output_; + } + /*! \brief dump the profile file */ + void DumpProfile(); + /*! \return the profiler init time, time unit is microsecond (10^-6) s */ + inline uint64_t GetInitTime() { + return init_time_; + } + /*! \brief add one operation execution record in + * corresponding device statistics */ + OprExecStat* AddOprStat(int dev_type, int dev_id); + + protected: + /*! \brief make constructor protected. */ + Profiler(); + + private: + /*! \brief generate device information following chrome profile file format */ + void EmitPid(std::ostream& os, const std::string& name, int pid); + /*! \brief generate event information following chrome profile file format */ + void EmitEvent(std::ostream& os, const std::string& name, + const std::string& category, const std::string& ph, + uint64_t ts, int pid, int tid); + /*! \brief Profiler instance */ + static Profiler* instance_; + /*! \brief internal mutex of the profiler */ + static std::mutex m_; + /*! \brief indicate whether the profiler is running */ + ProfilerState state_; + /*! \brief once running, enable profiler to output */ + bool enable_output_; + /*! \brief indicate what operator the profiler will record */ + ProfilerMode mode_; + /*! \brief filename to output profile file */ + std::string filename_; + /*! \brief profile statistics consist of multiple device statistics */ + DevStat* profile_stat; + /*! \brief cpu number on the machine */ + unsigned int cpu_num_; + /*! \brief gpu number on the machine */ + unsigned int gpu_num_; + /*! \brief the profiler init time */ + uint64_t init_time_; +}; + +/*! \return current clock time, time unit is microsecond (10^-6 s) */ +inline uint64_t NowInUsec(); +/*! \brief set operation execution start timestamp */ +void SetOprStart(OprExecStat* opr_stat); +/*! \brief set operation execution end timestamp */ +void SetOprEnd(OprExecStat* opr_stat); + +} // namespace engine +} // namespace mxnet +#endif // MXNET_ENGINE_PROFILER_H_ diff --git a/src/engine/threaded_engine.cc b/src/engine/threaded_engine.cc index 1dcc92d56b8f..1ee1a0934990 100644 --- a/src/engine/threaded_engine.cc +++ b/src/engine/threaded_engine.cc @@ -186,8 +186,10 @@ ThreadedOpr* ThreadedEngine::NewOperator( ThreadedEngine::AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop) { + FnProperty prop, + const char* opr_name) { auto ret = ThreadedOpr::New(); + ret->opr_name = opr_name; ret->fn = std::move(fn); ret->prop = prop; ret->const_vars.resize(const_vars.size()); @@ -249,10 +251,11 @@ void ThreadedEngine::DeleteOperator(OprHandle op) { threaded_opr->mutable_vars.end()); this->PushSync([threaded_opr](RunContext) { ThreadedOpr::Delete(threaded_opr); - }, Context::CPU(), {}, deps, FnProperty::kAsync); + }, Context::CPU(), {}, deps, FnProperty::kAsync, 0, + PROFILER_MESSAGE("DeleteOperator")); } -void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) { +void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority, bool profiling) { ThreadedOpr* threaded_opr = ThreadedOpr::CastFromBase(op); OprBlock* opr_block = OprBlock::New(); opr_block->opr = threaded_opr; @@ -262,6 +265,7 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) { threaded_opr->mutable_vars.size() + 1)); opr_block->ctx = exec_ctx; opr_block->priority = priority; + opr_block->profiling = profiling; ++pending_; // Add read dependencies. for (auto&& i : threaded_opr->const_vars) { @@ -279,10 +283,19 @@ void ThreadedEngine::Push(OprHandle op, Context exec_ctx, int priority) { void ThreadedEngine::PushAsync(AsyncFn fn, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop, int priority) { - ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop); + FnProperty prop, + int priority, + const char* opr_name) { + ThreadedOpr *opr = NewOperator(std::move(fn), const_vars, mutable_vars, prop, opr_name); opr->temporary = true; - Push(opr, exec_ctx, priority); +#if MXNET_USE_PROFILER + Profiler *profiler = Profiler::Get(); + bool profiling = (profiler->GetState() == Profiler::kRunning) && + (profiler->GetMode() == Profiler::kAllOperator); +#else + bool profiling = false; +#endif + Push(opr, exec_ctx, priority, profiling); } void ThreadedEngine::DeleteVariable(SyncFn delete_fn, @@ -294,7 +307,8 @@ void ThreadedEngine::DeleteVariable(SyncFn delete_fn, // so during `ThreadedEngine::OnComplete` it could be recycled. threaded_var->SetToDelete(); delete_fn(ctx); - }, exec_ctx, {}, {var}, FnProperty::kAsync); + }, exec_ctx, {}, {var}, FnProperty::kAsync, 0, + PROFILER_MESSAGE("DeleteVariable")); } void ThreadedEngine::WaitForVar(VarHandle var) { @@ -317,7 +331,8 @@ void ThreadedEngine::WaitForVar(VarHandle var) { if (engine_info_) { LOG(INFO) << "Sync is notified"; } - }, Context::CPU(), {var}, {}, FnProperty::kNormal); + }, Context::CPU(), {var}, {}, FnProperty::kNormal, 0, + PROFILER_MESSAGE("WaitForVar")); { std::unique_lock lock{finished_m_}; finished_cv_.wait(lock, [this, &done]() { @@ -379,9 +394,17 @@ inline void ThreadedEngine::OnComplete(ThreadedOpr* threaded_opr) { } void ThreadedEngine::OnCompleteStatic( - Engine *engine, void *threaded_opr) { - static_cast(engine)->OnComplete( - static_cast(threaded_opr)); + Engine *engine, void *opr_block_) { + OprBlock *opr_block = static_cast(opr_block_); + ThreadedOpr *threaded_opr = opr_block->opr; +#if MXNET_USE_PROFILER + if (opr_block->profiling && threaded_opr->opr_name) { + // record operator end timestamp + SetOprEnd(opr_block->opr_stat); + } +#endif + static_cast(engine)->OnComplete(threaded_opr); + OprBlock::Delete(opr_block); } } // namespace engine diff --git a/src/engine/threaded_engine.h b/src/engine/threaded_engine.h index 93aaae3c54f2..f5338ddad713 100644 --- a/src/engine/threaded_engine.h +++ b/src/engine/threaded_engine.h @@ -16,7 +16,9 @@ #include #include #include +#include #include "./engine_impl.h" +#include "./profiler.h" #include "../common/object_pool.h" namespace mxnet { @@ -50,6 +52,10 @@ struct OprBlock : public common::ObjectPoolAllocatable { Context ctx; /*! \brief priority of the function */ int priority; + /*! \brief indicate whether to profile this operator */ + bool profiling{false}; + /*! \brief operator execution statistics */ + OprExecStat *opr_stat; // define possible debug information DEFINE_ENGINE_DEBUG_INFO(OprBlock); /*! @@ -199,8 +205,10 @@ struct ThreadedOpr final : public Opr, std::vector const_vars; /*! \brief The variable this operation will mutate. */ std::vector mutable_vars; - /*! \brief the property of the operator */ + /*! \brief The property of the operator */ FnProperty prop; + /*! \brief The name of the operator */ + const char* opr_name{nullptr}; /*! * \brief Whether this is an temporary operator * that can be deleted right after the operation completed. @@ -234,14 +242,16 @@ class ThreadedEngine : public Engine { ThreadedOpr* NewOperator(AsyncFn fn, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop) override; + FnProperty prop = FnProperty::kNormal, + const char* opr_name = nullptr) override; void DeleteOperator(OprHandle op) override; - void Push(OprHandle op, Context exec_ctx, int priority) override; + void Push(OprHandle op, Context exec_ctx, int priority = 0, bool profiling = false) override; void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, std::vector const& mutable_vars, - FnProperty prop, - int priority) override; + FnProperty prop = FnProperty::kNormal, + int priority = 0, + const char* opr_name = nullptr) override; void DeleteVariable(SyncFn delete_fn, Context exec_ctx, VarHandle var) override; void WaitForVar(VarHandle var) override; void WaitForAll() override; @@ -258,6 +268,13 @@ class ThreadedEngine : public Engine { objpool_var_ref_ = common::ObjectPool::_GetSharedRef(); } ~ThreadedEngine() { +#if MXNET_USE_PROFILER + // dump trace file if profiler is enabled when engine is destructed. + Profiler* profiler = Profiler::Get(); + if (profiler->IsEnableOutput()) { + profiler->DumpProfile(); + } +#endif { std::unique_lock lock{finished_m_}; kill_.store(true); @@ -283,8 +300,19 @@ class ThreadedEngine : public Engine { */ void ExecuteOprBlock(RunContext run_ctx, OprBlock *opr_block) { ThreadedOpr* threaded_opr = opr_block->opr; +#if MXNET_USE_PROFILER + if (opr_block->profiling && threaded_opr->opr_name) { + const Context& ctx = opr_block->ctx; + opr_block->opr_stat = Profiler::Get()->AddOprStat(ctx.dev_type, ctx.dev_id); + uint64_t id = std::hash()(std::this_thread::get_id()); + opr_block->opr_stat->thread_id = id; + opr_block->opr_stat->opr_name = std::string(threaded_opr->opr_name); + // record operator start timestamp + SetOprStart(opr_block->opr_stat); + } +#endif CallbackOnComplete callback = this->CreateCallback( - ThreadedEngine::OnCompleteStatic, threaded_opr); + ThreadedEngine::OnCompleteStatic, opr_block); bool debug_info = (engine_info_ && debug_push_opr_ == opr_block); if (debug_info) { LOG(INFO) << "ExecuteOprBlock " << opr_block @@ -317,8 +345,6 @@ class ThreadedEngine : public Engine { } else { callback(); } - - OprBlock::Delete(opr_block); } private: diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 001ba4ed3d89..980dbfa76c4e 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -11,6 +11,7 @@ #include "./exec_pass.h" #include "./graph_executor.h" +#include "../engine/profiler.h" namespace mxnet { namespace exec { @@ -474,6 +475,11 @@ void GraphExecutor::InitCachedOps() { for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) { const auto& inode = idx[nid]; if (inode.source->is_variable()) continue; +#if MXNET_USE_PROFILER + op_nodes_[nid].opr_name = inode.source->op()->name.c_str(); +#else + op_nodes_[nid].opr_name = nullptr; +#endif if (skip_plus_node.at(nid)) { op_nodes_[nid].skip_exec_node = true; continue; } @@ -557,7 +563,8 @@ void GraphExecutor::InitCachedOps() { dedup(all_vars); Engine::Get()->PushSync([exec](RunContext rctx) { exec->Setup(); - }, Context::CPU(), {}, all_vars); + }, Context::CPU(), {}, all_vars, FnProperty::kNormal, 0, + PROFILER_MESSAGE("SetupExec")); auto exec_fun = [exec, is_async, is_gpu] ( RunContext ctx, Engine::CallbackOnComplete on_complete) { if (is_async) { @@ -578,8 +585,9 @@ void GraphExecutor::InitCachedOps() { } }; // setup the vars - op_nodes_[nid].cached_opr = Engine::Get()->NewOperator( - exec_fun, use_vars, mutate_vars, FnProperty::kNormal); + op_nodes_[nid].cached_opr = Engine::Get()->NewOperator( + exec_fun, use_vars, mutate_vars, FnProperty::kNormal, + PROFILER_MESSAGE(op_nodes_[nid].opr_name)); } } @@ -599,7 +607,12 @@ void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) { CHECK_EQ(opnode.exec->out_array.size(), 1); CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0])); } else if (opnode.cached_opr != nullptr) { - Engine::Get()->Push(opnode.cached_opr, opnode.ctx); +#if MXNET_USE_PROFILER + bool profiling = engine::Profiler::Get()->GetState() == engine::Profiler::kRunning; +#else + bool profiling = false; +#endif + Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling); } else { LOG(FATAL) << "Not accessed"; } diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h index c494e6c4cbdb..cae7c28aafd6 100644 --- a/src/executor/graph_executor.h +++ b/src/executor/graph_executor.h @@ -48,6 +48,8 @@ class GraphExecutor : public Executor { protected: // Information about operational node struct OpNode { + // The name of the operator + const char* opr_name; // the context of the node Context ctx; // The executor diff --git a/src/io/image_io.cc b/src/io/image_io.cc index c54ce5e61596..c1f57387c01b 100644 --- a/src/io/image_io.cc +++ b/src/io/image_io.cc @@ -149,7 +149,8 @@ void Imdecode(const nnvm::NodeAttrs& attrs, if (param.to_rgb) { cv::cvtColor(dst, dst, CV_BGR2RGB); } - }, ndout.ctx(), {ndin.var()}, {ndout.var()}); + }, ndout.ctx(), {ndin.var()}, {ndout.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("Imdecode")); (*outputs)[0] = ndout; #else LOG(FATAL) << "Build with USE_OPENCV=1 for image io."; diff --git a/src/kvstore/comm.h b/src/kvstore/comm.h index 64be5886319d..9c425d6077b6 100644 --- a/src/kvstore/comm.h +++ b/src/kvstore/comm.h @@ -84,7 +84,7 @@ class CommCPU : public Comm { Engine::Get()->PushSync([reduce, this](RunContext rctx) { ReduceSumCPU(reduce); }, Context::CPU(), const_vars, {reduce[0].var()}, - FnProperty::kCPUPrioritized, priority); + FnProperty::kCPUPrioritized, priority, PROFILER_MESSAGE("KVStoreReduce")); return buf.merged; } diff --git a/src/kvstore/kvstore_dist.h b/src/kvstore/kvstore_dist.h index d4b7853fb779..e579f0018167 100644 --- a/src/kvstore/kvstore_dist.h +++ b/src/kvstore/kvstore_dist.h @@ -116,7 +116,9 @@ class KVStoreDist : public KVStoreLocal { pinned_ctx_, {}, {recv_buf.var()}, - FnProperty::kNormal, priority); + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistPull")); comm_->Broadcast(key, recv_buf, grouped_vals[i], priority); } @@ -221,7 +223,9 @@ class KVStoreDist : public KVStoreLocal { pinned_ctx_, {send_buf.var()}, {}, - FnProperty::kNormal, priority); + FnProperty::kNormal, + priority, + PROFILER_MESSAGE("KVStoreDistPush")); } } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 3c551ca1f7df..17d6b48cd4c0 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -66,7 +66,8 @@ void TernaryOp(const NDArray &lhs, ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), mhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, { ret.var() }); + }, lhs.ctx(), const_vars, { ret.var() }, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -77,7 +78,8 @@ void TernaryOp(const NDArray &lhs, ndarray::Eval(lhs.data(), mhs.data(), rhs.data(), &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, lhs.ctx(), const_vars, { ret.var() }); + }, lhs.ctx(), const_vars, { ret.var() }, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -126,7 +128,8 @@ void BinaryOp(const NDArray &lhs, ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.var()}); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -137,7 +140,8 @@ void BinaryOp(const NDArray &lhs, ndarray::Eval(lhs.data(), rhs.data(), &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, lhs.ctx(), const_vars, {ret.var()}); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -155,7 +159,8 @@ void SetValueOp(const real_t &rhs, NDArray *out) { ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(rhs, &tmp, ctx); - }, ret.ctx(), {}, {ret.var()}); + }, ret.ctx(), {}, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -166,7 +171,8 @@ void SetValueOp(const real_t &rhs, NDArray *out) { ndarray::Eval(rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, ret.ctx(), {}, {ret.var()}); + }, ret.ctx(), {}, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -204,7 +210,8 @@ void ScalarOp(const NDArray &lhs, ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::Eval(lhs.data(), rhs, &tmp, ctx); - }, lhs.ctx(), const_vars, {ret.var()}); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -215,7 +222,8 @@ void ScalarOp(const NDArray &lhs, ndarray::Eval(lhs.data(), rhs, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, lhs.ctx(), const_vars, {ret.var()}); + }, lhs.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -247,7 +255,7 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { ndarray::Copy(from.data(), &tmp, from.ctx(), ret.ctx(), ctx); }, from.ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, priority); + FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME); } else { #if MXNET_USE_CUDA if (a == cpu::kDevMask && b == gpu::kDevMask) { @@ -259,7 +267,7 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, ret.ctx(), const_vars, {ret.var()}, - FnProperty::kCopyToGPU, priority); + FnProperty::kCopyToGPU, priority, PROFILER_MESSAGE_FUNCNAME); } else if (a == gpu::kDevMask && b == cpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -269,7 +277,7 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, from.ctx(), const_vars, {ret.var()}, - FnProperty::kCopyFromGPU, priority); + FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE_FUNCNAME); } else if (a == gpu::kDevMask && b == gpu::kDevMask) { Engine::Get()->PushSync([from, ret](RunContext ctx) { ret.CheckAndAlloc(); @@ -279,7 +287,7 @@ void CopyFromTo(const NDArray &from, NDArray *to, int priority) { // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, from.ctx(), const_vars, {ret.var()}, - FnProperty::kCopyFromGPU, priority); + FnProperty::kCopyFromGPU, priority, PROFILER_MESSAGE_FUNCNAME); } else { LOG(FATAL) << "unknown device mask"; } @@ -320,7 +328,7 @@ void ElementwiseSum(const std::vector &source, NDArray *out, int priori TBlob tmp = ret.data(); ndarray::ElementwiseSum(source_tblob, &tmp, ctx); }, out->ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, priority); + FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -336,7 +344,7 @@ void ElementwiseSum(const std::vector &source, NDArray *out, int priori // Wait GPU kernel to complete ctx.get_stream()->Wait(); }, out->ctx(), const_vars, {ret.var()}, - FnProperty::kNormal, priority); + FnProperty::kNormal, priority, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -362,7 +370,8 @@ void ClipOp(const NDArray &src, ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalClip(src.data(), a_min, a_max, &tmp, ctx); - }, src.ctx(), const_vars, {ret.var()}); + }, src.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -371,7 +380,8 @@ void ClipOp(const NDArray &src, ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalClip(src.data(), a_min, a_max, &tmp, ctx); - }, src.ctx(), const_vars, {ret.var()}); + }, src.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -399,7 +409,8 @@ void SampleOP(const real_t &a, ret.CheckAndAlloc(); TBlob tmp = ret.data(); ndarray::EvalRandom(a, b, resource, &tmp, ctx); - }, out->ctx(), {}, {ret.var(), resource.var}); + }, out->ctx(), {}, {ret.var(), resource.var}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -410,7 +421,8 @@ void SampleOP(const real_t &a, ndarray::EvalRandom(a, b, resource, &tmp, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, out->ctx(), {}, {ret.var(), resource.var}); + }, out->ctx(), {}, {ret.var(), resource.var}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif @@ -553,7 +565,8 @@ void Broadcast(const NDArray& src, int dim, int size, NDArray *out) { NDArray inter_out = ret.Reshape(mshadow::Shape3(before, size, after)); TBlob tmp = inter_out.data(); ndarray::EvalBroadcast(inter_in.data(), &tmp, size, ctx); - }, src.ctx(), const_vars, {ret.var()}); + }, src.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #if MXNET_USE_CUDA @@ -566,7 +579,8 @@ void Broadcast(const NDArray& src, int dim, int size, NDArray *out) { ndarray::EvalBroadcast(inter_in.data(), &tmp, size, ctx); // Wait GPU kernel to complete ctx.get_stream()->Wait(); - }, src.ctx(), const_vars, {ret.var()}); + }, src.ctx(), const_vars, {ret.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE_FUNCNAME); break; } #endif diff --git a/src/operator/cudnn_convolution.cc b/src/operator/cudnn_convolution.cc index b3d6b481b012..34aec2621547 100644 --- a/src/operator/cudnn_convolution.cc +++ b/src/operator/cudnn_convolution.cc @@ -273,7 +273,8 @@ void TuneCudnnConvolution(ConvolutionParam param, } else { *back_algo = bwd_data_algo[i].algo; } - }, ctx, {}, {var}); + }, ctx, {}, {var}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("TuneCudnnConvolution")); Engine::Get()->WaitForVar(var); Engine::Get()->DeleteVariable([](RunContext s) {}, ctx, var); diff --git a/src/operator/custom.cc b/src/operator/custom.cc index 09ab894044df..00a2518a5a38 100644 --- a/src/operator/custom.cc +++ b/src/operator/custom.cc @@ -81,7 +81,8 @@ void CustomOp::Forward(const OpContext &ctx, // NDArray* in ptrs is freed by frontend side. We keep a copy in ndcpy to keep ndvar alive Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) { ctx.async_on_complete(); - }, ndctx, ndvar, {}); + }, ndctx, ndvar, {}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOpForward")); } template @@ -139,7 +140,8 @@ void CustomOp::Backward(const OpContext &ctx, // NDArray* in ptrs is freed by frontend side. We keep a copy in ndcpy to keep ndvar alive Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){ ctx.async_on_complete(); - }, ndctx, ndvar, {}); + }, ndctx, ndvar, {}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("CustomOpBackward")); } Operator* CustomOpProp::CreateOperatorEx(Context ctx, std::vector *in_shape, diff --git a/src/operator/ndarray_op.cc b/src/operator/ndarray_op.cc index 950492632b31..773fe7753930 100644 --- a/src/operator/ndarray_op.cc +++ b/src/operator/ndarray_op.cc @@ -67,7 +67,8 @@ void NDArrayOp::Forward(const OpContext &ctx, CHECK(param_.pinfo->forward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_forward)); Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx) {ctx.async_on_complete(); }, - ndctx, ndvar, {}); + ndctx, ndvar, {}, FnProperty::kNormal, 0, + PROFILER_MESSAGE("NDArrayOpForward")); } template @@ -113,7 +114,8 @@ void NDArrayOp::Backward(const OpContext &ctx, CHECK(param_.pinfo->backward(ptrs.size(), ptrs.data(), tags.data(), param_.pinfo->p_backward)); Engine::Get()->PushSync([ndcpy, ctx](RunContext rctx){ ctx.async_on_complete(); }, - ndctx, ndvar, {}); + ndctx, ndvar, {}, FnProperty::kNormal, 0, + PROFILER_MESSAGE("NDArrayOpBackward")); } Operator* NDArrayOpProp::CreateOperator(Context ctx) const { diff --git a/src/operator/operator_util.cc b/src/operator/operator_util.cc index 3d618e5c5340..0b42229a6ee9 100644 --- a/src/operator/operator_util.cc +++ b/src/operator/operator_util.cc @@ -487,7 +487,8 @@ void SimpleOpRegEntryImpl::RegisterSourceImperative() { ctx.get_stream()->Wait(); } #endif - }, ret.ctx(), {}, write_vars); + }, ret.ctx(), {}, write_vars, + FnProperty::kNormal, 0, PROFILER_MESSAGE("RegisterSourceImperative")); }; // register the function. NDArrayReg() @@ -671,7 +672,8 @@ void SimpleOpRegEntryImpl::RegisterUnaryImperative() { ctx.get_stream()->Wait(); } #endif - }, src.ctx(), const_vars, write_vars); + }, src.ctx(), const_vars, write_vars, + FnProperty::kNormal, 0, PROFILER_MESSAGE("RegisterUnaryImperative")); }; // register the function. NDArrayReg() @@ -945,7 +947,8 @@ void SimpleOpRegEntryImpl::RegisterBinaryImperative() { ctx.get_stream()->Wait(); } #endif - }, lhs.ctx(), const_vars, write_vars); + }, lhs.ctx(), const_vars, write_vars, + FnProperty::kNormal, 0, PROFILER_MESSAGE("RegisterBinaryImperative")); }; // register the function. NDArrayReg() diff --git a/src/optimizer/sgd-inl.h b/src/optimizer/sgd-inl.h index c5b6aab528b2..36b45c375b95 100644 --- a/src/optimizer/sgd-inl.h +++ b/src/optimizer/sgd-inl.h @@ -122,11 +122,13 @@ class SGDOpt : public Optimizer { if (param_.momentum > 0.0f) { Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) { call_sgd_mom_update_cpu(ctx, w.data(), g.data(), mom[index].data(), lr, wd, param_); - }, w.ctx(), {g.var()}, {w.var(), mom[index].var()}, FnProperty::kNormal); + }, w.ctx(), {g.var()}, {w.var(), mom[index].var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate")); } else { Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) { call_sgd_update_cpu(ctx, w.data(), g.data(), lr, wd, param_); - }, w.ctx(), {g.var()}, {w.var()}, FnProperty::kNormal); + }, w.ctx(), {g.var()}, {w.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate")); } break; case Context::kGPU: @@ -134,11 +136,13 @@ class SGDOpt : public Optimizer { if (param_.momentum > 0.0f) { Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) { call_sgd_mom_update_gpu(ctx, w.data(), g.data(), mom[index].data(), lr, wd, param_); - }, w.ctx(), {g.var()}, {w.var(), mom[index].var()}, FnProperty::kNormal); + }, w.ctx(), {g.var()}, {w.var(), mom[index].var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate")); } else { Engine::Get()->PushSync([this, index, w, g, lr, wd](RunContext ctx) { call_sgd_update_gpu(ctx, w.data(), g.data(), lr, wd, param_); - }, w.ctx(), {g.var()}, {w.var()}, FnProperty::kNormal); + }, w.ctx(), {g.var()}, {w.var()}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("SGDOptUpdate")); } break; #else diff --git a/src/resource.cc b/src/resource.cc index 9123c42c3a69..232ab8cfd975 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -171,7 +171,8 @@ class ResourceManagerImpl : public ResourceManager { Engine::Get()->PushSync([r, seed](RunContext rctx) { r->set_stream(rctx.get_stream()); r->Seed(seed); - }, ctx, {}, {resource.var}); + }, ctx, {}, {resource.var}, + FnProperty::kNormal, 0, PROFILER_MESSAGE("ResourceRandomSetSeed")); } }; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 20ce6c9ab5aa..9ae7c62587ce 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1784,6 +1784,173 @@ def test_special_functions_using_scipy(): lambda x: scipy_special.psi(x), 0.5, 0.5) + +def mathematical_core_binary(name, + forward_mxnet_call, + forward_numpy_call, + backward_numpy_call1, + backward_numpy_call2, + data1_init=2., + data2_init=3., + grad_init=2.): + data1 = mx.symbol.Variable('data') + data2 = mx.symbol.Variable('data') + shape = (3, 4) + data_tmp1 = np.random.rand(3, 4) + data_tmp2 = np.random.rand(3, 4) + data_tmp1[:] = data1_init + data_tmp2[:] = data2_init + + arr_data1 = mx.nd.array(data_tmp1) + arr_data2 = mx.nd.array(data_tmp2) + + arr_grad1 = mx.nd.empty(shape) + arr_grad2 = mx.nd.empty(shape) + + test = forward_mxnet_call(data1, data2) + exe_test = test.bind(mx.cpu(), args=[arr_data1, arr_data2], args_grad=[arr_grad1, arr_grad2]) + exe_test.forward() + out = exe_test.outputs[0].asnumpy() + npout = forward_numpy_call(data_tmp1, data_tmp2) + assert reldiff(out, npout) < 1e-6, "%s mathematical forward failed\n%s\n\n%s" % (name, out, npout) + + out_grad = mx.nd.empty(shape) + out_grad[:] = grad_init + exe_test.backward(out_grad) + + npout_grad = np.ones(shape) + npout_grad[:] = grad_init + + npout_grad1 = npout_grad * backward_numpy_call1(data_tmp1, data_tmp2) + npout_grad2 = npout_grad * backward_numpy_call2(data_tmp1, data_tmp2) + arr_grad1 = arr_grad1.asnumpy() + arr_grad2 = arr_grad2.asnumpy() + + assert reldiff(arr_grad1, npout_grad1) < 1e-6, "%s mathematical backward1 failed\n%s\n\n%s" % ( + name, arr_grad1, npout_grad) + assert reldiff(arr_grad2, npout_grad2) < 1e-6, "%s mathematical backward2 failed\n%s\n\n%s" % ( + name, arr_grad2, npout_grad) + + +def mathematical_core(name, forward_mxnet_call, forward_numpy_call, backward_numpy_call, data_init=5., grad_init=2.): + data = mx.symbol.Variable('data') + shape = (3, 4) + data_tmp = np.ones(shape) + data_tmp[:] = data_init + arr_data = mx.nd.array(data_tmp) + arr_grad = mx.nd.empty(shape) + arr_grad[:] = 3 + + test = forward_mxnet_call(data) + exe_test = test.bind(mx.cpu(), args=[arr_data], args_grad=[arr_grad]) + exe_test.forward() + out = exe_test.outputs[0].asnumpy() + npout = forward_numpy_call(data_tmp) + assert reldiff(out, npout) < 1e-6, "%s mathematical forward failed\n%s\n\n%s" % (name, out, npout) + + out_grad = mx.nd.empty(shape) + out_grad[:] = grad_init + npout_grad = out_grad.asnumpy() + temp = backward_numpy_call(data_tmp) + npout_grad = npout_grad * temp + exe_test.backward(out_grad) + arr_grad = arr_grad.asnumpy() + # print(name) + # print(arr_grad) + # print(npout_grad) + assert reldiff(arr_grad, npout_grad) < 1e-6, "%s mathematical backward failed\n%s\n\n%s" % ( + name, arr_grad, npout_grad) + + +def test_mathematical(): + # rsqrt + mathematical_core("rsqrt", + lambda x: mx.sym.rsqrt(x), + lambda x: 1 / np.sqrt(x), + lambda x: -(1.0 / (2.0 * x * np.sqrt(x)))) + # tan + mathematical_core("tan", lambda x: mx.sym.tan(x), lambda x: np.tan(x), lambda x: np.tan(x) ** 2 + 1) + # arcsin + mathematical_core("arcsin", lambda x: mx.sym.arcsin(x), lambda x: np.arcsin(x), + lambda x: 1. / (1. - x ** 2) ** (1. / 2.), 0.5, 0.5) + # arccos + mathematical_core("arccos", lambda x: mx.sym.arccos(x), lambda x: np.arccos(x), + lambda x: -1. / (1. - x ** 2.) ** (1. / 2.), 0.5, 0.5) + # arctan + mathematical_core("arctan", lambda x: mx.sym.arctan(x), lambda x: np.arctan(x), + lambda x: 1. / (x ** 2. + 1.), 0.5, 0.5) + # hypot + mathematical_core_binary("hypot", + lambda x, y: mx.sym.hypot(x, y), + lambda x, y: np.hypot(x, y), + lambda x, y: x / np.hypot(x, y), + lambda x, y: y / np.hypot(x, y), + 0.5, 0.5, 0.5) + + # hypot scalar + mathematical_core("hypot scalar", + lambda x: mx.sym.hypot(x, 3), + lambda x: np.hypot(x, 3), + lambda x: x / np.hypot(x, 3), + 0.5, 0.5) + + # degrees + mathematical_core("degrees", + lambda x: mx.sym.degrees(x), + lambda x: np.degrees(x), + lambda x: 180./np.pi, + 0.5, 0.5) + # radians + mathematical_core("radians", + lambda x: mx.sym.radians(x), + lambda x: np.radians(x), + lambda x: np.pi / 180., + 0.6, 1) + # sinh + mathematical_core("sinh", lambda x: mx.sym.sinh(x), lambda x: np.sinh(x), lambda x: np.cosh(x)) + + # cosh + mathematical_core("cosh", lambda x: mx.sym.cosh(x), lambda x: np.cosh(x), lambda x: np.sinh(x), 5, 5) + + # tanh + mathematical_core("tanh", lambda x: mx.sym.tanh(x), lambda x: np.tanh(x), lambda x: 1. - np.tanh(x) ** 2, 0.5, 1) + + # arcsinh + mathematical_core("arcsinh", lambda x: mx.sym.arcsinh(x), lambda x: np.arcsinh(x), + lambda x: 1./(x**2 + 1.)**(1./2.)) + + # arccosh + mathematical_core("arccosh", lambda x: mx.sym.arccosh(x), lambda x: np.arccosh(x), + lambda x: 1./(x**2 - 1.)**(1./2.)) + + # arctanh + mathematical_core("arctanh", lambda x: mx.sym.arctanh(x), lambda x: np.arctanh(x), + lambda x: -1./(x**2 - 1.), 0.5) + + # log1p + mathematical_core("log1p", lambda x: mx.sym.log1p(x), lambda x: np.log1p(x), + lambda x: 1. / (1.0 + x), 0.5, 0.5) + # expm1 + mathematical_core("expm1", lambda x: mx.sym.expm1(x), lambda x: np.expm1(x), + lambda x: np.exp(x), 0.5, 0.5) + + +def test_special_functions_using_scipy(): + try: + from scipy import special as scipy_special + except: + print("Could not import scipy. Skipping unit tests for special functions") + return + + # gamma + mathematical_core("gamma", lambda x: mx.sym.gamma(x), lambda x: scipy_special.gamma(x), + lambda x: scipy_special.gamma(x) * scipy_special.psi(x), 0.5, 0.5) + + # gammaln + mathematical_core("gammaln", lambda x: mx.sym.gammaln(x), lambda x: scipy_special.gammaln(x), + lambda x: scipy_special.psi(x), 0.5, 0.5) + + if __name__ == '__main__': test_expand_dims() test_slice_axis()