Skip to content

Commit

Permalink
[RUNTIME] ShapeTuple Container (apache#8200)
Browse files Browse the repository at this point in the history
* Add ShapeTuple.

* Update NDArray.

* Documents.

* Lint.

* Lint.

* Lint.

* Address comment.

* Address comment.

* Address comment.

* Lint.

* Lint.
  • Loading branch information
ZihengJiang authored Jun 9, 2021
1 parent 1f2ca06 commit 4d9bc9b
Show file tree
Hide file tree
Showing 9 changed files with 270 additions and 23 deletions.
180 changes: 180 additions & 0 deletions include/tvm/runtime/container/shape_tuple.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/container/shape_tuple.h
* \brief Runtime ShapeTuple container types.
*/
#ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
#define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_

#include <utility>
#include <vector>

#include "./base.h"

namespace tvm {
namespace runtime {

/*! \brief An object representing a shape tuple. */
class ShapeTupleObj : public Object {
public:
/*! \brief The type of shape index element. */
using index_type = int64_t;
/*! \brief The pointer to shape tuple data. */
index_type* data;
/*! \brief The size of the shape tuple object. */
uint64_t size;

static constexpr const uint32_t _type_index = runtime::TypeIndex::kRuntimeShapeTuple;
static constexpr const char* _type_key = "runtime.ShapeTuple";
TVM_DECLARE_FINAL_OBJECT_INFO(ShapeTupleObj, Object);

private:
/*! \brief ShapeTuple object which is moved from std::vector container. */
class FromStd;

friend class ShapeTuple;
};

/*! \brief An object representing shape tuple moved from std::vector. */
class ShapeTupleObj::FromStd : public ShapeTupleObj {
public:
/*! \brief The type of shape index element. */
using index_type = ShapeTupleObj::index_type;
/*!
* \brief Construct a new FromStd object
*
* \param other The moved/copied std::vector object
*
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
explicit FromStd(std::vector<index_type> other) : data_container{other} {}

private:
/*! \brief Container that holds the memory. */
std::vector<index_type> data_container;

friend class ShapeTuple;
};

/*!
* \brief Reference to shape tuple objects.
*/
class ShapeTuple : public ObjectRef {
public:
/*! \brief The type of shape index element. */
using index_type = ShapeTupleObj::index_type;

/*!
* \brief Construct an empty shape tuple.
*/
ShapeTuple() : ShapeTuple(std::vector<index_type>()) {}

/*!
* \brief Constructor from iterator
* \param begin begin of iterator
* \param end end of iterator
* \tparam IterType The type of iterator
*/
template <typename IterType>
ShapeTuple(IterType begin, IterType end) : ShapeTuple(std::vector<index_type>(begin, end)) {}

/*!
* \brief constructor from initializer list
* \param shape The initializer list
*/
ShapeTuple(std::initializer_list<index_type> shape) : ShapeTuple(shape.begin(), shape.end()) {}

/*!
* \brief Construct a new ShapeTuple object
*
* \param shape The moved/copied std::vector object
*
* \note If user passes const reference, it will trigger copy. If it's rvalue,
* it will be moved into other.
*/
ShapeTuple(std::vector<index_type> shape); // NOLINT(*)

/*!
* \brief Return the data pointer
*
* \return const index_type* data pointer
*/
const index_type* data() const { return get()->data; }

/*!
* \brief Return the size of the shape tuple
*
* \return size_t shape tuple size
*/
size_t size() const { return get()->size; }

/*!
* \brief Immutably read i-th element from the shape tuple.
* \param idx The index
* \return the i-th element.
*/
index_type operator[](size_t idx) const {
ICHECK(0 <= idx && idx < this->size())
<< "IndexError: indexing " << idx << " on an array of size " << this->size();
return this->data()[idx];
}

/*!
* \brief Immutably read i-th element from the shape tuple.
* \param idx The index
* \return the i-th element.
*/
index_type at(size_t idx) const { return this->operator[](idx); }

/*! \return Whether shape tuple is empty */
bool empty() const { return size() == 0; }

/*! \return The first element of the shape tuple */
index_type front() const { return this->at(0); }

/*! \return The last element of the shape tuple */
index_type back() const { return this->at(this->size() - 1); }

/*! \return begin iterator */
const index_type* begin() const { return get()->data; }

/*! \return end iterator */
const index_type* end() const { return (get()->data + size()); }

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ShapeTuple, ObjectRef, ShapeTupleObj);
};

inline ShapeTuple::ShapeTuple(std::vector<index_type> shape) {
auto ptr = make_object<ShapeTupleObj::FromStd>(std::move(shape));
ptr->size = ptr->data_container.size();
ptr->data = ptr->data_container.data();
data_ = std::move(ptr);
}

} // namespace runtime

// expose the functions to the root namespace.
using runtime::ShapeTuple;
using runtime::ShapeTupleObj;
} // namespace tvm

#endif // TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_
19 changes: 9 additions & 10 deletions include/tvm/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
#define TVM_RUNTIME_NDARRAY_H_

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/optional.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/object.h>
Expand Down Expand Up @@ -128,7 +128,7 @@ class NDArray : public ObjectRef {
* \param dtype The data type of the new array.
* \note The memory size of new array must be smaller than the current one.
*/
TVM_DLL NDArray CreateView(std::vector<int64_t> shape, DLDataType dtype);
TVM_DLL NDArray CreateView(ShapeTuple shape, DLDataType dtype);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
Expand All @@ -143,7 +143,7 @@ class NDArray : public ObjectRef {
* \param mem_scope The memory scope of the array.
* \return The created Array
*/
TVM_DLL static NDArray Empty(std::vector<int64_t> shape, DLDataType dtype, Device dev,
TVM_DLL static NDArray Empty(ShapeTuple shape, DLDataType dtype, Device dev,
Optional<String> mem_scope = NullOpt);
/*!
* \brief Create a NDArray backed by a dlpack tensor.
Expand All @@ -166,7 +166,7 @@ class NDArray : public ObjectRef {
TVM_DLL static void CopyFromTo(const DLTensor* from, DLTensor* to,
TVMStreamHandle stream = nullptr);

TVM_DLL std::vector<int64_t> Shape() const;
TVM_DLL ShapeTuple Shape() const;
TVM_DLL runtime::DataType DataType() const;
// internal namespace
struct Internal;
Expand Down Expand Up @@ -241,7 +241,7 @@ class NDArray::ContainerBase {
* \brief The shape container,
* can be used used for shape data.
*/
std::vector<int64_t> shape_;
ShapeTuple shape_;
};

/*!
Expand All @@ -261,13 +261,13 @@ class NDArray::Container : public Object, public NDArray::ContainerBase {
dl_tensor.byte_offset = 0;
}

Container(void* data, std::vector<int64_t> shape, DLDataType dtype, Device dev) {
Container(void* data, ShapeTuple shape, DLDataType dtype, Device dev) {
// Initialize the type index.
type_index_ = Container::RuntimeTypeIndex();
dl_tensor.data = data;
shape_ = std::move(shape);
dl_tensor.ndim = static_cast<int>(shape_.size());
dl_tensor.shape = dmlc::BeginPtr(shape_);
dl_tensor.shape = const_cast<ShapeTuple::index_type*>(shape_.data());
dl_tensor.dtype = dtype;
dl_tensor.strides = nullptr;
dl_tensor.byte_offset = 0;
Expand Down Expand Up @@ -357,8 +357,7 @@ inline void NDArray::CopyTo(const NDArray& other) const {
inline NDArray NDArray::CopyTo(const Device& dev) const {
ICHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
NDArray ret =
Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev);
NDArray ret = Empty(ShapeTuple(dptr->shape, dptr->shape + dptr->ndim), dptr->dtype, dev);
this->CopyTo(ret);
return ret;
}
Expand Down Expand Up @@ -460,7 +459,7 @@ inline bool NDArray::Load(dmlc::Stream* strm) {
if (ndim != 0) {
ICHECK(strm->ReadArray(&shape[0], ndim)) << "Invalid DLTensor file format";
}
NDArray ret = NDArray::Empty(shape, dtype, dev);
NDArray ret = NDArray::Empty(ShapeTuple(shape), dtype, dev);
int64_t num_elems = 1;
int elem_bytes = (ret->dtype.bits + 7) / 8;
for (int i = 0; i < ret->ndim; ++i) {
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ struct TypeIndex {
kRuntimeArray = 4,
/*! \brief runtime::Map. */
kRuntimeMap = 5,
/*! \brief runtime::ShapeTuple. */
kRuntimeShapeTuple = 6,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
Expand Down
1 change: 1 addition & 0 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/runtime/c_runtime_api.h>
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/data_type.h>
#include <tvm/runtime/logging.h>
#include <tvm/runtime/module.h>
Expand Down
24 changes: 24 additions & 0 deletions python/tvm/runtime/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,27 @@ def __from_tvm_object__(cls, obj):
val = str.__new__(cls, content)
val.__tvm_object__ = obj
return val


@tvm._ffi.register_object("runtime.ShapeTuple")
class ShapeTuple(Object):
"""TVM runtime ShapeTuple object.
Parameters
----------
shape : list[int]
The shape list used to construct the object.
"""

def __init__(self, shape):
assert isinstance(shape, (list, tuple)), "Expect list of tuple, but received : {0}".format(
type(shape)
)
for x in shape:
assert isinstance(x, int), "Expect int type, but received : {0}".format(type(x))
self.__init_handle_by_constructor__(_ffi_api.ShapeTuple, *shape)

def __len__(self):
return _ffi_api.GetShapeTupleSize(self)

def __getitem__(self, idx):
return getitem_helper(self, _ffi_api.GetShapeTupleElem, len(self), idx)
12 changes: 12 additions & 0 deletions src/node/container_printing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << '}';
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ShapeTupleObj>([](const ObjectRef& node, ReprPrinter* p) {
auto* op = static_cast<const ShapeTupleObj*>(node.get());
p->stream << '[';
for (size_t i = 0; i < op->size; ++i) {
if (i != 0) {
p->stream << ", ";
}
p->stream << op->data[i];
}
p->stream << ']';
});
} // namespace tvm
23 changes: 21 additions & 2 deletions src/runtime/container.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <tvm/runtime/container/array.h>
#include <tvm/runtime/container/closure.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/container/shape_tuple.h>
#include <tvm/runtime/container/string.h>
#include <tvm/runtime/memory.h>
#include <tvm/runtime/object.h>
Expand Down Expand Up @@ -108,7 +109,6 @@ TVM_REGISTER_GLOBAL("runtime.ADT").set_body([](TVMArgs args, TVMRetValue* rv) {
});

// String

TVM_REGISTER_OBJECT_TYPE(StringObj);

TVM_REGISTER_GLOBAL("runtime.String").set_body_typed([](std::string str) {
Expand All @@ -120,7 +120,6 @@ TVM_REGISTER_GLOBAL("runtime.GetFFIString").set_body_typed([](String str) {
});

// Map

TVM_REGISTER_OBJECT_TYPE(MapNode);

TVM_REGISTER_GLOBAL("runtime.Map").set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down Expand Up @@ -185,7 +184,27 @@ TVM_REGISTER_GLOBAL("runtime.MapItems").set_body([](TVMArgs args, TVMRetValue* r
TVM_DLL constexpr uint64_t DenseMapNode::kNextProbeLocation[];
#endif

// Closure
TVM_REGISTER_OBJECT_TYPE(ClosureObj);

// ShapeTuple
TVM_REGISTER_OBJECT_TYPE(ShapeTupleObj);

TVM_REGISTER_GLOBAL("runtime.ShapeTuple").set_body([](TVMArgs args, TVMRetValue* rv) {
std::vector<ShapeTuple::index_type> shape;
for (int i = 0; i < args.size(); i++) {
shape.push_back(args[i]);
}
*rv = ShapeTuple(shape);
});

TVM_REGISTER_GLOBAL("runtime.GetShapeTupleSize").set_body_typed([](ShapeTuple shape) {
return static_cast<int64_t>(shape.size());
});

TVM_REGISTER_GLOBAL("runtime.GetShapeTupleElem").set_body_typed([](ShapeTuple shape, int idx) {
ICHECK_LT(idx, shape.size());
return shape[idx];
});
} // namespace runtime
} // namespace tvm
Loading

0 comments on commit 4d9bc9b

Please sign in to comment.