Skip to content

Commit

Permalink
[Bug Fix] Support threads_per_head < 64 for wavefront size of 64 (#6622)
Browse files Browse the repository at this point in the history
When launching apply_rotary_pos_half kernel, only threads_per_head of 64
is supported for wavefront size of 64.
This change adds support for threads_per_head < 64 such as 4, 8, 16.

Fixes the issue introduced in
#5402

---------

Signed-off-by: Jagadish Krishnamoorthy <jagadish.krishnamoorthy@amd.com>
Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
Co-authored-by: Logan Adams <loadams@microsoft.com>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent 6c08b7f commit 2b41d62
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
10 changes: 9 additions & 1 deletion csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,15 @@ __global__ void apply_rotary_pos_half(T* mixed_query,

#if defined(__HIP_PLATFORM_AMD__) and ROCM_WAVEFRONT_SIZE == 64
#define LAUNCH_FOR_ALIGNMENT(ALIGNMENT) \
if (threads_per_head == 64) { \
if (threads_per_head == 4) { \
LAUNCH_ROT_POS_EMB_HALF(4, ALIGNMENT); \
} else if (threads_per_head == 8) { \
LAUNCH_ROT_POS_EMB_HALF(8, ALIGNMENT); \
} else if (threads_per_head == 16) { \
LAUNCH_ROT_POS_EMB_HALF(16, ALIGNMENT); \
} else if (threads_per_head == 32) { \
LAUNCH_ROT_POS_EMB_HALF(32, ALIGNMENT); \
} else if (threads_per_head == 64) { \
LAUNCH_ROT_POS_EMB_HALF(64, ALIGNMENT); \
} else { \
assert(false); \
Expand Down
38 changes: 38 additions & 0 deletions tests/unit/ops/transformer/inference/test_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import deepspeed
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.accelerator import get_accelerator

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)


@pytest.mark.inference_ops
@pytest.mark.parametrize("num_heads", [64, 32, 16, 8])
def test_rope_warp_size_alignment(num_heads):
if get_accelerator().device_name() != "cuda":
pytest.skip("This test runs only on GPU")

batch = 1
head = 8
seq_len = 1024
head_dim = 32
rotary_dim = 32
offset = 8
rotate_half = False
rope_theta = 2

cuda0 = torch.device('cuda:0')
query = torch.randn(batch, head, seq_len, head_dim, device=cuda0)
key = torch.randn(batch, head, seq_len, head_dim, device=cuda0)

inference = InferenceBuilder().load()
# For num_heads values of 64, 32, 16, 8
# corresponding threads_per_head (defined in apply_rotary_pos_emb.cu) values are 4, 8, 16, 32
inference.apply_rotary_pos_emb(query, key, rotary_dim, offset, num_heads, rotate_half, rope_theta)

0 comments on commit 2b41d62

Please sign in to comment.