34
34
namespace tvm {
35
35
namespace tir {
36
36
37
- class StackSizeChecker : public StmtExprVisitor {
37
+ // Calculate the statistics of packed function.
38
+ // These information are needed during codegen.
39
+ class BuiltinLower : public StmtExprMutator {
38
40
public:
39
41
struct StackSizes {
40
42
// If a tvm_stack_make_shape call has no arguments, it is still
@@ -46,159 +48,86 @@ class StackSizeChecker : public StmtExprVisitor {
46
48
uint64_t arg_stack{0 };
47
49
};
48
50
49
- static StackSizes Check (Stmt stmt) {
50
- StackSizeChecker visitor;
51
- visitor.VisitStmt (stmt);
52
- return visitor.max_stack_ ;
53
- }
54
-
55
- private:
56
- void VisitStmt_ (const ForNode* op) final {
57
- if (op->kind == ForKind::kParallel ) {
58
- // Parallel for loops have their own stack and allocations, so
59
- // stop the recursion here.
60
- return ;
61
- } else {
62
- this ->VisitStmt (op->body );
63
- }
64
- }
65
- void VisitExpr_ (const CallNode* op) final {
66
- if (op->op .same_as (builtin::tvm_call_packed ())) {
67
- return MakeCallPacked (op, /* use_string_lookup */ true );
68
- } else if (op->op .same_as (builtin::tvm_call_cpacked ())) {
69
- return MakeCallPacked (op, /* use_string_lookup */ false );
70
- } else if (op->op .same_as (builtin::tvm_call_trace_packed ())) {
71
- return MakeCallTracePacked (op);
72
- } else if (op->op .same_as (builtin::tvm_stack_make_shape ())) {
73
- return MakeShape (op);
74
- } else if (op->op .same_as (builtin::tvm_stack_make_array ())) {
75
- return MakeArray (op);
76
- } else {
77
- return StmtExprVisitor::VisitExpr_ (op);
78
- }
79
- }
80
- // call shape
81
- void MakeShape (const CallNode* op) {
82
- // if args.size() == 0, it is still valid and represents a scalar
83
- // shape (). Therefore, -1 is used to represent "no shape
84
- // arguments exist", while 0 represents "shape arguments exist,
85
- // all of which are size 0".
86
- if (current_stack_.shape_stack == -1 ) {
87
- current_stack_.shape_stack = 0 ;
88
- }
89
- current_stack_.shape_stack += op->args .size ();
90
- StmtExprVisitor::VisitExpr_ (op);
91
- }
92
- // make array
93
- void MakeArray (const CallNode* op) {
94
- current_stack_.array_stack += 1 ;
95
- StmtExprVisitor::VisitExpr_ (op);
96
- }
97
- // call packed.
98
- void MakeCallPacked (const CallNode* op, bool use_string_lookup) {
99
- StackSizes restore_stack = current_stack_;
100
-
101
- size_t arg_count = op->args .size ();
102
-
103
- // cpacked expects a resource_handle parameter
104
- if (!use_string_lookup) {
105
- arg_count--;
106
- }
107
-
108
- current_stack_.arg_stack += arg_count;
109
- // Specially handle the buffer packed intrinsic
110
- StmtExprVisitor::VisitExpr_ (op);
111
- // Record the amount of stack space needed, then reset the stack
112
- // position to its previous location.
113
- UpdateMaxStack ();
114
- current_stack_ = restore_stack;
115
- }
116
-
117
- void MakeCallTracePacked (const CallNode* op) {
118
- StackSizes restore_stack = current_stack_;
119
-
120
- size_t args_size = op->args .size ();
121
- ICHECK_GT (args_size, 0 );
122
- current_stack_.arg_stack += args_size;
123
-
124
- StmtExprVisitor::VisitExpr_ (op);
125
- // Record the amount of stack space needed, then reset the stack
126
- // position to its previous location.
127
- UpdateMaxStack ();
128
- current_stack_ = restore_stack;
129
-
130
- // However, the arguments to this CallNode remain on top of the
131
- // stack, so we can use more than one packed function's arguments
132
- // with the one stack.
133
- current_stack_.arg_stack = restore_stack.arg_stack + args_size - 1 ;
134
- }
135
-
136
- void UpdateMaxStack () {
137
- max_stack_.arg_stack = std::max (current_stack_.arg_stack , max_stack_.arg_stack );
138
- max_stack_.shape_stack = std::max (current_stack_.shape_stack , max_stack_.shape_stack );
139
- max_stack_.array_stack = std::max (current_stack_.array_stack , max_stack_.array_stack );
140
- }
141
-
142
- StackSizes current_stack_;
143
- StackSizes max_stack_;
144
- };
145
-
146
- // Calculate the statistics of packed function.
147
- // These information are needed during codegen.
148
- class BuiltinLower : public StmtExprMutator {
149
- public:
150
51
// Record stack frame for existing scope.
151
52
struct AllocaScope {
152
53
Buffer stack_shape;
153
54
Var stack_array = Var(" stack_array" , DataType::Handle());
154
55
Var stack_value = Var(" stack_value" , DataType::Handle());
155
56
Buffer stack_tcode;
156
57
157
- int64_t max_shape_stack{-1 };
158
- uint64_t max_array_stack{0 };
159
- uint64_t max_arg_stack{0 };
58
+ StackSizes max_sizes;
59
+ StackSizes run_sizes;
160
60
161
- int64_t run_shape_stack{-1 };
162
- uint64_t run_array_stack{0 };
163
- uint64_t run_arg_stack{0 };
61
+ void UpdateMax () {
62
+ max_sizes.shape_stack = std::max (max_sizes.shape_stack , run_sizes.shape_stack );
63
+ max_sizes.array_stack = std::max (max_sizes.array_stack , run_sizes.array_stack );
64
+ max_sizes.arg_stack = std::max (max_sizes.arg_stack , run_sizes.arg_stack );
65
+ }
66
+
67
+ void AssertMaxIsValid () const {
68
+ ICHECK ((max_sizes.shape_stack >= run_sizes.shape_stack ) ||
69
+ (max_sizes.array_stack >= run_sizes.array_stack ) ||
70
+ (max_sizes.arg_stack >= run_sizes.arg_stack ));
71
+ }
164
72
};
165
73
166
74
Stmt Build (Stmt stmt) { return this ->VisitBodyAndRealizeAlloca (stmt); }
167
75
76
+ StackSizes GetMaxStack (Stmt stmt) {
77
+ BuiltinLower precheck;
78
+ precheck.is_precheck_ = true ;
79
+ precheck.device_id_ = this ->device_id_ ;
80
+ precheck.device_type_ = this ->device_type_ ;
81
+
82
+ precheck.alloca_scope_ .emplace_back ();
83
+ auto & scope = precheck.alloca_scope_ .back ();
84
+ scope.stack_shape =
85
+ decl_buffer ({IntImm (DataType::Int (64 ), 0 )}, DataType::Int (64 ), " stack_shape" );
86
+ scope.stack_tcode =
87
+ decl_buffer ({IntImm (DataType::UInt (64 ), 0 )}, DataType::Int (32 ), " stack_tcode" );
88
+
89
+ precheck.VisitStmt (stmt);
90
+
91
+ ICHECK_EQ (precheck.alloca_scope_ .size (), 1 );
92
+ return precheck.alloca_scope_ [0 ].max_sizes ;
93
+ }
94
+
168
95
// Allcoate stack frames, only at parallel-for or root.
169
96
Stmt VisitBodyAndRealizeAlloca (Stmt stmt) {
170
- // Initial check to identify maximum stack sizes. These are used
171
- // to construct Buffer objects to hold the stack, which are then
172
- // used when mutating.
173
- auto max_sizes = StackSizeChecker::Check (stmt);
97
+ // Only perform the precheck up to the point where we would add a
98
+ // new scope.
99
+ if (is_precheck_) {
100
+ return stmt;
101
+ }
174
102
175
103
alloca_scope_.emplace_back ();
176
104
auto & scope = alloca_scope_.back ();
177
105
178
- if (max_sizes.shape_stack != -1 ) {
179
- scope.stack_shape = decl_buffer ({IntImm (DataType::Int (64 ), max_sizes.shape_stack )},
106
+ // Initial check to identify maximum stack sizes. These are used
107
+ // to construct Buffer objects to hold the stack, which are then
108
+ // used when mutating.
109
+ scope.max_sizes = GetMaxStack (stmt);
110
+
111
+ if (scope.max_sizes .shape_stack != -1 ) {
112
+ scope.stack_shape = decl_buffer ({IntImm (DataType::Int (64 ), scope.max_sizes .shape_stack )},
180
113
DataType::Int (64 ), " stack_shape" );
181
- stmt = LetStmt (scope.stack_shape ->data , StackAlloca (" shape" , max_sizes.shape_stack ), stmt);
114
+ stmt =
115
+ LetStmt (scope.stack_shape ->data , StackAlloca (" shape" , scope.max_sizes .shape_stack ), stmt);
182
116
}
183
117
184
- if (max_sizes.array_stack != 0 ) {
185
- stmt = LetStmt (scope.stack_array , StackAlloca (" array" , max_sizes.array_stack ), stmt);
118
+ if (scope. max_sizes .array_stack != 0 ) {
119
+ stmt = LetStmt (scope.stack_array , StackAlloca (" array" , scope. max_sizes .array_stack ), stmt);
186
120
}
187
121
188
- if (max_sizes.arg_stack != 0 ) {
189
- scope.stack_tcode = decl_buffer ({IntImm (DataType::UInt (64 ), max_sizes.arg_stack )},
122
+ if (scope. max_sizes .arg_stack != 0 ) {
123
+ scope.stack_tcode = decl_buffer ({IntImm (DataType::UInt (64 ), scope. max_sizes .arg_stack )},
190
124
DataType::Int (32 ), " stack_tcode" );
191
- stmt = LetStmt (scope.stack_value , StackAlloca (" arg_value" , max_sizes.arg_stack ), stmt);
125
+ stmt = LetStmt (scope.stack_value , StackAlloca (" arg_value" , scope. max_sizes .arg_stack ), stmt);
192
126
193
- stmt = LetStmt (scope.stack_tcode ->data , StackAlloca (" arg_tcode" , max_sizes.arg_stack ), stmt);
127
+ stmt = LetStmt (scope.stack_tcode ->data , StackAlloca (" arg_tcode" , scope.max_sizes .arg_stack ),
128
+ stmt);
194
129
}
195
130
196
- // Copy these values from the earlier search, for use in bounds
197
- // checks.
198
- scope.max_shape_stack = max_sizes.shape_stack ;
199
- scope.max_array_stack = max_sizes.array_stack ;
200
- scope.max_arg_stack = max_sizes.arg_stack ;
201
-
202
131
stmt = this ->VisitStmt (stmt);
203
132
204
133
ICHECK (!alloca_scope_.empty ());
@@ -213,8 +142,8 @@ class BuiltinLower : public StmtExprMutator {
213
142
214
143
auto stmt = StmtExprMutator::VisitStmt (s);
215
144
auto & scope = alloca_scope_.back ();
216
- ICHECK_EQ (scope.run_shape_stack , -1 );
217
- ICHECK_EQ (scope.run_array_stack , 0 );
145
+ ICHECK_EQ (scope.run_sizes . shape_stack , -1 );
146
+ ICHECK_EQ (scope.run_sizes . array_stack , 0 );
218
147
219
148
auto prep_seq = std::move (prep_seq_stack_.back ());
220
149
prep_seq_stack_.pop_back ();
@@ -364,11 +293,11 @@ class BuiltinLower : public StmtExprMutator {
364
293
ICHECK (!alloca_scope_.empty ());
365
294
auto & scope = alloca_scope_.back ();
366
295
auto & prep_seq = prep_seq_stack_.back ();
367
- if (scope.run_shape_stack == -1 ) {
368
- scope.run_shape_stack = 0 ;
296
+ if (scope.run_sizes . shape_stack == -1 ) {
297
+ scope.run_sizes . shape_stack = 0 ;
369
298
}
370
- int64_t stack_begin = scope.run_shape_stack ;
371
- scope.run_shape_stack += op->args .size ();
299
+ int64_t stack_begin = scope.run_sizes . shape_stack ;
300
+ scope.run_sizes . shape_stack += op->args .size ();
372
301
PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
373
302
op = expr.as <CallNode>();
374
303
// no need to perform any store for a scalar shape
@@ -384,8 +313,8 @@ class BuiltinLower : public StmtExprMutator {
384
313
auto & scope = alloca_scope_.back ();
385
314
auto & prep_seq = prep_seq_stack_.back ();
386
315
387
- size_t idx = scope.run_array_stack ;
388
- scope.run_array_stack += 1 ;
316
+ size_t idx = scope.run_sizes . array_stack ;
317
+ scope.run_sizes . array_stack += 1 ;
389
318
PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
390
319
op = expr.as <CallNode>();
391
320
@@ -426,9 +355,9 @@ class BuiltinLower : public StmtExprMutator {
426
355
auto & scope = alloca_scope_.back ();
427
356
auto & prep_seq = prep_seq_stack_.back ();
428
357
429
- int64_t restore_shape_stack = scope.run_shape_stack ;
430
- size_t restore_array_stack = scope.run_array_stack ;
431
- size_t arg_stack_begin = scope.run_arg_stack ;
358
+ int64_t restore_shape_stack = scope.run_sizes . shape_stack ;
359
+ size_t restore_array_stack = scope.run_sizes . array_stack ;
360
+ size_t arg_stack_begin = scope.run_sizes . arg_stack ;
432
361
433
362
size_t arg_count = op->args .size ();
434
363
@@ -437,7 +366,7 @@ class BuiltinLower : public StmtExprMutator {
437
366
arg_count--;
438
367
}
439
368
440
- scope.run_arg_stack += arg_count;
369
+ scope.run_sizes . arg_stack += arg_count;
441
370
// Specially handle the buffer packed intrinsic
442
371
PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
443
372
op = expr.as <CallNode>();
@@ -460,12 +389,14 @@ class BuiltinLower : public StmtExprMutator {
460
389
prep_seq.emplace_back (BufferStore (scope.stack_tcode , ConstInt32 (arg_tcode), {stack_index}));
461
390
}
462
391
// Verify stack size matches earlier value.
463
- ICHECK_LE (scope.run_arg_stack , scope.max_arg_stack );
464
- ICHECK_LE (scope.run_shape_stack , scope.max_shape_stack );
465
- ICHECK_LE (scope.run_array_stack , scope.max_array_stack );
466
- scope.run_shape_stack = restore_shape_stack;
467
- scope.run_array_stack = restore_array_stack;
468
- scope.run_arg_stack = arg_stack_begin;
392
+ if (is_precheck_) {
393
+ scope.UpdateMax ();
394
+ } else {
395
+ scope.AssertMaxIsValid ();
396
+ }
397
+ scope.run_sizes .shape_stack = restore_shape_stack;
398
+ scope.run_sizes .array_stack = restore_array_stack;
399
+ scope.run_sizes .arg_stack = arg_stack_begin;
469
400
Array<PrimExpr> packed_args = {op->args [0 ], scope.stack_value , scope.stack_tcode ->data ,
470
401
ConstInt32 (arg_stack_begin),
471
402
ConstInt32 (arg_stack_begin + op->args .size () - 1 )};
@@ -486,10 +417,10 @@ class BuiltinLower : public StmtExprMutator {
486
417
auto & scope = alloca_scope_.back ();
487
418
auto & prep_seq = prep_seq_stack_.back ();
488
419
489
- int64_t restore_shape_stack = scope.run_shape_stack ;
490
- size_t restore_array_stack = scope.run_array_stack ;
491
- size_t arg_stack_begin = scope.run_arg_stack ;
492
- scope.run_arg_stack += op->args .size ();
420
+ int64_t restore_shape_stack = scope.run_sizes . shape_stack ;
421
+ size_t restore_array_stack = scope.run_sizes . array_stack ;
422
+ size_t arg_stack_begin = scope.run_sizes . arg_stack ;
423
+ scope.run_sizes . arg_stack += op->args .size ();
493
424
size_t args_size = op->args .size ();
494
425
ICHECK_GT (args_size, 0 );
495
426
PrimExpr expr = StmtExprMutator::VisitExpr_ (op);
@@ -510,14 +441,16 @@ class BuiltinLower : public StmtExprMutator {
510
441
prep_seq.emplace_back (BufferStore (scope.stack_tcode , ConstInt32 (arg_tcode), {stack_index}));
511
442
}
512
443
// Verify stack size matches earlier value.
513
- ICHECK_LE (scope.run_arg_stack , scope.max_arg_stack );
514
- ICHECK_LE (scope.run_shape_stack , scope.max_shape_stack );
515
- ICHECK_LE (scope.run_array_stack , scope.max_array_stack );
516
- scope.run_shape_stack = restore_shape_stack;
517
- scope.run_array_stack = restore_array_stack;
444
+ if (is_precheck_) {
445
+ scope.UpdateMax ();
446
+ } else {
447
+ scope.AssertMaxIsValid ();
448
+ }
449
+ scope.run_sizes .shape_stack = restore_shape_stack;
450
+ scope.run_sizes .array_stack = restore_array_stack;
518
451
// Update the top of the stack, so we can use more than one
519
452
// packed function's arguments with the one stack.
520
- scope.run_arg_stack = arg_stack_begin + args_size - 1 ;
453
+ scope.run_sizes . arg_stack = arg_stack_begin + args_size - 1 ;
521
454
Array<PrimExpr> packed_args = {op->args [0 ], scope.stack_value , scope.stack_tcode ->data ,
522
455
ConstInt32 (arg_stack_begin),
523
456
ConstInt32 (arg_stack_begin + op->args .size () - 1 ),
@@ -575,6 +508,8 @@ class BuiltinLower : public StmtExprMutator {
575
508
PrimExpr device_type_;
576
509
PrimExpr device_id_;
577
510
511
+ bool is_precheck_{false };
512
+
578
513
// Record all stack frames.
579
514
std::vector<AllocaScope> alloca_scope_;
580
515
};
0 commit comments