Skip to content

Commit

Permalink
[TOP] split, reshape, concatenate (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed May 29, 2018
1 parent d34036f commit 81fd125
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 39 deletions.
19 changes: 8 additions & 11 deletions nnvm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
[![Build Status](https://travis-ci.org/dmlc/nnvm.svg?branch=master)](https://travis-ci.org/dmlc/nnvm)
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)

NNVM is a reusable computational graph optimization and compilation stack for deep learning systems. It provides modules to:
NNVM is a reusable computational graph compilation stack for deep learning systems. It provides modules to:

- Represent deep learning workloads from front-end frameworks via a graph IR.
- Optimize computation graphs to improve performance.
Expand All @@ -20,26 +20,23 @@ from tvm.contrib import graph_runtime, rpc
import nnvm.frontend
import nnvm.compiler

# get model from frameworks
# GET model from frameworks
# change xyz to supported framework name.
graph, params = nnvm.frontend.from_xyz(...)

# optimize and compile the graph to get a deployable module
# OPTIMIZE and COMPILE the graph to get a deployable module
# target can be "opencl", "llvm", "metal" or any target supported by tvm
target = "cuda"
graph, lib, params = nnvm.compiler.build(
graph, target, shape={"data", data_shape}, params=params)
graph, lib, params = nnvm.compiler.build(graph, target, {"data", data_shape}, params=params)

# deploy and run on gpu(0)
# DEPLOY and run on gpu(0)
module = graph_runtime.create(graph, lib, tvm.gpu(0))
module.set_input(**params)
module.run(data=data_array)
output = tvm.nd.empty(out_shape, ctx=tvm.gpu(0))
for data_array in dataset:
module.set_input("data", data_array)
module.run()
module.get_output(0, output)
module.get_output(0, output)

# deploy to remote mobile/rasp/browser with minimum tvm rpc runtime
# DEPLOY to REMOTE mobile/rasp/browser with minimum tvm rpc runtime
# useful for quick experiments on mobile devices
remote = rpc.connect(remote_host, remote_port)
lib.export_library("mylib.so")
Expand Down
1 change: 1 addition & 0 deletions nnvm/python/nnvm/compiler/param_dict.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pylint: disable=invalid-name
"""Helper utility to save parameter dict"""
import tvm

Expand Down
49 changes: 30 additions & 19 deletions nnvm/python/nnvm/top/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,12 @@
"""Tensor transformation ops"""
from __future__ import absolute_import

import tvm
import topi
from .tensor import _fschedule_broadcast, _fschedule_injective
from . import registry as reg
from .registry import OpPattern

# Need add reshape
# expand_dims
@reg.register_compute("expand_dims")
def compute_expand_dims(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
Expand All @@ -18,34 +17,46 @@ def compute_expand_dims(attrs, inputs, out_info):
reg.register_pattern("expand_dims", OpPattern.BROADCAST)
reg.register_schedule("expand_dims", _fschedule_broadcast)


# transpose
@reg.register_compute("transpose")
def compute_transpose(attrs, inputs, out_info):
"""Compute definition of expand_dims"""
"""Compute definition of transpose"""
axes = attrs.get_int_tuple("axes")
axes = tuple(axes) if axes else None
return topi.transpose(inputs[0], axes)
reg.register_pattern("transpose", OpPattern.INJECTIVE)
reg.register_schedule("transpose", _fschedule_injective)


def _flatten_index(indices, shape):
"""flatten the index to 1D"""
idx = 0
for i, value in enumerate(shape):
if i != 0:
idx *= value
idx = idx + indices[i]
return idx

# reshape
@reg.register_compute("reshape")
def compute_reshape(attrs, inputs, out_info):
"""Compute definition of softmax"""
# TODO(sxj) add support for general reshape
assert len(inputs[0].shape) == 1, "Only support 1d input for now"
"""Compute definition of reshape"""
oshape = out_info[0].shape
x = inputs[0]
return tvm.compute(oshape, lambda *i: x(_flatten_index(i, oshape)))
return topi.reshape(inputs[0], oshape)
reg.register_pattern("reshape", OpPattern.INJECTIVE)
reg.register_schedule("reshape", _fschedule_injective)

# concatenate
@reg.register_compute("concatenate")
def compute_concatenate(attrs, inputs, out_info):
"""Compute definition of concatenate"""
axis = attrs.get_int("axis")
return topi.concatenate([x for x in inputs], axis=axis)

reg.register_pattern("concatenate", OpPattern.INJECTIVE)
reg.register_schedule("concatenate", _fschedule_injective)

# split
@reg.register_compute("split")
def compute_split(attrs, inputs, out_info):
"""Compute definition of split"""
x = attrs["indices_or_sections"]
if x.startswith("(") or x.startswith("["):
indices_or_sections = attrs.get_int_tuple("indices_or_sections")
else:
indices_or_sections = attrs.get_int("indices_or_sections")
return topi.split(inputs[0], indices_or_sections, axis=attrs.get_int("axis"))


reg.register_pattern("split", OpPattern.INJECTIVE)
reg.register_schedule("split", _fschedule_injective)
1 change: 1 addition & 0 deletions nnvm/src/compiler/graph_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <nnvm/graph.h>
#include <vector>
#include <string>
#include "../../tvm/src/runtime/graph/graph_runtime.h"

namespace nnvm {
Expand Down
20 changes: 11 additions & 9 deletions nnvm/src/top/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,15 +225,17 @@ inline bool SplitInferShape(const NodeAttrs& attrs,
CHECK_EQ(out_shape->size(), static_cast<size_t>(num_outputs));
CHECK_LT(param.axis, dshape.ndim());
TShape oshape = dshape;
dim_t total = 0;
for (size_t i = 1; i < num_outputs; ++i) {
oshape[param.axis] = param.indices_or_sections[i - 1];
total += oshape[param.axis];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i - 1, oshape);
dim_t begin = 0;
for (size_t i = 0; i < num_outputs - 1; ++i) {
CHECK_GT(param.indices_or_sections[i], begin)
<< "indices_or_sections need to be a sorted ascending list";
oshape[param.axis] = param.indices_or_sections[i] - begin;
begin = param.indices_or_sections[i];
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, i, oshape);
}
CHECK_LT(total, dshape[param.axis])
CHECK_LT(begin, dshape[param.axis])
<< "The sum of sections must match the input.shape[axis]";
oshape[param.axis] = dshape[param.axis] - total;
oshape[param.axis] = dshape[param.axis] - begin;
NNVM_ASSIGN_OUTPUT_SHAPE(attrs, *out_shape, num_outputs - 1, oshape);
}
return true;
Expand All @@ -256,11 +258,11 @@ NNVM_REGISTER_OP(split)
along which to split the array.
)code" NNVM_ADD_FILELINE)
.add_argument("data", "Tensor", "List of arrays to concatenate")
.add_argument("data", "Tensor", "Array to be splitted")
.add_arguments(SplitParam::__FIELDS__())
.set_attr_parser(SplitParamParser)
.set_attr<FInferShape>("FInferShape", SplitInferShape)
.set_attr<FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<FInferType>("FInferType", ElemwiseType<1, -1>)
.set_num_inputs(1)
.set_num_outputs(SplitNumOutputs)
.set_support_level(1);
Expand Down
45 changes: 45 additions & 0 deletions nnvm/tests/python/compiler/test_top_level1.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,52 @@ def test_batchnorm():
res.asnumpy(), res_np, atol=1e-5, rtol=1e-5)


def verify_concatenate(ishape, axis):
x = [sym.Variable("x%d" % i) for i in range(len(ishape))]
y = sym.concatenate(*x, axis=axis) + 1
dtype = "float32"
for target, ctx in ctx_list():
# set input
data = []
for i, shape in enumerate(ishape):
data.append(np.random.uniform(size=shape).astype(dtype))
pdict = {"x%d" % i : v for i, v in enumerate(data)}
shape = {"x%d" % i : v.shape for i, v in enumerate(data)}
graph, lib, _ = nnvm.compiler.build(y, target, shape)
m = graph_runtime.create(graph, lib, ctx)
m.run(**pdict)
out_np = np.concatenate(data, axis=axis) + 1
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def test_concatenate():
verify_concatenate([(2, 3, 4), (1, 3, 4)], axis=0)
verify_concatenate([(2, 4), (2, 7)], axis=1)


def verify_split(ishape, indices_or_sections, axis):
x = sym.Variable("x")
y = sym.split(x, indices_or_sections=indices_or_sections, axis=axis)
dtype = "float32"
x_np = np.random.uniform(size=ishape).astype(dtype)
res = np.split(x_np, indices_or_sections, axis=axis)
for target, ctx in ctx_list():
# set input
graph, lib, _ = nnvm.compiler.build(y, target, {"x": ishape})
m = graph_runtime.create(graph, lib, ctx)
m.run(x=x_np)
for i, arr in enumerate(res):
out = m.get_output(i, tvm.nd.empty(arr.shape))
np.testing.assert_allclose(out.asnumpy(), arr, atol=1e-5, rtol=1e-5)

def test_split():
verify_split((2, 3), 2, axis=0)
verify_split((5, 3), [3], axis=0)
verify_split((5, 9, 3), [3, 4], axis=1)

if __name__ == "__main__":
test_split()
test_concatenate()
test_log_softmax()
test_batchnorm()
test_dense()
Expand Down
22 changes: 22 additions & 0 deletions nnvm/tests/python/compiler/test_top_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,29 @@ def test_reduce():
verify_reduce((4, 4, 3), np.sum, sym.sum, axis=(0, 2))


def verify_reshape(dshape, oshape):
x = sym.Variable("x")
y = sym.reshape(x, shape=oshape)
y = y + 1
dtype = "float32"
for target, ctx in ctx_list():
graph, lib, _ = nnvm.compiler.build(y, target, {"x": dshape})
m = graph_runtime.create(graph, lib, ctx)
# set input
data = tvm.nd.array(np.random.uniform(size=dshape).astype(dtype))
m.run(x=data)
out_np = data.asnumpy().reshape(oshape) + 1
out = m.get_output(0, tvm.nd.empty(out_np.shape))
np.testing.assert_allclose(out.asnumpy(), out_np, atol=1e-5, rtol=1e-5)

def test_reshape():
verify_reshape((2, 3, 4), (-1, 2, 1))
verify_reshape((2, 3, 4), (8, 3))
verify_reshape((4, 7), (2, 7, 2))


if __name__ == "__main__":
test_reshape()
test_reduce()
test_tranpose()
print(nnvm.compiler.engine.dump())

0 comments on commit 81fd125

Please sign in to comment.