Skip to content

Commit b5eb32d

Browse files
junrushaoSiyuan FengspectrometerHBHjinhongyiiMasterJH5574
committed
Squashed commit
[Meta Schedule][M3c] Schedule Rules, Mutator & Postprocs (apache#485) [Meta Schedule][M3c] PostOrderApply (apache#486) Fix Post Order Apply (apache#490) [MetaSchedule] Relay Integration (apache#489) [M3c][Meta Schedule] Add Trace Correctness Test for PostOrderApply (apache#492) Fix replay trace. (apache#493) [M3c][Meta Schedule] Implement the Replay Func class. (apache#495) [PR] Test script for meta-schedule task extraction. Interface to load… (apache#494) [Meta Schedule Refactor] Get child blocks (apache#500) Read-at && Write-at (apache#497) [M3c][Meta Schedule] Measure Callbacks (apache#498) [Bug] Fix Infinite Loop Caused When Calling Methods Not Overrided In PyClass (apache#496) [MetaSchedule] Sample-Perfect-Tile (apache#501) [MetaSchedule] TE Workloads (apache#502) Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Sunghyun Park <49998730+sunggg@users.noreply.github.com>
1 parent 048994b commit b5eb32d

File tree

64 files changed

+4500
-61
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

64 files changed

+4500
-61
lines changed
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
21+
#define TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
22+
23+
#include <tvm/meta_schedule/builder.h>
24+
#include <tvm/meta_schedule/runner.h>
25+
#include <tvm/meta_schedule/search_strategy.h>
26+
#include <tvm/meta_schedule/tune_context.h>
27+
28+
namespace tvm {
29+
namespace meta_schedule {
30+
31+
class TaskScheduler;
32+
33+
/*! \brief Rules to apply after measure results is available. */
34+
class MeasureCallbackNode : public runtime::Object {
35+
public:
36+
/*! \brief Virtual destructor. */
37+
virtual ~MeasureCallbackNode() = default;
38+
39+
void VisitAttrs(tvm::AttrVisitor* v) {}
40+
41+
/*!
42+
* \brief Apply a measure callback rule with given arguments.
43+
* \param task_scheduler The task scheduler.
44+
* \param tasks The list of tune context to process.
45+
* \param measure_candidates The measure candidates.
46+
* \param builds The builder results by building the measure candidates.
47+
* \param results The runner results by running the built measure candidates.
48+
* \return Whether the measure callback was successfully applied.
49+
*/
50+
virtual bool Apply(const TaskScheduler& task_scheduler, //
51+
const Array<TuneContext> tasks, //
52+
const Array<MeasureCandidate>& measure_candidates, //
53+
const Array<BuilderResult>& builds, //
54+
const Array<RunnerResult>& results) = 0;
55+
56+
static constexpr const char* _type_key = "meta_schedule.MeasureCallback";
57+
TVM_DECLARE_BASE_OBJECT_INFO(MeasureCallbackNode, Object);
58+
};
59+
60+
/*! \brief The measure callback with customized methods on the python-side. */
61+
class PyMeasureCallbackNode : public MeasureCallbackNode {
62+
public:
63+
/*!
64+
* \brief Apply a measure callback to the given schedule.
65+
* \param task_scheduler The task scheduler.
66+
* \param tasks The list of tune context to process.
67+
* \param measure_candidates The measure candidates.
68+
* \param builds The builder results by building the measure candidates.
69+
* \param results The runner results by running the built measure candidates.
70+
* \return Whether the measure callback was successfully applied.
71+
*/
72+
using FApply =
73+
runtime::TypedPackedFunc<bool(const TaskScheduler& task_scheduler, //
74+
const Array<TuneContext> tasks, //
75+
const Array<MeasureCandidate>& measure_candidates, //
76+
const Array<BuilderResult>& builds, //
77+
const Array<RunnerResult>& results)>;
78+
/*!
79+
* \brief Get the measure callback function as string with name.
80+
* \return The string of the measure callback function.
81+
*/
82+
using FAsString = runtime::TypedPackedFunc<String()>;
83+
84+
/*! \brief The packed function to the `Apply` funcion. */
85+
FApply f_apply;
86+
/*! \brief The packed function to the `AsString` funcion. */
87+
FAsString f_as_string;
88+
89+
void VisitAttrs(tvm::AttrVisitor* v) {
90+
// `f_apply` is not visited
91+
// `f_as_string` is not visited
92+
}
93+
94+
bool Apply(const TaskScheduler& task_scheduler, //
95+
const Array<TuneContext> tasks, //
96+
const Array<MeasureCandidate>& measure_candidates, //
97+
const Array<BuilderResult>& builds, //
98+
const Array<RunnerResult>& results) final {
99+
ICHECK(f_apply != nullptr) << "PyMeasureCallback's Apply method not implemented!";
100+
return this->f_apply(task_scheduler, tasks, measure_candidates, builds, results);
101+
}
102+
103+
static constexpr const char* _type_key = "meta_schedule.PyMeasureCallback";
104+
TVM_DECLARE_FINAL_OBJECT_INFO(PyMeasureCallbackNode, MeasureCallbackNode);
105+
};
106+
107+
/*!
108+
* \brief Managed reference to MeasureCallbackNode
109+
* \sa MeasureCallbackNode
110+
*/
111+
class MeasureCallback : public runtime::ObjectRef {
112+
public:
113+
/*!
114+
* \brief Create a measure callback with customized methods on the python-side.
115+
* \param f_apply The packed function of `Apply`.
116+
* \return The measure callback created.
117+
*/
118+
TVM_DLL static MeasureCallback PyMeasureCallback(PyMeasureCallbackNode::FApply f_apply, //
119+
PyMeasureCallbackNode::FAsString f_as_string);
120+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(MeasureCallback, ObjectRef, MeasureCallbackNode);
121+
};
122+
123+
} // namespace meta_schedule
124+
} // namespace tvm
125+
126+
#endif // TVM_META_SCHEDULE_MEASURE_CALLBACK_H_
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_META_SCHEDULE_MUTATOR_H_
21+
#define TVM_META_SCHEDULE_MUTATOR_H_
22+
23+
#include <tvm/tir/schedule/schedule.h>
24+
25+
namespace tvm {
26+
namespace meta_schedule {
27+
28+
class TuneContext;
29+
30+
/*! \brief Mutator is designed to mutate the trace to explore the design space. */
31+
class MutatorNode : public runtime::Object {
32+
public:
33+
/*! \brief Virtual destructor. */
34+
virtual ~MutatorNode() = default;
35+
36+
void VisitAttrs(tvm::AttrVisitor* v) {}
37+
38+
/*!
39+
* \brief The function type of `InitializeWithTuneContext` method.
40+
* \param tune_context The tuning context for initialization.
41+
*/
42+
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
43+
44+
/*!
45+
* \brief Apply the mutator function to the given trace.
46+
* \param trace The given trace for mutation.
47+
* \return None if mutator failed, otherwise return the mutated trace.
48+
*/
49+
virtual Optional<tir::Trace> Apply(const tir::Trace& trace) = 0;
50+
51+
static constexpr const char* _type_key = "meta_schedule.Mutator";
52+
TVM_DECLARE_BASE_OBJECT_INFO(MutatorNode, Object);
53+
};
54+
55+
/*! \brief The mutator with customized methods on the python-side. */
56+
class PyMutatorNode : public MutatorNode {
57+
public:
58+
/*!
59+
* \brief The function type of `InitializeWithTuneContext` method.
60+
* \param tune_context The tuning context for initialization.
61+
*/
62+
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
63+
/*!
64+
* \brief Apply the mutator function to the given trace.
65+
* \param trace The given trace for mutation.
66+
* \return None if mutator failed, otherwise return the mutated trace.
67+
*/
68+
using FApply = runtime::TypedPackedFunc<Optional<tir::Trace>(const tir::Trace&)>;
69+
/*!
70+
* \brief Get the mutator as string with name.
71+
* \return The string of the mutator.
72+
*/
73+
using FAsString = runtime::TypedPackedFunc<String()>;
74+
75+
/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
76+
FInitializeWithTuneContext f_initialize_with_tune_context;
77+
/*! \brief The packed function to the `Apply` funcion. */
78+
FApply f_apply;
79+
/*! \brief The packed function to the `AsString` funcion. */
80+
FAsString f_as_string;
81+
82+
void VisitAttrs(tvm::AttrVisitor* v) {
83+
// `f_initialize_with_tune_context` is not visited
84+
// `f_apply` is not visited
85+
// `f_as_string` is not visited
86+
}
87+
88+
void InitializeWithTuneContext(const TuneContext& context) final {
89+
ICHECK(f_initialize_with_tune_context != nullptr)
90+
<< "PyMutator's InitializeWithTuneContext method not implemented!";
91+
this->f_initialize_with_tune_context(context);
92+
}
93+
94+
Optional<tir::Trace> Apply(const tir::Trace& trace) final {
95+
ICHECK(f_apply != nullptr) << "PyMutator's Apply method not implemented!";
96+
return this->f_apply(trace);
97+
}
98+
99+
static constexpr const char* _type_key = "meta_schedule.PyMutator";
100+
TVM_DECLARE_FINAL_OBJECT_INFO(PyMutatorNode, MutatorNode);
101+
};
102+
103+
/*!
104+
* \brief Managed reference to MutatorNode
105+
* \sa MutatorNode
106+
*/
107+
class Mutator : public runtime::ObjectRef {
108+
public:
109+
/*!
110+
* \brief Create a mutator with customized methods on the python-side.
111+
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
112+
* \param f_apply The packed function of `Apply`.
113+
* \return The mutator created.
114+
*/
115+
TVM_DLL static Mutator PyMutator(
116+
PyMutatorNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
117+
PyMutatorNode::FApply f_apply, //
118+
PyMutatorNode::FAsString f_as_string);
119+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Mutator, ObjectRef, MutatorNode);
120+
};
121+
122+
} // namespace meta_schedule
123+
} // namespace tvm
124+
125+
#endif // TVM_META_SCHEDULE_MUTATOR_H_
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
#ifndef TVM_META_SCHEDULE_POSTPROC_H_
21+
#define TVM_META_SCHEDULE_POSTPROC_H_
22+
23+
#include <tvm/tir/schedule/schedule.h>
24+
25+
namespace tvm {
26+
namespace meta_schedule {
27+
28+
class TuneContext;
29+
30+
/*!
31+
* \brief Rules to apply a post processing to a schedule.
32+
* \note Post processing is designed to deal with the problem of undertermined schedule validity
33+
* after applying some schedule primitves at runtime. E.g., Fuse the first X loops to reach the
34+
* maximum number below 1024, X is only decided at runtime.
35+
*/
36+
class PostprocNode : public runtime::Object {
37+
public:
38+
/*! \brief Virtual destructor. */
39+
virtual ~PostprocNode() = default;
40+
41+
void VisitAttrs(tvm::AttrVisitor* v) {}
42+
43+
/*!
44+
* \brief The function type of `InitializeWithTuneContext` method.
45+
* \param tune_context The tuning context for initialization.
46+
*/
47+
virtual void InitializeWithTuneContext(const TuneContext& context) = 0;
48+
49+
/*!
50+
* \brief Apply a post processing to the given schedule.
51+
* \param sch The schedule to be post processed.
52+
* \return Whether the post processing was successfully applied.
53+
*/
54+
virtual bool Apply(const tir::Schedule& schedule) = 0;
55+
56+
static constexpr const char* _type_key = "meta_schedule.Postproc";
57+
TVM_DECLARE_BASE_OBJECT_INFO(PostprocNode, Object);
58+
};
59+
60+
/*! \brief The post processing with customized methods on the python-side. */
61+
class PyPostprocNode : public PostprocNode {
62+
public:
63+
/*!
64+
* \brief The function type of `InitializeWithTuneContext` method.
65+
* \param tune_context The tuning context for initialization.
66+
*/
67+
using FInitializeWithTuneContext = runtime::TypedPackedFunc<void(const TuneContext&)>;
68+
/*!
69+
* \brief Apply a post processing to the given schedule.
70+
* \param sch The schedule to be post processed.
71+
* \return Whether the post processing was successfully applied.
72+
*/
73+
using FApply = runtime::TypedPackedFunc<bool(const tir::Schedule&)>;
74+
/*!
75+
* \brief Get the post processing function as string with name.
76+
* \return The string of the post processing function.
77+
*/
78+
using FAsString = runtime::TypedPackedFunc<String()>;
79+
80+
/*! \brief The packed function to the `InitializeWithTuneContext` funcion. */
81+
FInitializeWithTuneContext f_initialize_with_tune_context;
82+
/*! \brief The packed function to the `Apply` funcion. */
83+
FApply f_apply;
84+
/*! \brief The packed function to the `AsString` funcion. */
85+
FAsString f_as_string;
86+
87+
void VisitAttrs(tvm::AttrVisitor* v) {
88+
// `f_initialize_with_tune_context` is not visited
89+
// `f_apply` is not visited
90+
// `f_as_string` is not visited
91+
}
92+
93+
void InitializeWithTuneContext(const TuneContext& context) final {
94+
ICHECK(f_initialize_with_tune_context != nullptr)
95+
<< "PyPostproc's InitializeWithTuneContext method not implemented!";
96+
this->f_initialize_with_tune_context(context);
97+
}
98+
99+
bool Apply(const tir::Schedule& sch) final {
100+
ICHECK(f_apply != nullptr) << "PyPostproc's Apply method not implemented!";
101+
return this->f_apply(sch);
102+
}
103+
104+
static constexpr const char* _type_key = "meta_schedule.PyPostproc";
105+
TVM_DECLARE_FINAL_OBJECT_INFO(PyPostprocNode, PostprocNode);
106+
};
107+
108+
/*!
109+
* \brief Managed reference to PostprocNode
110+
* \sa PostprocNode
111+
*/
112+
class Postproc : public runtime::ObjectRef {
113+
public:
114+
/*!
115+
* \brief Create a post processing with customized methods on the python-side.
116+
* \param f_initialize_with_tune_context The packed function of `InitializeWithTuneContext`.
117+
* \param f_apply The packed function of `Apply`.
118+
* \return The post processing created.
119+
*/
120+
TVM_DLL static Postproc PyPostproc(
121+
PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, //
122+
PyPostprocNode::FApply f_apply, //
123+
PyPostprocNode::FAsString f_as_string);
124+
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode);
125+
};
126+
127+
} // namespace meta_schedule
128+
} // namespace tvm
129+
130+
#endif // TVM_META_SCHEDULE_POSTPROC_H_

0 commit comments

Comments
 (0)