Skip to content

Commit

Permalink
[Unity][DLight] Fix general_reduction for GroupNorm (apache#16161)
Browse files Browse the repository at this point in the history
The pass `lower_thread_allreduce` failed to codegen if the spatial
loops are unrolled during the schedule. This PR works around the issue
by changing the schedule rules.
  • Loading branch information
Hzfengsy authored Nov 28, 2023
1 parent 8f24a27 commit af803cf
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 83 deletions.
19 changes: 14 additions & 5 deletions python/tvm/dlight/gpu/general_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Reduction rule for operators including softmax, layer norm, RMS norm, etc"""
from typing import List, Union

Expand Down Expand Up @@ -75,20 +76,28 @@ def apply( # pylint: disable=too-many-locals
return None

loops = sch.get_loops(block_infos[-1].block_rv)
bx = sch.fuse(*loops[:num_leading_s]) # pylint: disable=invalid-name
_, tx = sch.split(loops[-1], [None, len_tx]) # pylint: disable=invalid-name
bx = sch.fuse(*loops[:num_leading_s])
r_loop, tx = sch.split(loops[-1], [None, len_tx])
sch.reorder(tx, r_loop)
sch.bind(bx, "blockIdx.x")
sch.bind(tx, "threadIdx.x")
sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1)

for block in reversed(block_infos[:-1]):
block = block.block_rv
for i, _ in enumerate(sch.get(block).writes):
sch.set_scope(block, buffer_index=i, storage_scope="shared")
sch.compute_at(block, bx, preserve_unit_loops=True)
r_loop = sch.fuse(*sch.get_loops(block)[-num_trailing_r:])
_, tx = sch.split(r_loop, [None, len_tx]) # pylint: disable=invalid-name
r_loop, tx = sch.split(r_loop, [None, len_tx])
sch.reorder(tx, r_loop)
sch.bind(tx, "threadIdx.x")
sch.annotate(r_loop, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(r_loop, ann_key="pragma_unroll_explicit", ann_val=1)

sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)
# TODO: It's just a workaround to avoid unroll spatial loops, because of the bug of
# the pass lower-thread-allreduce. We should fix it in the future.
# sch.annotate(bx, ann_key="pragma_auto_unroll_max_step", ann_val=unroll_depth)
# sch.annotate(bx, ann_key="pragma_unroll_explicit", ann_val=1)
return sch
Loading

0 comments on commit af803cf

Please sign in to comment.