Skip to content

Commit 2859c20

Browse files
zxybazhjunrushaovinx13MasterJH5574jinhongyii
authored
[M3a][Meta Schedule] Add Sampling Primitive SampleCategorical. (#8817)
Co-authored-by: Junru Shao <junrushao1994@gmail.com> Co-authored-by: Wuwei Lin <wuwei@apache.org> Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com> Co-authored-by: Hongyi Jin <3231950289@qq.com> Co-authored-by: Siyuan Feng <Hzfengsy@sjtu.edu.cn> Co-authored-by: Bohan Hou <32121147+spectrometerHBH@users.noreply.github.com>
1 parent 44a1d1f commit 2859c20

File tree

12 files changed

+429
-44
lines changed

12 files changed

+429
-44
lines changed

include/tvm/support/random_engine.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
/*!
2121
* \file random_engine.h
22-
* \brief Random number generator, for Sampler and Sampling functions.
22+
* \brief Random number generator. It provides a generic interface consistent with
23+
* `std::uniform_random_bit_generator`
2324
*/
2425

2526
#ifndef TVM_SUPPORT_RANDOM_ENGINE_H_
@@ -41,10 +42,11 @@ namespace support {
4142
* included for simplification. For full member functions of std::minstd_rand, please check out the
4243
* following link: https://en.cppreference.com/w/cpp/numeric/random/linear_congruential_engine
4344
*/
45+
4446
class LinearCongruentialEngine {
4547
public:
4648
/*!
47-
* \brief The result type is defined as int64_t here for meta_schedule sampler usage.
49+
* \brief The result type is defined as uint64_t here to avoid overflow.
4850
* \note The type name is not in Google style because it is used in STL's distribution inferface.
4951
*/
5052
using result_type = uint64_t;
@@ -63,13 +65,13 @@ class LinearCongruentialEngine {
6365
* \brief The minimum possible value of random state here.
6466
* \note The function name is uncapilized because it is used in STL's distribution inferface.
6567
*/
66-
result_type min() { return 0; }
68+
static constexpr result_type min() { return 0; }
6769

6870
/*!
6971
* \brief The maximum possible value of random state here.
7072
* \note The function name is uncapilized because it is used in STL's distribution inferface.
7173
*/
72-
result_type max() { return modulus - 1; }
74+
static constexpr result_type max() { return modulus - 1; }
7375

7476
/*!
7577
* \brief Operator to move the random state to the next and return the new random state. According

include/tvm/tir/schedule/schedule.h

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#ifndef TVM_TIR_SCHEDULE_SCHEDULE_H_
2020
#define TVM_TIR_SCHEDULE_SCHEDULE_H_
2121

22+
#include <tvm/support/random_engine.h>
2223
#include <tvm/tir/schedule/state.h>
2324
#include <tvm/tir/schedule/trace.h>
2425

@@ -118,9 +119,9 @@ class ScheduleNode : public runtime::Object {
118119
* \brief Seed the randomness
119120
* \param seed The new random seed, -1 if use device random, otherwise non-negative
120121
*/
121-
virtual void Seed(int64_t seed = -1) {
122-
LOG(FATAL) << "ValueError: The schedule cannot be seeded because no randomness is allowed";
123-
}
122+
virtual void Seed(support::LinearCongruentialEngine::TRandState seed = -1) = 0;
123+
/*! \brief Fork the random state */
124+
virtual support::LinearCongruentialEngine::TRandState ForkSeed() = 0;
124125

125126
public:
126127
/******** Lookup/Remove random variables ********/
@@ -184,6 +185,16 @@ class ScheduleNode : public runtime::Object {
184185

185186
public:
186187
/******** Schedule: Sampling ********/
188+
/*!
189+
* \brief Sample an integer given the probability distribution
190+
* \param candidates The candidates
191+
* \param probs The probability distribution of the candidates
192+
* \param decision The sampling decision
193+
* \return The random variable sampled from candidates
194+
*/
195+
virtual ExprRV SampleCategorical(const Array<Integer>& candidates, const Array<FloatImm>& probs,
196+
Optional<Integer> decision = NullOpt) = 0;
197+
187198
/******** Schedule: Get blocks & loops ********/
188199
/*!
189200
* \brief Retrieve a block in a specific function with its name
@@ -356,6 +367,7 @@ class Schedule : public runtime::ObjectRef {
356367
/*!
357368
* \brief Construct a concrete TensorIR schedule from an IRModule
358369
* \param mod The IRModule to be scheduled
370+
* \param seed The seed value for schedule's random state
359371
* \param debug_mask Do extra correctness checking after the class creation
360372
* and each time after calling the Replace method.
361373
* \param error_render_level The level of error rendering
@@ -365,11 +377,12 @@ class Schedule : public runtime::ObjectRef {
365377
* 1) VerifySRefTree
366378
* 2) VerifyCachedFlags
367379
*/
368-
TVM_DLL static Schedule Concrete(IRModule mod, int debug_mask,
369-
ScheduleErrorRenderLevel error_render_level);
380+
TVM_DLL static Schedule Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
381+
int debug_mask, ScheduleErrorRenderLevel error_render_level);
370382
/*!
371383
* \brief Construct a traced concrete TensorIR schedule from an IRModule
372384
* \param mod The IRModule to be scheduled
385+
* \param seed The seed value for schedule's random state
373386
* \param debug_mask Do extra correctness checking after the class creation
374387
* and each time after calling the Replace method.
375388
* \param error_render_level The level of error rendering
@@ -379,8 +392,8 @@ class Schedule : public runtime::ObjectRef {
379392
* 1) VerifySRefTree
380393
* 2) VerifyCachedFlags
381394
*/
382-
TVM_DLL static Schedule Traced(IRModule mod, int debug_mask,
383-
ScheduleErrorRenderLevel error_render_level);
395+
TVM_DLL static Schedule Traced(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
396+
int debug_mask, ScheduleErrorRenderLevel error_render_level);
384397
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Schedule, runtime::ObjectRef, ScheduleNode);
385398
};
386399

python/tvm/tir/schedule/schedule.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,16 @@ def _parse_error_render_level(error_render_level: str) -> int:
7979
return _ERROR_RENDER_LEVEL.get(error_render_level)
8080

8181

82+
def _parse_seed(seed: Optional[int]) -> int:
83+
if seed is None:
84+
return -1
85+
if not isinstance(seed, int):
86+
raise TypeError(f"Expected `seed` to be int or None, but gets: {seed}")
87+
if seed < 1 or seed > 2147483647:
88+
raise ValueError(f"seed must be in the range [1, 2147483647], but gets: {seed}")
89+
return seed
90+
91+
8292
@_register_object("tir.Schedule")
8393
class Schedule(Object):
8494
"""The user-facing schedule class
@@ -98,6 +108,7 @@ def __init__(
98108
self,
99109
mod: Union[PrimFunc, IRModule],
100110
*,
111+
seed: Optional[int] = None,
101112
debug_mask: Union[str, int] = "none",
102113
error_render_level: str = "detail",
103114
) -> None:
@@ -107,6 +118,10 @@ def __init__(
107118
----------
108119
mod : Union[PrimFunc, IRModule]
109120
The IRModule or PrimFunc to be scheduled
121+
seed: Optional[int]
122+
The seed value for schedule's random state
123+
Note that None and -1 means use device random, otherwise only integer between 1 and
124+
2147483647 is allowed.
110125
debug_mask : Union[str, int]
111126
Do extra correctness checking after the class creation and each time
112127
after calling the Replace method.
@@ -130,6 +145,7 @@ def __init__(
130145
self.__init_handle_by_constructor__(
131146
_ffi_api.TracedSchedule, # type: ignore # pylint: disable=no-member
132147
_parse_mod(mod),
148+
_parse_seed(seed),
133149
_parse_debug_mask(debug_mask),
134150
_parse_error_render_level(error_render_level),
135151
)
@@ -138,12 +154,14 @@ def __init__(
138154
def _create_non_traced(
139155
mod: Union[PrimFunc, IRModule],
140156
*,
157+
seed: Optional[int] = None,
141158
debug_mask: Union[str, int] = "none",
142159
error_render_level: str = "detail",
143160
) -> "Schedule":
144161
"""Construct a non-traced TensorIR schedule class from an IRModule."""
145162
return _ffi_api.ConcreteSchedule( # type: ignore # pylint: disable=no-member
146163
_parse_mod(mod),
164+
_parse_seed(seed),
147165
_parse_debug_mask(debug_mask),
148166
_parse_error_render_level(error_render_level),
149167
)
@@ -190,6 +208,16 @@ def seed(self, seed: int) -> None:
190208
"""
191209
return _ffi_api.ScheduleSeed(self, seed) # type: ignore # pylint: disable=no-member
192210

211+
def fork_seed(self) -> int:
212+
"""Returns a forked random state as seed for new schedules
213+
214+
Returns
215+
-------
216+
seed : int
217+
The forked random state, not the same as the current random state
218+
"""
219+
return _ffi_api.ScheduleForkSeed(self) # type: ignore # pylint: disable=no-member
220+
193221
def show(self, rand_var: RAND_VAR_TYPE) -> str:
194222
"""Returns a string representation of the value that the random variable evaluates to
195223
@@ -268,6 +296,35 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
268296

269297
########## Schedule: Sampling ##########
270298

299+
def sample_categorical(
300+
self,
301+
candidates: List[int],
302+
probs: List[float],
303+
decision: Optional[int] = None,
304+
) -> ExprRV:
305+
"""Sample an integer given the probability distribution
306+
307+
Parameters
308+
----------
309+
candidates : List[int]
310+
The candidates to be sampled from
311+
probs : List[float]
312+
The probability of each candidate
313+
decision : Optional[int]
314+
The sampling decision, if any
315+
316+
Returns
317+
-------
318+
result : ExprRV
319+
The random variable sampled from candidates
320+
"""
321+
return _ffi_api.ScheduleSampleCategorical( # type: ignore # pylint: disable=no-member
322+
self,
323+
candidates,
324+
probs,
325+
decision,
326+
)
327+
271328
########## Schedule: Get blocks & loops ##########
272329
def get_block(
273330
self,

src/support/array.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
*/
1919
#ifndef TVM_SUPPORT_ARRAY_H_
2020
#define TVM_SUPPORT_ARRAY_H_
21+
#include <tvm/ir/expr.h>
2122
#include <tvm/runtime/container/array.h>
2223

2324
#include <vector>
@@ -67,6 +68,73 @@ inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>
6768
return true;
6869
}
6970

71+
/*!
72+
* \brief Convert a tvm::runtime::Array to std::vector
73+
* \tparam TSrc The type of elements in the source Array
74+
* \tparam TDst The type of elements in the result vector
75+
* \return The result vector
76+
*/
77+
template <class TSrc, class TDst>
78+
std::vector<TDst> AsVector(const Array<TSrc>& vec);
79+
80+
/********** Implementation details of AsVector<TSrc, TDst> **********/
81+
namespace details {
82+
83+
template <class TSrc, class TDst>
84+
struct AsVectorImpl {};
85+
86+
template <class TSrc>
87+
struct AsVectorImpl<TSrc, TSrc> {
88+
inline std::vector<TSrc> operator()(const Array<TSrc>& vec) const {
89+
return std::vector<TSrc>(vec.begin(), vec.end());
90+
}
91+
};
92+
93+
template <class TSrcObjectRef>
94+
struct AsVectorImpl<TSrcObjectRef, int> {
95+
inline std::vector<int> operator()(const Array<TSrcObjectRef>& vec) const {
96+
std::vector<int> results;
97+
for (const TSrcObjectRef& x : vec) {
98+
const auto* n = x.template as<IntImmNode>();
99+
ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
100+
results.push_back(n->value);
101+
}
102+
return results;
103+
}
104+
};
105+
106+
template <class TSrcObjectRef>
107+
struct AsVectorImpl<TSrcObjectRef, int64_t> {
108+
inline std::vector<int64_t> operator()(const Array<TSrcObjectRef>& vec) const {
109+
std::vector<int64_t> results;
110+
for (const TSrcObjectRef& x : vec) {
111+
const auto* n = x.template as<IntImmNode>();
112+
ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
113+
results.push_back(n->value);
114+
}
115+
return results;
116+
}
117+
};
118+
119+
template <class TSrcObjectRef>
120+
struct AsVectorImpl<TSrcObjectRef, double> {
121+
inline std::vector<double> operator()(const Array<TSrcObjectRef>& array) const {
122+
std::vector<double> results;
123+
for (const TSrcObjectRef& x : array) {
124+
const auto* n = x.template as<FloatImmNode>();
125+
ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey();
126+
results.push_back(n->value);
127+
}
128+
return results;
129+
}
130+
};
131+
} // namespace details
132+
133+
template <class TSrc, class TDst>
134+
inline std::vector<TDst> AsVector(const Array<TSrc>& vec) {
135+
return details::AsVectorImpl<TSrc, TDst>()(vec);
136+
}
137+
70138
} // namespace support
71139
} // namespace tvm
72140
#endif // TVM_SUPPORT_ARRAY_H_

src/tir/schedule/concrete_schedule.cc

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,19 @@
1818
*/
1919
#include "./concrete_schedule.h"
2020

21+
#include <random>
22+
2123
namespace tvm {
2224
namespace tir {
2325

24-
Schedule Schedule::Concrete(IRModule mod, int debug_mask,
25-
ScheduleErrorRenderLevel error_render_level) {
26+
Schedule Schedule::Concrete(IRModule mod, support::LinearCongruentialEngine::TRandState seed,
27+
int debug_mask, ScheduleErrorRenderLevel error_render_level) {
2628
ObjectPtr<ConcreteScheduleNode> n = make_object<ConcreteScheduleNode>();
2729
n->state_ = ScheduleState(mod, debug_mask);
2830
n->error_render_level_ = error_render_level;
2931
n->symbol_table_ = {};
3032
n->analyzer_ = std::make_unique<arith::Analyzer>();
33+
support::LinearCongruentialEngine(&n->rand_state_).Seed(seed);
3134
return Schedule(std::move(n));
3235
}
3336

@@ -208,6 +211,29 @@ Schedule ConcreteScheduleNode::Copy() const {
208211
}
209212

210213
/******** Schedule: Schedule: Sampling ********/
214+
215+
void ConcreteScheduleNode::Seed(support::LinearCongruentialEngine::TRandState seed) {
216+
if (seed == -1) {
217+
seed = std::random_device()();
218+
}
219+
support::LinearCongruentialEngine(&rand_state_).Seed(seed);
220+
}
221+
222+
support::LinearCongruentialEngine::TRandState ConcreteScheduleNode::ForkSeed() {
223+
// In order for reproducibility, we computer the new seed using RNG's random state and a different
224+
// set of parameters. Note that both 32767 and 1999999973 are prime numbers.
225+
return (support::LinearCongruentialEngine(&rand_state_)() * 32767) % 1999999973;
226+
}
227+
228+
ExprRV ConcreteScheduleNode::SampleCategorical(const Array<Integer>& candidates,
229+
const Array<FloatImm>& probs,
230+
Optional<Integer> decision) {
231+
TVM_TIR_SCHEDULE_BEGIN();
232+
return CreateRV(tir::SampleCategorical(&this->rand_state_, candidates, probs, &decision));
233+
TVM_TIR_SCHEDULE_END("sample-categorical", this->error_render_level_);
234+
throw;
235+
}
236+
211237
/******** Schedule: Get blocks & loops ********/
212238

213239
BlockRV ConcreteScheduleNode::GetBlock(const String& name, const String& func_name) {

0 commit comments

Comments
 (0)