@@ -248,7 +248,6 @@ class StoreNode : public StmtNode {
248248 * \endcode
249249 * \sa BufferLoad
250250 */
251- class BufferStore ;
252251class BufferStoreNode : public StmtNode {
253252 public:
254253 /* ! \brief The buffer variable. */
@@ -281,6 +280,10 @@ class BufferStoreNode : public StmtNode {
281280 TVM_DECLARE_FINAL_OBJECT_INFO (BufferStoreNode, StmtNode);
282281};
283282
283+ /* !
284+ * \brief Managed reference to BufferStoreNode.
285+ * \sa BufferStoreNode
286+ */
284287class BufferStore : public Stmt {
285288 public:
286289 TVM_DLL explicit BufferStore (Buffer buffer,
@@ -289,8 +292,80 @@ class BufferStore : public Stmt {
289292 TVM_DEFINE_OBJECT_REF_METHODS (BufferStore, Stmt, BufferStoreNode);
290293};
291294
295+ /* !
296+ * \brief Annotate the region where the buffer need to
297+ * be read and write in the body.
298+ * We only need to allocate the space for the corresponding region.
299+ *
300+ * \note There should be at most one BufferRealize for each buffer.
301+ * BufferRealize is not necessary for external buffers,
302+ * since they are assumed to be fully allocated.
303+ *
304+ * \sa BufferLoad, BufferStore
305+ */
306+ class BufferRealizeNode : public StmtNode {
307+ public:
308+ /* ! \brief The buffer variable. */
309+ Buffer buffer;
310+ /* ! \brief Bounds to be realized */
311+ Array<Range> bounds;
312+ /* ! \brief Only realize if condition holds. */
313+ PrimExpr condition;
314+ /* ! \brief The body of realization. */
315+ Stmt body;
316+
317+ void VisitAttrs (AttrVisitor* v) {
318+ v->Visit (" buffer" , &buffer);
319+ v->Visit (" bounds" , &bounds);
320+ v->Visit (" condition" , &condition);
321+ v->Visit (" body" , &body);
322+ }
323+
324+ bool SEqualReduce (const BufferRealizeNode* other, SEqualReducer equal) const {
325+ return
326+ equal (buffer, other->buffer ) &&
327+ equal (bounds, other->bounds ) &&
328+ equal (condition, other->condition ) &&
329+ equal (body, other->body );
330+ }
331+
332+ void SHashReduce (SHashReducer hash_reduce) const {
333+ hash_reduce (buffer);
334+ hash_reduce (bounds);
335+ hash_reduce (condition);
336+ hash_reduce (body);
337+ }
338+
339+ BufferRealizeNode () = default ;
340+ BufferRealizeNode (Buffer buffer,
341+ Array<Range> bounds,
342+ PrimExpr condition,
343+ Stmt body)
344+ : buffer(buffer), bounds(bounds),
345+ condition (condition), body(body) {}
346+
347+ static constexpr const char * _type_key = " BufferRealize" ;
348+ TVM_DECLARE_FINAL_OBJECT_INFO (BufferRealizeNode, StmtNode);
349+ };
350+
351+ /* !
352+ * \brief Managed reference to BufferRealizeNode.
353+ * \sa BufferRealizeNode
354+ */
355+ class BufferRealize : public Stmt {
356+ public:
357+ TVM_DLL explicit BufferRealize (Buffer buffer,
358+ Array<Range> bounds,
359+ PrimExpr condition,
360+ Stmt body);
361+
362+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (BufferRealize, Stmt, BufferRealizeNode);
363+ };
364+
292365/* !
293366 * \brief Store value into mult-dimensional array defined by func.
367+ *
368+ * \note Deprecated, move to BufferStore in the future.
294369 */
295370class ProvideNode : public StmtNode {
296371 public:
@@ -430,6 +505,8 @@ class FreeNode : public StmtNode {
430505/* !
431506 * \brief Annotate the bounds where func need to be written and read in body.
432507 * We will need to allocate space for the corresponding regions.
508+ *
509+ * \note Deprecated, move to BufferRealize in the future.
433510 */
434511class RealizeNode : public StmtNode {
435512 public:
@@ -747,50 +824,50 @@ class ForNode : public StmtNode {
747824};
748825
749826/* !
750- * \brief A prefetch hint of func.
827+ * \brief A prefetch hint for abuffer
751828 */
752829class PrefetchNode : public StmtNode {
753830 public:
754831 /* ! \brief The function to be prefetched. */
755- FunctionRef func;
756- /* ! \brief The output value index if func's value is a tuple. */
757- int value_index;
758- /* ! \brief The data type of the array. */
759- DataType dtype;
832+ Buffer buffer;
760833 /* ! \brief Bounds to be prefetched. */
761- Region bounds;
834+ Array<Range> bounds;
762835
763836 void VisitAttrs (AttrVisitor* v) {
764- v->Visit (" func" , &func);
765- v->Visit (" value_index" , &value_index);
766- v->Visit (" dtype" , &dtype);
837+ v->Visit (" buffer" , &buffer);
767838 v->Visit (" bounds" , &bounds);
768839 }
769840
770841 bool SEqualReduce (const PrefetchNode* other, SEqualReducer equal) const {
771842 return
772- equal (func, other->func ) &&
773- equal (value_index, other->value_index ) &&
774- equal (dtype, other->dtype ) &&
843+ equal (buffer, other->buffer ) &&
775844 equal (bounds, other->bounds );
776845 }
777846
778847 void SHashReduce (SHashReducer hash_reduce) const {
779- hash_reduce (func);
780- hash_reduce (value_index);
781- hash_reduce (dtype);
848+ hash_reduce (buffer);
782849 hash_reduce (bounds);
783850 }
784851
785- TVM_DLL static Stmt make (FunctionRef func,
786- int value_index,
787- DataType dtype,
788- Region bounds);
852+ PrefetchNode () = default ;
853+ PrefetchNode (Buffer buffer, Array<Range> bounds)
854+ : buffer(buffer), bounds(bounds) {}
789855
790856 static constexpr const char * _type_key = " Prefetch" ;
791857 TVM_DECLARE_FINAL_OBJECT_INFO (PrefetchNode, StmtNode);
792858};
793859
860+ /* !
861+ * \brief Managed reference to PrefetchNode.
862+ * \sa PrefetchNode
863+ */
864+ class Prefetch : public Stmt {
865+ public:
866+ TVM_DLL explicit Prefetch (Buffer buffer, Array<Range> bounds);
867+
868+ TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS (Prefetch, Stmt, PrefetchNode);
869+ };
870+
794871/* !
795872 * \brief Auxiliary data structure used in IR Pass to indicate a tensor.
796873 */
0 commit comments