Skip to content

Commit

Permalink
[SYMBOL] Add __iter__ and GetChildren for symbol (apache#268)
Browse files Browse the repository at this point in the history
* [SYMBOL] Add __iter__ and GetChildren for symbol

* [SYMBOL] Fix lint
  • Loading branch information
ZihengJiang authored and tqchen committed May 29, 2018
1 parent a0dc865 commit 02141d4
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
8 changes: 8 additions & 0 deletions nnvm/include/nnvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,14 @@ NNVM_DLL int NNSymbolGetNumOutputs(SymbolHandle symbol,
*/
NNVM_DLL int NNSymbolGetInternals(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get a symbol that contains only direct children.
* \param symbol The symbol
* \param out The output symbol whose outputs are the direct children.
* \return 0 when success, -1 when failure happens
*/
NNVM_DLL int NNSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out);
/*!
* \brief Get index-th outputs of the symbol.
* \param symbol The symbol
Expand Down
14 changes: 14 additions & 0 deletions nnvm/python/nnvm/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ def __getitem__(self, index):
self.handle, _base.nn_uint(index), _ctypes.byref(handle)))
return Symbol(handle=handle)

def __iter__(self):
return (self[i] for i in self.list_output_names())

def attr(self, key):
"""Get attribute string from the symbol, this function only works for non-grouped symbol.
Expand Down Expand Up @@ -196,6 +199,17 @@ def get_internals(self):
self.handle, _ctypes.byref(handle)))
return Symbol(handle=handle)

def get_children(self):
"""Gets a new grouped symbol whose output contains
inputs to output nodes of the original symbol."""
handle = _base.SymbolHandle()
_check_call(_LIB.NNSymbolGetChildren(
self.handle, _ctypes.byref(handle)))
ret = Symbol(handle=handle)
if not ret.list_output_names():
return None
return ret

def _get_list_copt(self, option):
"""internal function to get list option"""
if option == 'all':
Expand Down
9 changes: 9 additions & 0 deletions nnvm/src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,15 @@ int NNSymbolGetInternals(SymbolHandle symbol,
API_END_HANDLE_ERROR(delete s);
}

int NNSymbolGetChildren(SymbolHandle symbol,
SymbolHandle *out) {
Symbol *s = new Symbol();
API_BEGIN();
*s = static_cast<Symbol*>(symbol)->GetChildren();
*out = s;
API_END_HANDLE_ERROR(delete s);
}

int NNSymbolFree(SymbolHandle symbol) {
API_BEGIN();
delete static_cast<Symbol*>(symbol);
Expand Down

0 comments on commit 02141d4

Please sign in to comment.