Skip to content

Commit

Permalink
[IRBuilder][Minor] Add intrinsics like T.int32x4 (apache#13361)
Browse files Browse the repository at this point in the history
This PR adds all common TIR intrinsics like `T.int32x4`, `T.floatx4`.

Co-authored-by: Yaxing Cai <caiyaxing666@gmail.com>
  • Loading branch information
junrushao and cyx-6 authored Nov 11, 2022
1 parent 5ffcfd9 commit 8897983
Show file tree
Hide file tree
Showing 9 changed files with 348 additions and 341 deletions.
16 changes: 10 additions & 6 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -453,8 +453,8 @@ class AllocateFrameNode : public TIRFrameNode {
PrimExpr condition;
/*! \brief Additional annotation hints. */
Map<String, ObjectRef> annotations;
/*! \brief The buffer. */
tvm::tir::Buffer buffer;
/*! \brief The buffer var. */
tvm::tir::Var buffer_var;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
Expand All @@ -463,7 +463,7 @@ class AllocateFrameNode : public TIRFrameNode {
v->Visit("storage_scope", &storage_scope);
v->Visit("condition", &condition);
v->Visit("annotations", &annotations);
v->Visit("buffer", &buffer);
v->Visit("buffer_var", &buffer_var);
}

static constexpr const char* _type_key = "script.ir_builder.tir.AllocateFrame";
Expand Down Expand Up @@ -500,8 +500,8 @@ class AllocateConstFrameNode : public TIRFrameNode {
Array<PrimExpr> extents;
/*! \brief The data associated with the constant. */
tvm::runtime::NDArray data;
/*! \brief The buffer */
tvm::tir::Buffer buffer;
/*! \brief The buffer var */
tvm::tir::Var buffer_var;
/*! \brief Additional annotations about the allocation. */
Map<String, ObjectRef> annotations;

Expand All @@ -510,7 +510,7 @@ class AllocateConstFrameNode : public TIRFrameNode {
v->Visit("dtype", &dtype);
v->Visit("extents", &extents);
v->Visit("data", &data);
v->Visit("buffer", &buffer);
v->Visit("buffer_var", &buffer_var);
v->Visit("annotations", &annotations);
}

Expand Down Expand Up @@ -723,11 +723,15 @@ class ElseFrame : public TIRFrame {

class DeclBufferFrameNode : public TIRFrameNode {
public:
/*! \brief The declared buffer. */
tvm::tir::Buffer buffer;
/*! \brief The buffer allocated or not. */
bool allocated;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("buffer", &buffer);
v->Visit("allocated", &allocated);
}

static constexpr const char* _type_key = "script.ir_builder.tir.DeclBufferFrame";
Expand Down
46 changes: 28 additions & 18 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,8 @@ AllocateFrame Allocate(Array<PrimExpr> extents, DataType dtype, String storage_s
* \param annotations Additional annotation hints.
* \return The created AllocateConstFrame.
*/
AllocateConstFrame AllocateConst(
NDArray data, DataType dtype, Array<PrimExpr> extents,
Map<String, ObjectRef> annotations = NullValue<Map<String, ObjectRef>>());
AllocateConstFrame AllocateConst(NDArray data, DataType dtype, Array<PrimExpr> extents,
Optional<Map<String, ObjectRef>> annotations = NullOpt);

/*!
* \brief Create an attribute.
Expand Down Expand Up @@ -449,21 +448,32 @@ PrimExpr Ptr(runtime::DataType dtype, String storage_scope = "global");
return expr.defined() ? tvm::cast(dtype, expr.value()) : tvm::tir::Var("", dtype); \
}

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int8, DataType::Int(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int16, DataType::Int(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32, DataType::Int(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int64, DataType::Int(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt8, DataType::UInt(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt16, DataType::UInt(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt32, DataType::UInt(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(UInt64, DataType::UInt(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float8, DataType::Float(8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float16, DataType::Float(16));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float32, DataType::Float(32));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Float64, DataType::Float(64));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x4, DataType::Int(32, 4));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x8, DataType::Int(32, 8));
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Int32x16, DataType::Int(32, 16));
#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(DType, FDType) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##8, FDType(8)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##16, FDType(16)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##32, FDType(32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(DType##64, FDType(64));

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES(Int, DataType::Int);

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(FuncName, FDType, Size) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x4, FDType(Size, 4)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x8, FDType(Size, 8)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x16, FDType(Size, 16)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x32, FDType(Size, 32)); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(FuncName##x64, FDType(Size, 64));

#define TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(DType, FDType) \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##8, FDType, 8); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##16, FDType, 16); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##32, FDType, 32); \
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_LANES(DType##64, FDType, 64);

TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Float, DataType::Float);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(UInt, DataType::UInt);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST_SIZES_LANES(Int, DataType::Int);
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Boolean, DataType::Bool());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Handle, DataType::Handle());
TVM_TIR_IR_BUILDER_DEF_DTYPE_CAST(Void, DataType::Void());
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,14 @@ class RealizeFrame(TIRFrame):
class AllocateFrame(TIRFrame):
def __enter__(self) -> Buffer:
super().__enter__()
return self.buffer
return self.buffer_var


@_register_object("script.ir_builder.tir.AllocateConstFrame")
class AllocateConstFrame(TIRFrame):
def __enter__(self) -> Buffer:
super().__enter__()
return self.buffer
return self.buffer_var


@_register_object("script.ir_builder.tir.AttrFrame")
Expand Down
Loading

0 comments on commit 8897983

Please sign in to comment.