Skip to content

Commit

Permalink
[TensorIR] New schedule primitive set_dtype (apache#14316)
Browse files Browse the repository at this point in the history
# Motivation
Currently, we miss a schedule primitive to change the data type of allocated buffer (e.g. via `cache_read`/`cache_write`), and thus we cannot perform type conversion while loading data from global to shared memory.

This PR adds a new schedule primitive `set_dtype` that follows the interface of `set_scope` and allows users to customize the allocated buffers' data type.

# Example
Before running `set_dtype`:
```python
@T.prim_func
def before_set_dtype(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float32")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = A[vi, vj] * 2.0
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j]
            C[vi, vj] = B[vi, vj] + 1.0
```
then we perform the `set_dtype` schedule:
```python
sch = tir.Schedule(before_set_dtype)
sch.set_dtype("B", buffer_index=0, dtype="float16")
print(sch.mod["main"].script())
```
we get transformed code:
```python
@T.prim_func
def after_set_dtype(
    A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
    B = T.alloc_buffer((128, 128), dtype="float16")

    for i, j in T.grid(128, 128):
        with T.block("B"):
            vi, vj = T.axis.remap("SS", [i, j])
            B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
    for i, j in T.grid(128, 128):
        with T.block("C"):
            vi, vj = T.axis.remap("SS", [i, j]
            C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
```
where data type conversions are inserted automatically.

# Other Usage
Using the combination of `cache_read` + `set_dtype` can help us load data from the memory hierarchy while converting data to the desired type.
  • Loading branch information
yzh119 authored Mar 22, 2023
1 parent 0c2dd47 commit c7970dd
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 5 deletions.
12 changes: 11 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -589,13 +589,23 @@ class ScheduleNode : public runtime::Object {
virtual void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) = 0;
/*!
* \brief Set the storage scope of a buffer, where the buffer is specified by the a block and a
* \brief Set the storage scope of a buffer, where the buffer is specified by a block and a
* write-index
* \param block_rv The producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param storage_scope The storage scope to be set
*/
virtual void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) = 0;
/*!
* \brief Set the data type of a buffer, where the buffer is specified by a block and a
* write-index
* \note This schedule primitive is unsafe and may change correctness of program because of
* type conversion, please use with caution.
* \param block_rv The producer block of the buffer
* \param buffer_index the index of the buffer in block's write region
* \param dtype The data type to be set
*/
virtual void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) = 0;
/******** Schedule: Blockize & Tensorize ********/
/*!
* \brief Convert the subtree rooted at a specific loop into a block.
Expand Down
79 changes: 77 additions & 2 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -2369,7 +2369,7 @@ def set_scope(
self, block: Union[BlockRV, str], buffer_index: Union[int, str, Buffer], storage_scope: str
) -> None:
"""Set the storage scope of a buffer, where the buffer is
specified by the a block and a write-index
specified by the a block and a write-index.
Parameters
----------
Expand Down Expand Up @@ -2431,7 +2431,7 @@ def after_set_scope(
Note
----
Set_scope requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
`set_scope` requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
if not isinstance(buffer_index, int):
Expand All @@ -2442,6 +2442,81 @@ def after_set_scope(
self, block, buffer_index, storage_scope
)

@type_checked
def unsafe_set_dtype(self, block: Union[BlockRV, str], buffer_index: int, dtype: str) -> None:
"""Set the data type of a buffer, where the buffer is
specified by the a block and write-index.
This schedule primitive is unsafe and may change the correctness of program because of
type conversion, please use with caution.
Parameters
----------
block : Union[BlockRV, str]
The producer block of the buffer
buffer_index : int
The index of the buffer in block's write region
dtype : str
The data type to be set
Examples
--------
Before set_dtype, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_set_dtype(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), dtype="float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j]
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do set_dtype:
.. code-block:: python
sch = tir.Schedule(before_set_dtype)
sch.set_dtype("B", buffer_index=0, dtype="float16")
print(sch.mod["main"].script())
After applying set_dtype, the IR becomes:
.. code-block:: python
@T.prim_func
def after_set_dtype(
A: T.Buffer((128, 128), "float32"), C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), dtype="float16")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = T.cast(A[vi, vj] * 2.0, "float16")
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j]
C[vi, vj] = T.cast(B[vi, vj], "float32") + 1.0
Note
----
`set_dtype` requires the buffer to be an intermediate buffer defined via `alloc_buffer`.
"""
block = self._normalize_block_arg(block)
_ffi_api.ScheduleUnsafeSetDType( # type: ignore # pylint: disable=no-member
self, block, buffer_index, dtype
)

########## Schedule: Blockize & Tensorize ##########

@type_checked
Expand Down
8 changes: 8 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,14 @@ void ConcreteScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
this->state_->DebugVerify();
}

void ConcreteScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index,
const String& dtype) {
TVM_TIR_SCHEDULE_BEGIN();
tir::UnsafeSetDType(state_, this->GetSRef(block_rv), buffer_index, dtype);
TVM_TIR_SCHEDULE_END("set-dtype", this->error_render_level_);
this->state_->DebugVerify();
}

/******** Schedule: Reduction ********/

BlockRV ConcreteScheduleNode::DecomposeReduction(const BlockRV& block_rv, const LoopRV& loop_rv) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ class ConcreteScheduleNode : public ScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) override;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) override;
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) override;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) override;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) override;
Expand Down
12 changes: 12 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,18 @@ TVM_DLL void StorageAlign(ScheduleState self, const StmtSRef& block_sref, int bu
*/
TVM_DLL void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& storage_scope);
/*!
* \brief Set the data type of a buffer, where the buffer is specified by a block and a
* write-index
* \note This schedule primitive is unsafe and may change correctness of program because of
* type conversion, please use with caution.
* \param self The state of the schedule
* \param block_sref The sref of the producer block of the buffer
* \param buffer_index The index of the buffer in block's write region
* \param dtype The data type to be set
*/
TVM_DLL void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& dtype);
/*!
* \brief Set the axis separator of a buffer, where the buffer is specified by a block and a read
* or write index
Expand Down
117 changes: 117 additions & 0 deletions src/tir/schedule/primitive/block_annotate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
* specific language governing permissions and limitations
* under the License.
*/
#include <tvm/tir/expr.h>

#include "../utils.h"

namespace tvm {
Expand Down Expand Up @@ -297,6 +299,93 @@ void SetScope(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
self->Replace(alloc_site_sref, new_block, block_reuse_map);
}

/*!
* \brief A helper mutator which recursively mutates the old buffer's data type, inserts data type
* conversions, and collecte the block sref reuse information for the following replacement.
*/
class DTypeMutator : private ReplaceBufferMutator {
public:
/*!
* \param allocate_site The block where `old_buffer` was allocated.
* \param old_buffer The old buffer
* \param target_dtype The data type to be set
* \param block_sref_reuse The block sref reuse map to be updated
* \return The new block after the mutation
*/
static Block Mutate(const Block& allocate_site, const Buffer& old_buffer, const DataType& dtype,
Map<Block, Block>* block_sref_reuse) {
Buffer new_buffer = WithDType(old_buffer, dtype);
DTypeMutator mutator(old_buffer, new_buffer, dtype, block_sref_reuse);
Stmt new_block = mutator.VisitStmt(allocate_site);
return Downcast<Block>(new_block);
}

private:
DTypeMutator(const Buffer& old_buffer, Buffer new_buffer, const DataType& dtype,
Map<Block, Block>* block_sref_reuse)
: ReplaceBufferMutator(old_buffer, std::move(new_buffer), block_sref_reuse),
src_dtype_(old_buffer->dtype),
tgt_dtype_(dtype) {}

MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer) final {
auto it = buffer_var_map_.find(match_buffer->source->buffer->data.get());
if (it != buffer_var_map_.end()) {
Buffer new_target_buffer = WithDType(match_buffer->buffer, it->second->dtype);
buffer_var_map_[match_buffer->buffer->data.get()] = new_target_buffer;
return MatchBufferRegion(new_target_buffer,
BufferRegion(it->second, match_buffer->source->region));
} else {
return match_buffer;
}
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore node = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
auto it = buffer_var_map_.find(node->buffer->data.get());
if (it != buffer_var_map_.end()) {
node.CopyOnWrite()->buffer = it->second;
node.CopyOnWrite()->value = Cast(tgt_dtype_, node->value);
}
return node;
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
auto it = buffer_var_map_.find(node->buffer->data.get());
if (it != buffer_var_map_.end()) {
return Cast(src_dtype_, BufferLoad(it->second, node->indices));
}
return node;
}

DataType src_dtype_, tgt_dtype_;
};

void UnsafeSetDType(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
const String& dtype) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
Buffer buffer =
GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, BufferIndexType::kWrite);
DataType target_dtype(runtime::String2DLDataType(dtype));

// Step 1. If `dtype` equals the original data type, just return.
if (buffer->dtype == target_dtype) {
return;
}

// Step 2. Get the allocation site of the target buffer.
StmtSRef alloc_site_sref =
NonAllocatedBufferError::CheckAndGetBufferAllocationSite(self->mod, block_sref, buffer);
const BlockNode* alloc_site = TVM_SREF_TO_BLOCK(alloc_site_sref);

// Step 3. Recursively replace old buffer to a new buffer, where the new buffer has the given
// dtype, and insert data type conversions.
Map<Block, Block> block_reuse_map;
Block new_block =
DTypeMutator::Mutate(GetRef<Block>(alloc_site), buffer, target_dtype, &block_reuse_map);
self->Replace(alloc_site_sref, new_block, block_reuse_map);
}

/******** InstructionKind Registration ********/

struct StorageAlignTraits : public UnpackedInstTraits<StorageAlignTraits> {
Expand Down Expand Up @@ -356,8 +445,36 @@ struct SetScopeTraits : public UnpackedInstTraits<SetScopeTraits> {
friend struct ::tvm::tir::UnpackedInstTraits;
};

struct UnsafeSetDTypeTraits : public UnpackedInstTraits<UnsafeSetDTypeTraits> {
static constexpr const char* kName = "UnsafeSetDType";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, Integer buffer_index,
String dtype) {
return sch->UnsafeSetDType(block_rv, buffer_index->value, dtype);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, Integer buffer_index,
String dtype) {
PythonAPICall py("unsafe_set_dtype");
py.Input("block", block_rv);
py.Input("buffer_index", buffer_index);
py.Input("dtype", dtype);
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(StorageAlignTraits);
TVM_REGISTER_INST_KIND_TRAITS(SetScopeTraits);
TVM_REGISTER_INST_KIND_TRAITS(UnsafeSetDTypeTraits);

} // namespace tir
} // namespace tvm
2 changes: 2 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,8 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleStorageAlign")
.set_body_method<Schedule>(&ScheduleNode::StorageAlign);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleSetScope")
.set_body_method<Schedule>(&ScheduleNode::SetScope);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeSetDType")
.set_body_method<Schedule>(&ScheduleNode::UnsafeSetDType);
/******** (FFI) Blockize & Tensorize ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleBlockize")
.set_body_method<Schedule>(&ScheduleNode::Blockize);
Expand Down
11 changes: 11 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,17 @@ void TracedScheduleNode::SetScope(const BlockRV& block_rv, int buffer_index,
/*outputs=*/{}));
}

void TracedScheduleNode::UnsafeSetDType(const BlockRV& block_rv, int buffer_index,
const String& dtype) {
ConcreteScheduleNode::UnsafeSetDType(block_rv, buffer_index, dtype);
static const InstructionKind& kind = InstructionKind::Get("UnsafeSetDType");
trace_->Append(/*inst=*/Instruction(
/*kind=*/kind,
/*inputs=*/{block_rv},
/*attrs=*/{Integer(buffer_index), dtype},
/*outputs=*/{}));
}

/******** Schedule: Blockize & Tensorize ********/

BlockRV TracedScheduleNode::Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class TracedScheduleNode : public ConcreteScheduleNode {
void StorageAlign(const BlockRV& block_rv, int buffer_index, int axis, int factor,
int offset) final;
void SetScope(const BlockRV& block_rv, int buffer_index, const String& storage_scope) final;
void UnsafeSetDType(const BlockRV& block_rv, int buffer_index, const String& dtype) final;
/******** Schedule: Blockize & Tensorize ********/
BlockRV Blockize(const LoopRV& loop_rv, bool preserve_unit_iters) final;
void Tensorize(const BlockRV& block_rv, const String& intrin, bool preserve_unit_iters) final;
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ Buffer WithScope(const Buffer& buffer, const String& scope) {
return Buffer(new_buffer);
}

Buffer WithDType(const Buffer& buffer, const DataType& dtype) {
ObjectPtr<BufferNode> new_buffer = make_object<BufferNode>(*buffer.get());
new_buffer->dtype = dtype;
const auto* ptr_type = TVM_TYPE_AS(buffer->data->type_annotation, PointerTypeNode);
new_buffer->data =
Var(buffer->data->name_hint, PointerType(PrimType(dtype), ptr_type->storage_scope));
new_buffer->name = buffer->name;
return Buffer(new_buffer);
}

Array<BufferRegion> ReplaceBuffer(Array<BufferRegion> regions, const Buffer& source,
const Buffer& target) {
regions.MutateByApply([&source, &target](BufferRegion region) -> BufferRegion {
Expand Down
12 changes: 10 additions & 2 deletions src/tir/schedule/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ Block WithAnnotation(const BlockNode* block, const String& attr_key, const Objec
*/
Buffer WithScope(const Buffer& buffer, const String& scope);

/*!
* \brief Create a new buffer by changint the data type.
* \param buffer The given buffer.
* \param scope The target data type.
* \return The new buffer with target data type.
*/
Buffer WithDType(const Buffer& buffer, const DataType& dtype);

/*!
* \brief Replaces the buffer within the specific sequence of regions
* \param regions The regions whose buffers are to be replaced
Expand Down Expand Up @@ -131,9 +139,9 @@ class ReplaceBufferMutator : public StmtExprMutator {
return node;
}

Stmt VisitStmt_(const BufferStoreNode* op) final;
Stmt VisitStmt_(const BufferStoreNode* op) override;

PrimExpr VisitExpr_(const BufferLoadNode* op) final;
PrimExpr VisitExpr_(const BufferLoadNode* op) override;

virtual MatchBufferRegion VisitMatchBufferRegion(const MatchBufferRegion& match_buffer);

Expand Down
Loading

0 comments on commit c7970dd

Please sign in to comment.