Skip to content

Commit

Permalink
[AutoScheduler] New layout rewrite option: Weight pre-transpose (apac…
Browse files Browse the repository at this point in the history
…he#6750)

* Add pre transpose support for layout rewrite

* Update

* Bug fix

* Bug fix

* Update

* Bug fix

* CI Fix

* Update

* Update

* Re-trigger CI

* Update

* Update test_auto_scheduler_layout_rewrite.py

* Update test_auto_scheduler_layout_rewrite.py

* Update task_scheduler ut, re-trigger CI

Co-authored-by: Lianmin Zheng <lianminzheng@gmail.com>
  • Loading branch information
2 people authored and Trevor Morris committed Dec 2, 2020
1 parent 4ad8f51 commit 4683d34
Show file tree
Hide file tree
Showing 8 changed files with 334 additions and 82 deletions.
32 changes: 26 additions & 6 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,24 @@ class ComputeDAGNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
};

/*!
* \brief Options for applying layout rewrite.
* This is an optimization to rewrite the layout of input tensors according to the schedule we get.
*/
enum class LayoutRewriteOption : int {
/*! \brief Do not process layout rewrite. */
NoRewrite = 0,
/*! \brief Insert layout transformation stages for input placeholders in the compute DAG */
InsertTransformStage = 1,
/*!
* \brief Do not insert layout transformation stages and assume the input placeholders
* are pre-transformed.
* \note The lowered function with this option does not accept the origial input shapes,
* so this option must be used along with a layout conversion pass in Relay.
*/
RewriteForPreTransformed = 2,
};

/*!
* \brief Managed reference to ComputeDAGNode.
* \sa ComputeDAGNode
Expand All @@ -214,8 +232,10 @@ class ComputeDAG : public ObjectRef {
* \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
* according to the loop nest derived with `transform_steps`.
* \param transform_steps Transform steps of a state.
* \param layout_rewrite Different options in layout rewrite.
* \return The updated ComputeDAG after layout rewrite.
*/
void RewriteLayout(const Array<Step>& transform_steps);
ComputeDAG RewriteLayout(Array<Step>* transform_steps, LayoutRewriteOption layout_rewrite) const;

/*!
* \brief Apply the history transform steps to get a TVM schedule.
Expand All @@ -225,14 +245,14 @@ class ComputeDAG : public ObjectRef {
* \param stage_to_axes The map that stores all axes for one stage.
* Pass a valid pointer if this information needs to be used outside this function.
* \param layout_rewrite Rewrite the layout of placeholders specified by
* attr `layout_free_placeholders`
* attr `layout_free_placeholders`.
* \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
* or `tvm.build`.
*/
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(const Array<Step>& transform_steps,
Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
bool layout_rewrite = false) const;
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
LayoutRewriteOption layout_rewrite = LayoutRewriteOption::NoRewrite) const;

/*!
* \brief Print transform steps as equivalent python schedule API.
Expand Down
46 changes: 31 additions & 15 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,23 @@ class StepNode : public Object {
*/
class Step : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
/*!
* \brief CopyOnWrite function for Step.
* This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different
* steps.
* \return A base StepNode pointer, need to cast to its real StepNode type before doing any
* modifications.
* \code
*
* SplitStep ref;
* StepNode* mutable_ref = ref.CopyOnWrite();
* dynamic_cast<SplitStepNode*>(mutable_ref)->... = ...;
*
* \endcode
*/
StepNode* CopyOnWrite();

TVM_DEFINE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
};

// Forward declaration
Expand Down Expand Up @@ -267,7 +283,7 @@ class AnnotationStepNode : public StepNode {
static constexpr const char* record_prefix_str = "AN";

static constexpr const char* _type_key = "auto_scheduler.AnnotationStep";
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(AnnotationStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -330,7 +346,7 @@ class FuseStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FU";

static constexpr const char* _type_key = "auto_scheduler.FuseStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FuseStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -390,7 +406,7 @@ class PragmaStepNode : public StepNode {
static constexpr const char* record_prefix_str = "PR";

static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -452,7 +468,7 @@ class ReorderStepNode : public StepNode {
static constexpr const char* record_prefix_str = "RE";

static constexpr const char* _type_key = "auto_scheduler.ReorderStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ReorderStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -527,7 +543,7 @@ class SplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "SP";

static constexpr const char* _type_key = "auto_scheduler.SplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(SplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -607,7 +623,7 @@ class FollowSplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FSP";

static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -688,7 +704,7 @@ class FollowFusedSplitStepNode : public StepNode {
static constexpr const char* record_prefix_str = "FFSP";

static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -754,7 +770,7 @@ class StorageAlignStepNode : public StepNode {
static constexpr const char* record_prefix_str = "SA";

static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -822,7 +838,7 @@ class ComputeAtStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CA";

static constexpr const char* _type_key = "auto_scheduler.ComputeAtStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeAtStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -879,7 +895,7 @@ class ComputeInlineStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CI";

static constexpr const char* _type_key = "auto_scheduler.ComputeInlineStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeInlineStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -938,7 +954,7 @@ class ComputeRootStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CR";

static constexpr const char* _type_key = "auto_scheduler.ComputeRootStep";
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(ComputeRootStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1010,7 +1026,7 @@ class CacheReadStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CHR";

static constexpr const char* _type_key = "auto_scheduler.CacheReadStep";
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(CacheReadStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1081,7 +1097,7 @@ class CacheWriteStepNode : public StepNode {
static constexpr const char* record_prefix_str = "CHW";

static constexpr const char* _type_key = "auto_scheduler.CacheWriteStep";
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(CacheWriteStepNode, StepNode);
};

/*!
Expand Down Expand Up @@ -1148,7 +1164,7 @@ class RfactorStepNode : public StepNode {
static constexpr const char* record_prefix_str = "RF";

static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, StepNode);
};

/*!
Expand Down
7 changes: 6 additions & 1 deletion python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ class ComputeDAG(Object):
Input/output tensors or workload key for a compute declaration.
"""

# Layout Rewrite Options
NoRewrite = 0
InsertTransformStage = 1
RewriteForPreTransformed = 2

def __init__(self, compute_or_sche):
if isinstance(compute_or_sche, str):
compute = workload_key_to_tensors(compute_or_sche)
Expand Down Expand Up @@ -81,7 +86,7 @@ def get_init_state(self):
"""
return State(self.init_state, self)

def apply_steps_from_state(self, state, layout_rewrite=False):
def apply_steps_from_state(self, state, layout_rewrite=NoRewrite):
"""
Apply the history transform steps from a State to get a TVM schedule.
Expand Down
Loading

0 comments on commit 4683d34

Please sign in to comment.