Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update model to add save/load and period checkpoint.
Browse files Browse the repository at this point in the history
Removed wait for All at threaded Engine.
Still do not know the cause of WaitforVar stalling.
  • Loading branch information
tqchen committed Sep 20, 2015
1 parent a8c5ed1 commit fd201b3
Show file tree
Hide file tree
Showing 11 changed files with 411 additions and 242 deletions.
3 changes: 2 additions & 1 deletion python/mxnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from . import model
from . import initializer
from . import visualization
import atexit
# use viz as short for mx.ndarray
from . import visualization as viz

__version__ = "0.1.0"
41 changes: 41 additions & 0 deletions python/mxnet/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,44 @@ def ctypes2numpy_shared(cptr, shape):
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)


def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
"""Convert ctypes returned doc string information into parameters docstring.
num_args : mx_uint
Number of arguments.
arg_names : ctypes.POINTER(ctypes.c_char_p)
Argument names.
arg_types : ctypes.POINTER(ctypes.c_char_p)
Argument type information.
arg_descs : ctypes.POINTER(ctypes.c_char_p)
Argument description information.
remove_dup : boolean, optional
Whether remove duplication or not.
Returns
-------
docstr : str
Python docstring of parameter sections.
"""
param_keys = set()
param_str = []
for i in range(num_args.value):
key = py_str(arg_names[i])
if key in param_keys and remove_dup:
continue
param_keys.add(key)
type_info = py_str(arg_types[i])
ret = '%s : %s' % (key, type_info)
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
doc_str = ('Parameters\n' +
'----------\n' +
'%s\n')
doc_str = doc_str % ('\n'.join(param_str))
return doc_str
15 changes: 4 additions & 11 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from .base import _LIB
from .base import c_array, c_str, mx_uint, py_str
from .base import DataIterHandle, NDArrayHandle
from .base import check_call
from .base import check_call, ctypes2docstring
from .ndarray import NDArray

class DataIter(object):
Expand Down Expand Up @@ -99,24 +99,17 @@ def _make_io_iterator(handle):
ctypes.byref(arg_types), \
ctypes.byref(arg_descs)))
iter_name = py_str(name.value)
param_str = []
for i in range(num_args.value):
ret = '%s : %s' % (arg_names[i], arg_types[i])
if len(arg_descs[i]) != 0:
ret += '\n ' + py_str(arg_descs[i])
param_str.append(ret)
param_str = ctypes2docstring(num_args, arg_names, arg_types, arg_descs)

doc_str = ('%s\n\n' +
'Parameters\n' +
'----------\n' +
'%s\n' +
'name : string, required.\n' +
' Name of the resulting data iterator.\n\n' +
'Returns\n' +
'-------\n' +
'iterator: Iterator\n'+
'iterator: DataIter\n'+
' The result iterator.')
doc_str = doc_str % (desc.value, '\n'.join(param_str))
doc_str = doc_str % (desc.value, param_str)

def creator(*args, **kwargs):
"""Create an iterator.
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ def __init__(self):
def update(self, pred, label):
pred = pred.asnumpy()
label = label.asnumpy().astype('int32')
y = np.argmax(pred, axis=1)
self.sum_metric += np.sum(y == label)
py = np.argmax(pred, axis=1)
self.sum_metric += np.sum(py == label)
self.num_inst += label.size


Expand Down
Loading

0 comments on commit fd201b3

Please sign in to comment.