diff --git a/python/mxnet/libinfo.py b/python/mxnet/libinfo.py index 0669a03c6520..57c73e5943af 100644 --- a/python/mxnet/libinfo.py +++ b/python/mxnet/libinfo.py @@ -96,10 +96,18 @@ def find_include_path(): logging.warning("MXNET_INCLUDE_PATH '%s' doesn't exist", incl_from_env) curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) - incl_path = os.path.join(curr_path, '../../include/') - if not os.path.isdir(incl_path): - raise RuntimeError('Cannot find the MXNet include path.\n') - return incl_path + # include path in pip package + pip_incl_path = os.path.join(curr_path, 'include/') + if os.path.isdir(pip_incl_path): + return pip_incl_path + else: + # include path if build from source + src_incl_path = os.path.join(curr_path, '../../include/') + if os.path.isdir(src_incl_path): + return src_incl_path + else: + raise RuntimeError('Cannot find the MXNet include path in either ' + pip_incl_path + + ' or ' + src_incl_path + '\n') # current version