Skip to content

Commit

Permalink
[runtime] Add Metadata classes for AOTExecutor (apache#10282)
Browse files Browse the repository at this point in the history
* Add new Metadata classes and base implementation.

 * These were autogenerated in the original PR, but checking them in
   as plain code until we can revisit the auto-generator approach.

* address masa comments

* Add documentation per Manupa's comments, and move kMetadataVersion namespace.

* remove get_name function, used for debugging

* clang-format
  • Loading branch information
areusch authored Feb 22, 2022
1 parent 91b2e91 commit 33082e0
Show file tree
Hide file tree
Showing 7 changed files with 973 additions and 0 deletions.
160 changes: 160 additions & 0 deletions include/tvm/runtime/metadata.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/metadata.h
* \brief Defines types which can be used in Metadata.
*/
#ifndef TVM_RUNTIME_METADATA_H_
#define TVM_RUNTIME_METADATA_H_

#include <inttypes.h>
#ifdef __cplusplus
#include <memory>
#include <string>
#include <vector>
#endif
#include <tvm/runtime/c_runtime_api.h>
#ifdef __cplusplus
#include <tvm/runtime/metadata_base.h>
#endif
#include <tvm/support/span.h>

// Version number recorded in emitted artifacts for runtime checking.
#define TVM_METADATA_VERSION 1

namespace tvm {
namespace runtime {
namespace metadata {
/*!
* \brief Version of metadata emitted and understood by this compiler/runtime.
* Should be populated into the `version` field of all TVMMetadata.
*/
static const constexpr int64_t kMetadataVersion = TVM_METADATA_VERSION;
} // namespace metadata
} // namespace runtime
} // namespace tvm

#ifdef __cplusplus
extern "C" {
#endif

/*!
* \brief Top-level metadata structure. Holds all other metadata types.
*/
struct TVMMetadata {
/*! \brief Version identifier for this metadata. */
int64_t version;
/*! \brief Inputs to the AOT run_model function.
* The order of the elements is the same as in the arguments to run_model. That is to say,
* this array specifies the first `num_inputs` arguments to run_model.
*/
const struct TVMTensorInfo* inputs;
/*! \brief Number of elements in `inputs` array. */
int64_t num_inputs;
/*! \brief Outputs of the AOT run_model function.
* The order of the elements is the same as in the arguments to run_model. That is to say,
* this array specifies the last `num_outputs` arguments to run_model.
*/
const struct TVMTensorInfo* outputs;
/*! \brief Number of elements in `outputs` array. */
int64_t num_outputs;
/*! \brief Name of the model, as passed to tvm.relay.build. */
const char* mod_name;
};

/*!
* \brief Describes one tensor argument to `run_model`.
* NOTE: while TIR allows for other types of arguments, such as scalars, the AOT run_model
* function does not currently accept these. Therefore it's not possible to express those
* in this metadata. A future patch may modify this.
*/
struct TVMTensorInfo {
/*! \brief Name of the tensor, as specified in the Relay program. */
const char* name;
/*! \brief Shape of the tensor. */
const int64_t* shape;
/*! \brief Rank of this tensor. */
int64_t num_shape;
/*! \brief Data type of one element of this tensor. */
DLDataType dtype;
};
#ifdef __cplusplus
} // extern "C"
#include <tvm/runtime/object.h>
namespace tvm {
namespace runtime {
namespace metadata {

class Metadata;
class TensorInfo;

class MetadataNode : public MetadataBaseNode {
public:
explicit MetadataNode(const struct ::TVMMetadata* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.MetadataNode";
inline int64_t version() const { return int64_t(data_->version); }
inline int64_t num_inputs() const { return data_->num_inputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> inputs();
inline int64_t num_outputs() const { return data_->num_outputs; }
ArrayAccessor<struct TVMTensorInfo, TensorInfo> outputs();
inline ::tvm::runtime::String mod_name() const { return ::tvm::runtime::String(data_->mod_name); }
const struct ::TVMMetadata* data() const { return data_; }
TVM_DECLARE_FINAL_OBJECT_INFO(MetadataNode, MetadataBaseNode);

private:
const struct ::TVMMetadata* data_;
};

class Metadata : public MetadataBase {
public:
explicit Metadata(const struct ::TVMMetadata* data);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Metadata, MetadataBase, MetadataNode);
};

class TensorInfoNode : public MetadataBaseNode {
public:
explicit TensorInfoNode(const struct ::TVMTensorInfo* data) : data_{data} {}
static constexpr const char* _type_key = "metadata.TensorInfoNode";
inline ::tvm::runtime::String name() const { return ::tvm::runtime::String(data_->name); }
inline int64_t num_shape() const { return data_->num_shape; }
inline ::tvm::support::Span<const int64_t, int64_t> shape() const {
return ::tvm::support::Span<const int64_t, int64_t>(data_->shape,
data_->shape + data_->num_shape);
}
inline ::tvm::runtime::DataType dtype() const { return ::tvm::runtime::DataType(data_->dtype); }
const struct ::TVMTensorInfo* data() const { return data_; }
TVM_DECLARE_FINAL_OBJECT_INFO(TensorInfoNode, MetadataBaseNode);

private:
const struct ::TVMTensorInfo* data_;
};

class TensorInfo : public MetadataBase {
public:
explicit TensorInfo(const struct ::TVMTensorInfo* data);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TensorInfo, MetadataBase, TensorInfoNode);
};

} // namespace metadata
} // namespace runtime
} // namespace tvm
#endif // defined(__cplusplus)

#endif // TVM_RUNTIME_METADATA_H_
198 changes: 198 additions & 0 deletions include/tvm/runtime/metadata_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/runtime/metadata_base.h
* \brief Defines types which can be used in Metadata.
*/
#ifndef TVM_RUNTIME_METADATA_BASE_H_
#define TVM_RUNTIME_METADATA_BASE_H_

#include <tvm/ir/expr.h>
#include <tvm/runtime/object.h>

#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace tvm {
namespace runtime {
namespace metadata {

/*!
* \brief Common base class for all Metadata.
*
* This class is used in the visitor classes as a internal check to ensure that verify that all
* parts of the Metadata struct used in codegen are Metadata objects.
*/
class MetadataBaseNode : public ::tvm::runtime::Object {
public:
static constexpr const char* _type_key = "metadata.MetadataBaseNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataBaseNode, ::tvm::runtime::Object);
};

/*! \brief Reference class for the common MetadataBaseNode class. */
class MetadataBase : public ::tvm::runtime::ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataBase, ::tvm::runtime::ObjectRef, MetadataBaseNode);
};

template <typename C, class Ref>
class ArrayAccessor;

/*! \brief An iterator implementation that lazily instantiates the C++ wrapping Metadata class. */
template <typename C, class Ref>
class ArrayIterator {
public:
ArrayIterator(size_t index, const ArrayAccessor<C, Ref>* parent)
: index_{index}, parent_{parent} {}

inline Ref operator*() { return (*parent_)[index_]; }

inline ArrayIterator<C, Ref>& operator++() {
if (index_ < parent_->size()) {
index_++;
}

return *this;
}

inline bool operator==(const ArrayIterator<C, Ref>& other) const {
return parent_ == other.parent_ && index_ == other.index_;
}

inline bool operator!=(const ArrayIterator<C, Ref>& other) const { return !operator==(other); }

private:
size_t index_;
const ArrayAccessor<C, Ref>* parent_;
};

/*! \brief A span-like class which permits access to Array fields with complex elements.
* These array fields should be accessed from C++ using the Metadata wrapper classes. This class
* lazily instantiates those wrappers as they are accessed.
*/
template <typename C, class Ref>
class ArrayAccessor {
public:
using value_type = Ref;
using iterator = ArrayIterator<C, Ref>;
using const_iterator = iterator;

template <typename T = typename std::enable_if<std::is_base_of<ObjectRef, Ref>::value>::type>
ArrayAccessor(const C* data, size_t num_data) : data_{data}, num_data_{num_data} {}

inline size_t size() const { return num_data_; }

inline Ref operator[](size_t index) const {
if (index >= num_data_) {
throw std::runtime_error("Index out of range");
}

return Ref(&data_[index]);
}

inline ArrayIterator<C, Ref> begin() const { return ArrayIterator<C, Ref>{0, this}; }

inline ArrayIterator<C, Ref> end() const { return ArrayIterator<C, Ref>{num_data_, this}; }

private:
const C* data_;
size_t num_data_;
};

/*! \brief A specialization of ArrayAccessor for String.
* This class is needed because the String constructor signature is different from the typical
* Metadata subclass.
*/
template <>
class ArrayAccessor<const char*, ::tvm::runtime::String> {
public:
using value_type = ::tvm::runtime::String;
using iterator = ArrayIterator<const char*, ::tvm::runtime::String>;
using const_iterator = iterator;

ArrayAccessor(const char** data, size_t num_data) : data_{data}, num_data_{num_data} {}

inline size_t size() const { return num_data_; }

inline ::tvm::runtime::String operator[](size_t index) const {
if (index >= num_data_) {
throw std::runtime_error("Index out of range");
}
return ::tvm::runtime::String(data_[index]);
}

inline ArrayIterator<const char*, ::tvm::runtime::String> begin() const {
return ArrayIterator<const char*, ::tvm::runtime::String>{0, this};
}

inline ArrayIterator<const char*, ::tvm::runtime::String> end() const {
return ArrayIterator<const char*, ::tvm::runtime::String>{num_data_, this};
}

private:
const char** data_;
size_t num_data_;
};

/*! \brief Enumerates the primitive types which can be part of a Metadata instance.
*
* These are separate from TIR DataType because TIR does not model structs.
*/
enum MetadataTypeIndex : uint8_t {
kUint64 = 0,
kInt64 = 1,
kBool = 2,
kString = 3,
kHandle = 4,
kMetadata = 5,
};

/*! \brief Container for arrays in the metadata.
*
* Type information is needed when emitting arrays. This container augments the data field with
* the necessary typing information.
*/
class MetadataArrayNode : public MetadataBaseNode {
public:
MetadataArrayNode(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name)
: array(::std::move(array)), type_index{type_index}, struct_name{struct_name} {}

Array<ObjectRef> array;
MetadataTypeIndex type_index;
const char* struct_name;
static constexpr const char* _type_key = "metadata.MetadataArrayNode";
TVM_DECLARE_BASE_OBJECT_INFO(MetadataArrayNode, MetadataBaseNode);
};

/*! \brief Reference class for MetadataArray. */
class MetadataArray : public MetadataBase {
public:
MetadataArray(Array<ObjectRef> array, MetadataTypeIndex type_index, const char* struct_name);

TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MetadataArray, MetadataBase, MetadataArrayNode);
};

} // namespace metadata
} // namespace runtime
} // namespace tvm

#endif // TVM_RUNTIME_METADATA_BASE_H_
Loading

0 comments on commit 33082e0

Please sign in to comment.