Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] attrs.h -> ir #4709

Merged
merged 1 commit into from
Jan 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -677,13 +677,13 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
return LargeUIntImm(t, static_cast<int64_t>(low), static_cast<int64_t>(high));
}
}
if (t.is_float()) return ir::FloatImmNode::make(t, static_cast<double>(value));
if (t.is_float()) return FloatImm(t, static_cast<double>(value));
// For now, we store const scalar values of custom datatypes within doubles; later, during the
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
return ir::FloatImmNode::make(t, static_cast<double>(value));
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr();
Expand Down
18 changes: 1 addition & 17 deletions include/tvm/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,9 @@ namespace tvm {
namespace ir {

using IntImmNode = tvm::IntImmNode;
using FloatImmNode = tvm::FloatImmNode;
using VarNode = tvm::VarNode;

/*! \brief Floating point constants. */
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

TVM_DLL static PrimExpr make(DataType t, double value);

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};

/*! \brief String constants, only used in asserts. */
class StringImmNode : public PrimExprNode {
public:
Expand Down
53 changes: 21 additions & 32 deletions include/tvm/attrs.h → include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/attrs.h
* \brief TVM attribute module
* \file tvm/ir/attrs.h
* \brief Helpers for attribute objects.
*
* This module enables declaration of named attributes
* which support default value setup and bound checking.
Expand All @@ -42,31 +41,30 @@
*
* \sa AttrsNode, TVM_DECLARE_ATTRS, TVM_ATTR_FIELD
*/
#ifndef TVM_ATTRS_H_
#define TVM_ATTRS_H_
#ifndef TVM_IR_ATTRS_H_
#define TVM_IR_ATTRS_H_

#include <dmlc/common.h>
#include <tvm/ir/expr.h>
#include <tvm/runtime/packed_func.h>

#include <unordered_map>
#include <vector>
#include <functional>
#include <type_traits>
#include <string>
#include <utility>
#include "ir.h"
#include "base.h"
#include "expr.h"
#include "packed_func_ext.h"

namespace tvm {
/*!
* \brief Declare an attribute function.
* \param ClassName The name of the class.
* \param TypeKey The type key to be used by the TVM node system.
*/
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
#define TVM_DECLARE_ATTRS(ClassName, TypeKey) \
static constexpr const char* _type_key = TypeKey; \
TVM_DECLARE_FINAL_OBJECT_INFO(ClassName, ::tvm::BaseAttrsNode) \
template<typename FVisit> \
template<typename FVisit> \
void __VisitAttrs__(FVisit& __fvisit__) // NOLINT(*)


Expand Down Expand Up @@ -481,45 +479,36 @@ template<typename T>
inline void SetValue(T* ptr, const TVMArgValue& val) {
*ptr = val.operator T();
}

template<typename T>
inline void SetIntValue(T* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLInt) {
*ptr = static_cast<T>(val.value().v_int64);
} else {
PrimExpr expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
*ptr = static_cast<T>(op->value);
} else {
LOG(FATAL) << "Expect int value, but get " << expr->GetTypeKey();
}
IntImm expr = val;
*ptr = static_cast<T>(expr->value);
}
}

template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
*ptr = val.operator std::string();
} else {
PrimExpr expr = val;
const ir::StringImmNode* op = expr.as<ir::StringImmNode>();
CHECK(op != nullptr);
*ptr = op->value;
LOG(FATAL) << "Expect str";
}
}
template<>
inline void SetValue(DataType* ptr, const TVMArgValue& val) {
*ptr = val.operator DataType();
}

template<>
inline void SetValue<double>(double* ptr, const TVMArgValue& val) {
if (val.type_code() == kDLFloat || val.type_code() == kDLInt) {
*ptr = val.operator double();
} else {
PrimExpr expr = val;
ObjectRef expr = val;
CHECK(expr.defined());
if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
if (const IntImmNode* op = expr.as<IntImmNode>()) {
*ptr = static_cast<double>(op->value);
} else if (const ir::IntImmNode* op = expr.as<ir::IntImmNode>()) {
} else if (const FloatImmNode* op = expr.as<FloatImmNode>()) {
*ptr = static_cast<double>(op->value);
} else {
LOG(FATAL) << "Expect float value, but get " << expr->GetTypeKey();
Expand Down Expand Up @@ -611,7 +600,7 @@ struct TypeName<uint64_t> {

template<>
struct TypeName<DataType> {
static constexpr const char* value = "Type";
static constexpr const char* value = "DataType";
};

template<>
Expand Down Expand Up @@ -872,4 +861,4 @@ inline void BaseAttrsNode::PrintDocString(std::ostream &os) const { // NOLINT(*)
}

} // namespace tvm
#endif // TVM_ATTRS_H_
#endif // TVM_IR_ATTRS_H_
50 changes: 50 additions & 0 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,56 @@ class IntImm : public PrimExpr {
using ContainerType = IntImmNode;
};

/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;

void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
v->Visit("value", &value);
}

static constexpr const char* _type_key = "FloatImm";
TVM_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};

/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
TVM_DLL FloatImm(DataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
};

/*!
* \brief Base node of all non-primitive expressions.
*
Expand Down
14 changes: 8 additions & 6 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#define TVM_IR_OP_H_

#include <dmlc/registry.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/runtime/registry.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/type.h>
Expand Down Expand Up @@ -296,7 +296,8 @@ class OpRegistry {
// return internal pointer to op.
inline OpNode* get();
// update the attribute OpMap
TVM_DLL void UpdateAttr(const std::string& key, TVMRetValue value,
TVM_DLL void UpdateAttr(const std::string& key,
runtime::TVMRetValue value,
int plevel);
};

Expand All @@ -316,7 +317,7 @@ class GenericOpMap {
* \param op The key to the map
* \return the const reference to the content value.
*/
inline const TVMRetValue& operator[](const Op& op) const;
inline const runtime::TVMRetValue& operator[](const Op& op) const;
/*!
* \brief get the corresponding value element at op with default value.
* \param op The key to the map
Expand All @@ -342,7 +343,7 @@ class GenericOpMap {
// the attribute field.
std::string attr_name_;
// internal data
std::vector<std::pair<TVMRetValue, int> > data_;
std::vector<std::pair<runtime::TVMRetValue, int> > data_;
// The value
GenericOpMap() = default;
};
Expand Down Expand Up @@ -532,7 +533,7 @@ template <typename ValueType>
inline OpRegistry& OpRegistry::set_attr( // NOLINT(*)
const std::string& attr_name, const ValueType& value, int plevel) {
CHECK_GT(plevel, 0) << "plevel in set_attr must be greater than 0";
TVMRetValue rv;
runtime::TVMRetValue rv;
rv = value;
UpdateAttr(attr_name, rv, plevel);
return *this;
Expand All @@ -548,7 +549,8 @@ inline int GenericOpMap::count(const Op& op) const {
}
}

inline const TVMRetValue& GenericOpMap::operator[](const Op& op) const {
inline const runtime::TVMRetValue&
GenericOpMap::operator[](const Op& op) const {
CHECK(op.defined());
const uint32_t idx = op->index_;
CHECK(idx < data_.size() && data_[idx].second != 0)
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/ir/type_relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include <tvm/ir/type.h>
#include <tvm/ir/module.h>
#include <tvm/ir/env_func.h>
#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>

namespace tvm {

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/adt.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ADT_H_
#define TVM_RELAY_ADT_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/ir/adt.h>
#include <string>
#include <functional>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ALGORITHM_H_
#define TVM_RELAY_ATTRS_ALGORITHM_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/annotation.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_ANNOTATION_H_
#define TVM_RELAY_ATTRS_ANNOTATION_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/bitserial.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
#ifndef TVM_RELAY_ATTRS_BITSERIAL_H_
#define TVM_RELAY_ATTRS_BITSERIAL_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEBUG_H_
#define TVM_RELAY_ATTRS_DEBUG_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/device_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_DEVICE_COPY_H_
#define TVM_RELAY_ATTRS_DEVICE_COPY_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_IMAGE_H_
#define TVM_RELAY_ATTRS_IMAGE_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_MEMORY_H_
#define TVM_RELAY_ATTRS_MEMORY_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/expr.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_NN_H_
#define TVM_RELAY_ATTRS_NN_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_REDUCE_H_
#define TVM_RELAY_ATTRS_REDUCE_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <string>

namespace tvm {
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_TRANSFORM_H_
#define TVM_RELAY_ATTRS_TRANSFORM_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <tvm/relay/expr.h>
#include <string>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/vision.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_ATTRS_VISION_H_
#define TVM_RELAY_ATTRS_VISION_H_

#include <tvm/attrs.h>
#include <tvm/ir/attrs.h>
#include <tvm/relay/base.h>
#include <string>

Expand Down
Loading