Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
auto generate code
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Oct 15, 2017
1 parent 024cfe6 commit c4002f3
Show file tree
Hide file tree
Showing 18 changed files with 244 additions and 49 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,8 @@ target
bin/im2rec

model/

# generated function signature for IDE auto-complete
python/mxnet/symbol/gen_*
python/mxnet/ndarray/gen_*
python/.eggs
86 changes: 86 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"""ctypes library of mxnet and helper functions."""
from __future__ import absolute_import

import os
import sys
import ctypes
import atexit
Expand Down Expand Up @@ -444,3 +445,88 @@ def _init_op_module(root_namespace, module_name, make_op_func):
function.__module__ = contrib_module_name_old
setattr(contrib_module_old, function.__name__, function)
contrib_module_old.__all__.append(function.__name__)


def _generate_op_module_signature(root_namespace, module_name, op_code_gen_func):
"""
Generate op functions created by `op_code_gen_func` and write to the source file
of `root_namespace.module_name.[submodule_name]`,
where `submodule_name` is one of `_OP_SUBMODULE_NAME_LIST`.
Parameters
----------
root_namespace : str
Top level module name, `mxnet` in the current cases.
module_name : str
Second level module name, `ndarray` and `symbol` in the current cases.
op_code_gen_func : function
Function for creating op functions for `ndarray` and `symbol` modules.
"""
def get_module_file(module_name):
"""Return the generated module file based on module name."""
path = os.path.dirname(__file__)
module_path = module_name.split('.')
module_path[-1] = 'gen_'+module_path[-1]
file_name = os.path.join(path, '..', *module_path) + '.py'
module_file = open(file_name, 'w')
dependencies = {'symbol': ['from ._internal import SymbolBase',
'from ..base import _Null'],
'ndarray': ['from ._internal import NDArrayBase',
'from ..base import _Null']}
module_file.write('# File content is auto-generated. Do not modify.'+os.linesep)
module_file.write('# pylint: skip-file'+os.linesep)
module_file.write(os.linesep.join(dependencies[module_name.split('.')[1]]))
return module_file
def write_all_str(module_file, module_all_list):
"""Write the proper __all__ based on available operators."""
module_file.write(os.linesep)
module_file.write(os.linesep)
all_str = '__all__ = [' + ', '.join(["'%s'"%s for s in module_all_list]) + ']'
module_file.write(all_str)

plist = ctypes.POINTER(ctypes.c_char_p)()
size = ctypes.c_uint()

check_call(_LIB.MXListAllOpNames(ctypes.byref(size),
ctypes.byref(plist)))
op_names = []
for i in range(size.value):
op_names.append(py_str(plist[i]))

module_op_file = get_module_file("%s.%s.op" % (root_namespace, module_name))
module_op_all = []
module_internal_file = get_module_file("%s.%s._internal"%(root_namespace, module_name))
module_internal_all = []
submodule_dict = {}
for op_name_prefix in _OP_NAME_PREFIX_LIST:
submodule_dict[op_name_prefix] =\
(get_module_file("%s.%s.%s" % (root_namespace, module_name,
op_name_prefix[1:-1])), [])
for name in op_names:
hdl = OpHandle()
check_call(_LIB.NNGetOpHandle(c_str(name), ctypes.byref(hdl)))
op_name_prefix = _get_op_name_prefix(name)
if len(op_name_prefix) > 0:
func_name = name[len(op_name_prefix):]
cur_module_file, cur_module_all = submodule_dict[op_name_prefix]
elif name.startswith('_'):
func_name = name
cur_module_file = module_internal_file
cur_module_all = module_internal_all
else:
func_name = name
cur_module_file = module_op_file
cur_module_all = module_op_all

code, _ = op_code_gen_func(hdl, name, func_name, True)
cur_module_file.write(os.linesep)
cur_module_file.write(code)
cur_module_all.append(func_name)

for (submodule_f, submodule_all) in submodule_dict.values():
write_all_str(submodule_f, submodule_all)
submodule_f.close()
write_all_str(module_op_file, module_op_all)
module_op_file.close()
write_all_str(module_internal_file, module_internal_all)
module_internal_file.close()
4 changes: 4 additions & 0 deletions python/mxnet/ndarray/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

from . import _internal, contrib, linalg, op, random, sparse, utils
# pylint: disable=wildcard-import, redefined-builtin
try:
from .gen_op import * # pylint: disable=unused-wildcard-import
except ImportError:
pass
from . import register
from .op import *
from .ndarray import *
Expand Down
12 changes: 11 additions & 1 deletion python/mxnet/ndarray/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,13 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=wildcard-import, unused-import
"""NDArray namespace used to register internal functions."""
import sys as _sys
import os as _os
import sys as _sys

import numpy as np

try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.ndarray import NDArrayBase, CachedOp
Expand All @@ -34,4 +38,10 @@
from .._ctypes.ndarray import NDArrayBase, CachedOp
from .._ctypes.ndarray import _set_ndarray_class, _imperative_invoke

from ..base import _Null
try:
from .gen__internal import * # pylint: disable=unused-wildcard-import
except ImportError:
pass

__all__ = ['NDArrayBase', 'CachedOp', '_imperative_invoke', '_set_ndarray_class']
7 changes: 7 additions & 0 deletions python/mxnet/ndarray/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import
"""Contrib NDArray API of MXNet."""
try:
from .gen_contrib import *
except ImportError:
pass

__all__ = []
7 changes: 7 additions & 0 deletions python/mxnet/ndarray/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import
"""Linear Algebra NDArray API of MXNet."""
try:
from .gen_linalg import *
except ImportError:
pass

__all__ = []
2 changes: 1 addition & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from ..context import Context
from . import _internal
from . import op
from .op import NDArrayBase
from ._internal import NDArrayBase

__all__ = ["NDArray", "concatenate", "_DTYPE_NP_TO_MX", "_DTYPE_MX_TO_NP", "_GRAD_REQ_MAP",
"ones", "add", "arange", "divide", "equal", "full", "greater", "greater_equal",
Expand Down
10 changes: 7 additions & 3 deletions python/mxnet/ndarray/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,12 @@
# under the License.

# coding: utf-8
# pylint: disable=unused-import
# pylint: disable=wildcard-import, unused-wildcard-import, redefined-builtin
"""Backend ops in mxnet.ndarray namespace"""
__all__ = ['CachedOp']
from ._internal import CachedOp
try:
from .gen_op import * # pylint: disable=unused-wildcard-import
except ImportError:
pass

from ._internal import NDArrayBase, CachedOp
__all__ = ['CachedOp']
42 changes: 23 additions & 19 deletions python/mxnet/ndarray/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@
import ctypes
import numpy as np # pylint: disable=unused-import

from . import _internal
from ._internal import NDArrayBase, _imperative_invoke # pylint: disable=unused-import
from ..ndarray_doc import _build_doc

from ..base import mx_uint, check_call, _LIB, py_str, _init_op_module, _Null # pylint: disable=unused-import


def _generate_ndarray_function_code(handle, name, func_name):
def _generate_ndarray_function_code(handle, name, func_name, signature_only=False):
"""Generate function for ndarray op by handle and function name."""
real_name = ctypes.c_char_p()
desc = ctypes.c_char_p()
Expand Down Expand Up @@ -92,53 +91,59 @@ def _generate_ndarray_function_code(handle, name, func_name):
if arr_name:
code.append("""
def %s(*%s, **kwargs):"""%(func_name, arr_name))
code.append("""
if not signature_only:
code.append("""
ndargs = []
for i in {}:
assert isinstance(i, NDArrayBase), \\
"Positional arguments must have NDArray type, " \\
"but got %s"%str(i)
ndargs.append(i)""".format(arr_name))
if dtype_name is not None:
code.append("""
if dtype_name is not None:
code.append("""
if '%s' in kwargs:
kwargs['%s'] = np.dtype(kwargs['%s']).name"""%(
dtype_name, dtype_name, dtype_name))
code.append("""
code.append("""
_ = kwargs.pop('name', None)
out = kwargs.pop('out', None)
keys = list(kwargs.keys())
vals = list(kwargs.values())""")
else:
code.append("""
def %s(%s):"""%(func_name, ', '.join(signature)))
code.append("""
if not signature_only:
code.append("""
ndargs = []
keys = list(kwargs.keys())
vals = list(kwargs.values())""")
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
# NDArray args
for name in ndarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if {name} is not None:
assert isinstance({name}, NDArrayBase), \\
"Argument {name} must have NDArray type, but got %s"%str({name})
ndargs.append({name})""".format(name=name))
# kwargs
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
# kwargs
for name in kwarg_names: # pylint: disable=redefined-argument-from-local
code.append("""
if %s is not _Null:
keys.append('%s')
vals.append(%s)"""%(name, name, name))
# dtype
if dtype_name is not None:
code.append("""
# dtype
if dtype_name is not None:
code.append("""
if %s is not _Null:
keys.append('%s')
vals.append(np.dtype(%s).name)"""%(dtype_name, dtype_name, dtype_name))

code.append("""
if not signature_only:
code.append("""
return _imperative_invoke(%d, ndargs, keys, vals, out)"""%(
handle.value))
else:
code.append("""
return (0,)""")

doc_str_lines = _os.linesep+''.join([' '+s if s.strip() else s
for s in 'r"""{doc_str}"""'.format(doc_str=doc_str)
Expand All @@ -160,5 +165,4 @@ def _make_ndarray_function(handle, name, func_name):
ndarray_function.__module__ = 'mxnet.ndarray'
return ndarray_function

if not _internal.__dict__.get('skip_register'):
_init_op_module('mxnet', 'ndarray', _make_ndarray_function)
_init_op_module('mxnet', 'ndarray', _make_ndarray_function)
6 changes: 6 additions & 0 deletions python/mxnet/ndarray/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import, too-many-lines
"""Sparse NDArray API of MXNet."""

from __future__ import absolute_import
Expand All @@ -41,13 +42,18 @@
from ..context import Context
from . import _internal
from . import op
try:
from .gen_sparse import * # pylint: disable=redefined-builtin
except ImportError:
pass
from ._internal import _set_ndarray_class
from .ndarray import NDArray, _storage_type, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP
from .ndarray import _STORAGE_TYPE_STR_TO_ID, _STORAGE_TYPE_ROW_SPARSE, _STORAGE_TYPE_CSR
from .ndarray import _STORAGE_TYPE_UNDEFINED, _STORAGE_TYPE_DEFAULT
from .ndarray import zeros as _zeros_ndarray
from .ndarray import array as _array


try:
import scipy.sparse as spsp
except ImportError:
Expand Down
4 changes: 4 additions & 0 deletions python/mxnet/symbol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@

from . import _internal, contrib, linalg, op, random, sparse
# pylint: disable=wildcard-import, redefined-builtin
try:
from .gen_op import * # pylint: disable=unused-wildcard-import
except ImportError:
pass
from . import register
from .op import *
from .symbol import *
Expand Down
12 changes: 11 additions & 1 deletion python/mxnet/symbol/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
# specific language governing permissions and limitations
# under the License.

# pylint: disable=wildcard-import, unused-import
"""Symbol namespace used to register internal functions."""
# Use different version of SymbolBase
# When possible, use cython to speedup part of computation.
# pylint: disable=unused-import
import sys as _sys
import os as _os

import numpy as np

try:
if int(_os.environ.get("MXNET_ENABLE_CYTHON", True)) == 0:
from .._ctypes.symbol import SymbolBase, _set_symbol_class
Expand All @@ -36,5 +39,12 @@
raise ImportError("Cython Module cannot be loaded but MXNET_ENFORCE_CYTHON=1")
from .._ctypes.symbol import SymbolBase, _set_symbol_class
from .._ctypes.symbol import _symbol_creator
from ..attribute import AttrScope
from ..base import _Null
from ..name import NameManager
try:
from .gen__internal import * # pylint: disable=unused-wildcard-import
except ImportError:
pass

__all__ = ['SymbolBase', '_set_symbol_class', '_symbol_creator']
7 changes: 7 additions & 0 deletions python/mxnet/symbol/contrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import
"""Contrib Symbol API of MXNet."""
try:
from .gen_contrib import *
except ImportError:
pass

__all__ = []
7 changes: 7 additions & 0 deletions python/mxnet/symbol/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,12 @@
# specific language governing permissions and limitations
# under the License.

# coding: utf-8
# pylint: disable=wildcard-import, unused-wildcard-import
"""Linear Algebra Symbol API of MXNet."""
try:
from .gen_linalg import *
except ImportError:
pass

__all__ = []
7 changes: 6 additions & 1 deletion python/mxnet/symbol/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
# under the License.

# coding: utf-8
# pylint: disable=unused-import
# pylint: disable=wildcard-import, unused-wildcard-import, redefined-builtin
"""Backend ops in mxnet.symbol namespace."""
try:
from .gen_op import *
except ImportError:
pass

__all__ = []
Loading

0 comments on commit c4002f3

Please sign in to comment.