Skip to content

Commit

Permalink
[TensorIR][M1b] Schedule class (#7847)
Browse files Browse the repository at this point in the history
Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn>
Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
Co-authored-by: Hongyi Jin <3231950289@qq.com>
Co-authored-by: Wuwei Lin <wuwei@apache.org>
Co-authored-by: Jared Roesch <roeschinc@gmail.com>
  • Loading branch information
7 people authored Apr 17, 2021
1 parent 6d0386b commit 2b690ad
Show file tree
Hide file tree
Showing 11 changed files with 1,210 additions and 1 deletion.
227 changes: 227 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
#define TVM_TIR_SCHEDULE_SCHEDULE_H_

#include <tvm/tir/schedule/state.h>

namespace tvm {
namespace tir {

/**************** Random variable: BlockRV ****************/

/*! \brief A random variable that evaluates to a TensorIR block */
class BlockRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.BlockRV";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRVNode, runtime::Object);
};

/*!
* \brief Managed reference to BlockRVNode
* \sa BlockRVNode
*/
class BlockRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL BlockRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(BlockRV, runtime::ObjectRef, BlockRVNode);
};

/**************** Random variable: LoopRV ****************/

/*! \brief A random variable that evaluates to a TensorIR for loop */
class LoopRVNode : public runtime::Object {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
static constexpr const char* _type_key = "tir.LoopRV";
TVM_DECLARE_FINAL_OBJECT_INFO(LoopRVNode, runtime::Object);
};

/*!
* \brief Managed reference to LoopRVNode
* \sa LoopRVNode
*/
class LoopRV : public runtime::ObjectRef {
public:
/*! \brief Constructor */
TVM_DLL LoopRV();
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(LoopRV, runtime::ObjectRef, LoopRVNode);
};

/**************** Random variable: IntRV ****************/

/*! \brief An integer random variable */
using IntRV = PrimExpr;

using IntRVNode = PrimExprNode;

/**************** The Schedule class ****************/

class Schedule;

/*! \brief The user-facing schedule class */
class ScheduleNode : public runtime::Object {
friend class Schedule;

public:
virtual ~ScheduleNode() = default;

static constexpr const char* _type_key = "tir.Schedule";
TVM_DECLARE_BASE_OBJECT_INFO(ScheduleNode, runtime::Object);

public:
/*! \brief Get the IRModule associated with this schedule. */
virtual IRModule mod() const { return state()->mod; }
/*! \return The internal state of scheduling */
virtual ScheduleState state() const = 0;
/*!
* \brief Returns a copy of the schedule, including both its state and its symbol table,
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is not modified;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* reconstructed
*/
virtual Schedule Copy() const = 0;
/*!
* \brief Seed the randomness
* \param seed The new random seed, -1 if use device random, otherwise non-negative
*/
virtual void Seed(int64_t seed = -1) {
LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed";
}

public:
/******** Lookup/Remove random variables ********/
/*!
* \brief Get the block corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
* \return The corresponding block
*/
virtual Block Get(const BlockRV& block_rv) const = 0;
/*!
* \brief Get the for loop corresponding to the specific LoopRV
* \param loop_rv The LoopRV to be looked up
* \return The corresponding for loop
*/
virtual For Get(const LoopRV& loop_rv) const = 0;
/*!
* \brief Get the value corresponding to the specific random variable
* \param int_rv The random variable to be looked up
* \return The corresponding value
*/
virtual int64_t Get(const IntRV& int_rv) const = 0;
/*!
* \brief Get the block sref corresponding to the specific BlockRV
* \param block_rv The BlockRV to be looked up
* \return The corresponding block sref
*/
virtual StmtSRef GetSRef(const BlockRV& block_rv) const = 0;
/*!
* \brief Get the loop sref corresponding to the specific LoopRV
* \param loop_rv The LoopRV to be looked up
* \return The corresponding loop sref
*/
virtual StmtSRef GetSRef(const LoopRV& loop_rv) const = 0;
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
* \return The corresponding block/loop sref
*/
virtual StmtSRef GetSRef(const StmtNode* stmt) const;
/*!
* \brief Get the block/loop sref corresponding to the specific statement
* \param stmt The statement to be looked up
* \return The corresponding block/loop sref
*/
StmtSRef GetSRef(const Stmt& stmt) const { return this->GetSRef(stmt.get()); }
/*!
* \brief Remove a block random variable from the symbol table
* \param block_rv The random variable to be removed
*/
virtual void RemoveRV(const BlockRV& block_rv) = 0;
/*!
* \brief Remove a loop random variable from the symbol table
* \param loop_rv The random variable to be removed
*/
virtual void RemoveRV(const LoopRV& loop_rv) = 0;
/*!
* \brief Remove an integer random variable from the symbol table
* \param int_rv The random variable to be removed
*/
virtual void RemoveRV(const IntRV& int_rv) = 0;

public:
/******** Block/Loop relation ********/
/*!
* \brief Retrieve a block in a specific function with its name
* \param name The name of the block to be retrieved
* \param func_name The name of the function
* \return The block retrieved
* \note Indexing error is raised if 0 or multiple blocks exist with the specific name
*/
virtual BlockRV GetBlock(const String& name, const String& func_name = "main") = 0;
/*!
* \brief Get the parent loops of the block in its scope, from outer to inner
* \param block_rv The query block
* \return A list of loops above the given block in its scope, from outer to inner
*/
virtual Array<LoopRV> GetLoops(const BlockRV& block_rv) = 0;
};

/*!
* \brief Managed reference to ScheduleNode
*
* A schedule is a set of transformations that change the order of computation but
* preserve the semantics of computation. Some example of schedules:
* 1) Split a loop into two;
* 2) Reorder two loops;
* 3) Inline the computation of a specific buffer into its consumer
*
* The schedule class stores auxiliary information to schedule correctly and efficiently.
*
* Link to tutorial: https://tvm.apache.org/docs/tutorials/language/schedule_primitives.html
*
* \sa ScheduleNode
*/
class Schedule : public runtime::ObjectRef {
public:
/*!
* \brief Construct a concrete TensorIR schedule from an IRModule
* \param mod The IRModule to be scheduled
* \param debug_mode Do extra correctness checking after the class creation
* and each time after calling the Replace method.
* \return The concrete schedule created
* \sa ScheduleDebugMask
* \note The checks performed includes:
* 1) VerifySRefTree
* 2) VerifyAffineBinding
* 3) VerifyRegionCover
* 4) VerifyStagePipeline
*/
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mode);
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
};

} // namespace tir
} // namespace tvm

#endif // TVM_TIR_SCHEDULE_SCHEDULE_H_
1 change: 1 addition & 0 deletions python/tvm/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self, msg):
register_error("TypeError", TypeError)
register_error("AttributeError", AttributeError)
register_error("KeyError", KeyError)
register_error("IndexError", IndexError)


@register_error
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
from .op import comm_reducer, min, max, sum
from .op import q_multiply_shift

from .schedule import StmtSRef, BlockScope, ScheduleState
from .schedule import StmtSRef, BlockScope, ScheduleState, Schedule

from . import schedule
from . import ir_builder
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/schedule/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@

from .block_scope import BlockScope, Dependency, DepKind, StmtSRef
from .state import ScheduleDebugMask, ScheduleState
from .schedule import LoopRV, BlockRV, IntRV, RAND_VAR_TYPE, Schedule
Loading

0 comments on commit 2b690ad

Please sign in to comment.