@@ -79,6 +79,16 @@ def _parse_error_render_level(error_render_level: str) -> int:
79
79
return _ERROR_RENDER_LEVEL .get (error_render_level )
80
80
81
81
82
+ def _parse_seed (seed : Optional [int ]) -> int :
83
+ if seed is None :
84
+ return - 1
85
+ if not isinstance (seed , int ):
86
+ raise TypeError (f"Expected `seed` to be int or None, but gets: { seed } " )
87
+ if seed < 1 or seed > 2147483647 :
88
+ raise ValueError (f"seed must be in the range [1, 2147483647], but gets: { seed } " )
89
+ return seed
90
+
91
+
82
92
@_register_object ("tir.Schedule" )
83
93
class Schedule (Object ):
84
94
"""The user-facing schedule class
@@ -98,6 +108,7 @@ def __init__(
98
108
self ,
99
109
mod : Union [PrimFunc , IRModule ],
100
110
* ,
111
+ seed : Optional [int ] = None ,
101
112
debug_mask : Union [str , int ] = "none" ,
102
113
error_render_level : str = "detail" ,
103
114
) -> None :
@@ -107,6 +118,10 @@ def __init__(
107
118
----------
108
119
mod : Union[PrimFunc, IRModule]
109
120
The IRModule or PrimFunc to be scheduled
121
+ seed: Optional[int]
122
+ The seed value for schedule's random state
123
+ Note that None and -1 means use device random, otherwise only integer between 1 and
124
+ 2147483647 is allowed.
110
125
debug_mask : Union[str, int]
111
126
Do extra correctness checking after the class creation and each time
112
127
after calling the Replace method.
@@ -130,6 +145,7 @@ def __init__(
130
145
self .__init_handle_by_constructor__ (
131
146
_ffi_api .TracedSchedule , # type: ignore # pylint: disable=no-member
132
147
_parse_mod (mod ),
148
+ _parse_seed (seed ),
133
149
_parse_debug_mask (debug_mask ),
134
150
_parse_error_render_level (error_render_level ),
135
151
)
@@ -138,12 +154,14 @@ def __init__(
138
154
def _create_non_traced (
139
155
mod : Union [PrimFunc , IRModule ],
140
156
* ,
157
+ seed : Optional [int ] = None ,
141
158
debug_mask : Union [str , int ] = "none" ,
142
159
error_render_level : str = "detail" ,
143
160
) -> "Schedule" :
144
161
"""Construct a non-traced TensorIR schedule class from an IRModule."""
145
162
return _ffi_api .ConcreteSchedule ( # type: ignore # pylint: disable=no-member
146
163
_parse_mod (mod ),
164
+ _parse_seed (seed ),
147
165
_parse_debug_mask (debug_mask ),
148
166
_parse_error_render_level (error_render_level ),
149
167
)
@@ -190,6 +208,16 @@ def seed(self, seed: int) -> None:
190
208
"""
191
209
return _ffi_api .ScheduleSeed (self , seed ) # type: ignore # pylint: disable=no-member
192
210
211
+ def fork_seed (self ) -> int :
212
+ """Returns a forked random state as seed for new schedules
213
+
214
+ Returns
215
+ -------
216
+ seed : int
217
+ The forked random state, not the same as the current random state
218
+ """
219
+ return _ffi_api .ScheduleForkSeed (self ) # type: ignore # pylint: disable=no-member
220
+
193
221
def show (self , rand_var : RAND_VAR_TYPE ) -> str :
194
222
"""Returns a string representation of the value that the random variable evaluates to
195
223
@@ -268,6 +296,35 @@ def remove_rv(self, rand_var: RAND_VAR_TYPE) -> None:
268
296
269
297
########## Schedule: Sampling ##########
270
298
299
+ def sample_categorical (
300
+ self ,
301
+ candidates : List [int ],
302
+ probs : List [float ],
303
+ decision : Optional [int ] = None ,
304
+ ) -> ExprRV :
305
+ """Sample an integer given the probability distribution
306
+
307
+ Parameters
308
+ ----------
309
+ candidates : List[int]
310
+ The candidates to be sampled from
311
+ probs : List[float]
312
+ The probability of each candidate
313
+ decision : Optional[int]
314
+ The sampling decision, if any
315
+
316
+ Returns
317
+ -------
318
+ result : ExprRV
319
+ The random variable sampled from candidates
320
+ """
321
+ return _ffi_api .ScheduleSampleCategorical ( # type: ignore # pylint: disable=no-member
322
+ self ,
323
+ candidates ,
324
+ probs ,
325
+ decision ,
326
+ )
327
+
271
328
########## Schedule: Get blocks & loops ##########
272
329
def get_block (
273
330
self ,
0 commit comments