Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] [Triton] [AMD] Adding Triton implementations awq_dequantize and awq_gemm to support AWQ #7386

Merged
merged 59 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
ff27ffa
Add awq_dequantize_triton
rasmith Jul 26, 2024
f9b6e74
Add awq_dequantize_triton
rasmith Jul 26, 2024
7b49a76
Merge branch 'ransmith_awq_dequantize_triton' of github.com:rasmith/v…
rasmith Jul 31, 2024
e2c3ba5
Merge branch 'vllm-project:main' into ransmith_awq_dequantize_triton
rasmith Jul 31, 2024
ec14fe9
Use any instead of all
rasmith Jul 31, 2024
fd80f7f
ruff checks
rasmith Jul 31, 2024
370c9f0
run isort
rasmith Jul 31, 2024
bdd0ab7
run yapf
rasmith Jul 31, 2024
915e0ae
Format for PR
rasmith Jul 31, 2024
3b3a563
Merge branch 'ransmith_awq_dequantize_triton' of github.com:rasmith/v…
rasmith Aug 9, 2024
150db8c
Merge branch 'vllm-project:main' into ransmith_awq_dequantize_triton
rasmith Aug 9, 2024
00dee49
Merge branch 'main' into ransmith_awq_dequantize_triton
rasmith Aug 9, 2024
a8ef8c2
Merge branch 'ransmith_awq_dequantize_triton' of github.com:rasmith/v…
rasmith Aug 9, 2024
2ebd212
Have working awq_gemm in Triton
rasmith Aug 9, 2024
e3073bc
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 10, 2024
5326dde
Optimizations to awq_gemm
rasmith Aug 10, 2024
fb43aa4
Small cleanup
rasmith Aug 10, 2024
43abe7a
ruff and yapf linting/formatting
rasmith Aug 12, 2024
91c6741
isort/ruff fixing
rasmith Aug 12, 2024
962ea59
add env VLLM_USE_TRITON_AWQ
rasmith Aug 14, 2024
c9df260
Add tests
rasmith Aug 16, 2024
c7b63e8
awq for rocm in config
rasmith Aug 16, 2024
5cf14db
add dimension assertions
rasmith Aug 16, 2024
23cf001
fix typo
rasmith Aug 16, 2024
f94c1b0
yappity yapf
rasmith Aug 16, 2024
5887e77
merge main
rasmith Aug 16, 2024
8594e25
Merge branch 'vllm-project:main' into ransmith_awq_gemm_triton
rasmith Aug 16, 2024
64e5251
Merge main
rasmith Aug 16, 2024
86f2ec6
warning message for AWQ on ROCm and not setting VLLM_USE_TRITON_AWQ
rasmith Aug 19, 2024
d32212a
VLLM_USE_TRITON_AWQ enabled automatically
rasmith Aug 19, 2024
6514622
parameterized unit tests
rasmith Aug 20, 2024
8a1f6f2
cleanup
rasmith Aug 20, 2024
39d44a2
ruff
rasmith Aug 20, 2024
34e06b5
yapf
rasmith Aug 20, 2024
4f3148f
yapf
rasmith Aug 20, 2024
010c80e
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 20, 2024
4895074
test cleanup
rasmith Aug 21, 2024
0e1862c
test cleanup
rasmith Aug 21, 2024
24a6b3b
yapf
rasmith Aug 21, 2024
3d2854c
merge main
rasmith Aug 22, 2024
c3b8102
Adjust threshold
rasmith Aug 22, 2024
a84c7d7
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 23, 2024
c7fbacf
simplify unit test and use assert_close
rasmith Aug 24, 2024
11860d6
clean up test
rasmith Aug 24, 2024
bea93a2
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 24, 2024
0c45b68
use marlin tolerance
rasmith Aug 24, 2024
bbfb4d9
update test
rasmith Aug 24, 2024
13bb612
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 24, 2024
226e7fb
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 24, 2024
c4e3fd1
Merge branch 'vllm-project:main' into ransmith_awq_gemm_triton
rasmith Aug 25, 2024
62612ee
Support more group sizes
rasmith Aug 26, 2024
5d91e78
Merge branch 'main' into ransmith_awq_gemm_triton
rasmith Aug 26, 2024
ba434dc
Merge branch 'ransmith_awq_gemm_triton' of github.com:rasmith/vllm in…
rasmith Aug 26, 2024
2db93e0
assert added
rasmith Aug 26, 2024
f07c241
ruff
rasmith Aug 26, 2024
e95dfc4
ruff
rasmith Aug 26, 2024
efbd8a5
isort
rasmith Aug 26, 2024
69573dd
test update
rasmith Aug 26, 2024
d456232
update comment
rasmith Aug 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add awq_dequantize_triton
  • Loading branch information
rasmith committed Jul 31, 2024
commit f9b6e741a51231da8474e3459fce78239d22c700
10 changes: 10 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

logger = init_logger(__name__)

# NOTE: 7/26/24: Use Triton implementation of AWQ functions. Currently, only
# awq_dequantize is implemented, but awq_gemm coming soon.

use_awq_triton = False

try:
import vllm._C
except ImportError as e:
Expand All @@ -21,6 +26,7 @@
import vllm._punica_C



def is_custom_op_supported(op_name: str) -> bool:
op, overloads = torch._C._jit_get_operation(op_name)
return op is not None
Expand Down Expand Up @@ -183,6 +189,10 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
if use_awq_triton:
from vllm.model_executor.layers.quantization.awq_triton import awq_dequantize_triton
return awq_dequantize_triton(qweight, scales, zeros, split_k_iters,
thx, thy)
return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters,
thx, thy)

Expand Down
260 changes: 260 additions & 0 deletions vllm/model_executor/layers/quantization/awq_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
import torch

import triton
import triton.language as tl


@triton.jit
def awq_dequantize_kernel(qweight_ptr, # quantized matrix
scales_ptr, # scales, per group
zeros_ptr, # zeros, per group
split_k_iters, # Not used
thx, # Not used
thy, # Not used
group_size, # Should always be 128
result_ptr, # Output matrix
num_cols, # input num cols in qweight
num_rows, # input num rows in qweight
reverse_awq_order_ptr,
BLOCK_SIZE_X: tl.constexpr,
BLOCK_SIZE_Y: tl.constexpr):
# Setup the pids.
pid_x = tl.program_id(axis=0)
pid_y = tl.program_id(axis=1)

# Compute offsets and masks for qweight_ptr.
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]

masks_y = offsets_y < num_rows
masks_x = offsets_x < num_cols

masks = masks_y[:, None] & masks_x[None, :]

# Compute offsets and masks for result output ptr.
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
result_offsets_x = (pid_x * BLOCK_SIZE_X * 8
+ tl.arange(0, BLOCK_SIZE_X * 8))
result_offsets = (8 * num_cols * result_offsets_y[:, None]
+ result_offsets_x[None, :])

result_masks_y = result_offsets_y < num_rows
result_masks_x = result_offsets_x < num_cols * 8
result_masks = result_masks_y[:, None] & result_masks_x[None, :]

# Load the weights.
iweights = tl.load(qweight_ptr + offsets, masks)

# Load the AWQ reverse order offsets.
reverse_awq_order_offsets = tl.arange(0, 8)
reverse_awq_order_tensor = tl.load(reverse_awq_order_ptr +
reverse_awq_order_offsets)

# Use this to compute a set of shifts that can be used to unpack and
# reorder the values in iweights and zeros.
shifts = reverse_awq_order_tensor * 4
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))

# Unpack and reorder: shift out the correct 4-bit value and mask.
iweights = (iweights >> shifts) & 0xF

# Compute zero offsets and masks.
zero_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size
+ tl.arange(0, BLOCK_SIZE_Y) // group_size)
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X * 8) // 8
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]

zero_masks_y = zero_offsets_y < num_rows // group_size
zero_masks_x = zero_offsets_x < num_cols
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]

# Load the zeros.
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks)

# Unpack and reorder: shift out the correct 4-bit value and mask.
zeros = (zeros >> shifts) & 0xF

# Compute scale offsets and masks.
scale_offsets_y = (pid_y * BLOCK_SIZE_Y // group_size
+ tl.arange(0, BLOCK_SIZE_Y) // group_size)
scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8
+ tl.arange(0, BLOCK_SIZE_X * 8))
scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] +
scale_offsets_x[None, :])
scale_masks_y = scale_offsets_y < num_rows // group_size
scale_masks_x = scale_offsets_x < num_cols * 8
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]

# Load the scales.
scales = tl.load(scales_ptr + scale_offsets, scale_masks)

# Dequantize.
iweights = (iweights - zeros) * scales
iweights = iweights.to(tl.float16)

# Finally, store.
tl.store(result_ptr + result_offsets, iweights, result_masks)

# Example input:
# qweight.size=torch.Size([3584, 576]),
# qweight.dtype = torch.int32,
# scales.size=torch.Size([28, 4608]),
# scales.dtype=torch.float16,
# zeros.size=torch.Size([28, 576]),
# zeros.dtype=torch.int32
# split_k_iters=0
# thx=0
# thy=0
def awq_dequantize_triton(qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
split_k_iters: int, # Not used
thx: int, # Not used
thy: int # Not used
) -> torch.Tensor:
# Result tensor:
# number of rows = same as input tensor
# number of cols = 8 x input tensor num cols
result = torch.empty(qweight.shape[0],
qweight.shape[1] * 8,
device = qweight.device,
dtype = torch.float16)

block_size_x = 32
block_size_y = 32

reverse_awq_order = torch.tensor([0, 4, 1, 5, 2, 6, 3, 7],
dtype = torch.uint8, device = qweight.device)

Y = qweight.shape[0] # num rows
X = qweight.shape[1] # num cols
group_size = 128
grid = lambda META: (
triton.cdiv(X, META['BLOCK_SIZE_X']), triton.cdiv(Y,
META['BLOCK_SIZE_Y']), )
awq_dequantize_kernel[grid](qweight, scales, zeros, split_k_iters,
thx, thy, group_size, result, X, Y, reverse_awq_order,
BLOCK_SIZE_X = block_size_x, BLOCK_SIZE_Y = block_size_y)

return result


def reverse_awq_order(t: torch.Tensor):
bits = 4
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
reverse_order_tensor = torch.arange(
t.shape[-1],
dtype=torch.int32,
device=t.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)

t = t[:, reverse_order_tensor] & 0xF
return t

# qweightss [R , C // 8], int32
# scales - [R // G, C ], float16
# zeros - [R // G, C // 8], int32
def awq_dequantize_torch(qweight: torch.Tensor,
scales: torch.Tensor,
qzeros: torch.Tensor,
split_k_iters: int,
thx: int,
thy: int) -> torch.Tensor:
print(f"awq_dequantize_torch:qweight.shape = {qweight.shape}"
f", qzeros.shape={qzeros.shape}")
bits = 4
group_size = 128
shifts = torch.arange(0, 32, bits, device=qzeros.device)

iweights = torch.bitwise_right_shift(
qweight[:, :, None],
shifts[None, None, :]).to(torch.int8)

iweights = iweights.view(iweights.shape[0], -1)

# iweights = reverse_awq_order(iweights)
# return (iweights & 0xF).to(torch.float16)

zeros = torch.bitwise_right_shift(
qzeros[:, :, None], shifts[None, None, :]).to(torch.int8)

zeros = zeros.view(qzeros.shape[0], -1)

zeros = reverse_awq_order(zeros)
iweights = reverse_awq_order(iweights)

iweights = torch.bitwise_and(iweights, (2**bits) - 1)
zeros = torch.bitwise_and(zeros, (2**bits) - 1)


scales = scales.repeat_interleave(group_size, dim=0)
zeros = zeros.repeat_interleave(group_size, dim=0)
print(f"awq_dequantize_torch:iweights.shape = {iweights.shape},"
f"zeros.shape={zeros.shape}, "
f"scales.shape={scales.shape}")

return (iweights - zeros) * scales

def main():
use_triton = True
use_torch = True

qweight_rows = 3584
qweight_cols = 576
group_size = 128
small_test_size = True
if small_test_size:
qweight_rows = 256
qweight_cols = 128
print(f"qweight_rows = {qweight_rows}, qweight_cols = {qweight_cols}")
qweight_dtype = torch.int32
scales_rows = qweight_rows // group_size
scales_cols = qweight_cols * 8
scales_dtype = torch.float16
zeros_rows = scales_rows
zeros_cols = qweight_cols
zeros_dtype = torch.int32
split_k_iters=0
thx=0
thy=0
device='cuda'
torch.manual_seed(0)

qweight = torch.randint(0,10000000, (qweight_rows,
qweight_cols),
dtype=qweight_dtype,
device=device)
scales = torch.rand(scales_rows,
scales_cols,
dtype=scales_dtype,
device=device)
zeros = torch.randint(0, 10000000, (zeros_rows,
zeros_cols),
dtype=zeros_dtype,
device=device)
print(f"zeros.shape = {zeros.shape}")
print(f"qweight = {qweight}")
if use_triton:
iweights_triton = awq_dequantize_triton(
qweight, scales, zeros, split_k_iters, thx, thy)
print(f"Triton result:iweights_triton = {iweights_triton}")
print(f"Any infs in triton result? -->"
f"{not torch.all(torch.isinf(iweights_triton) == False)}")

if use_torch:
iweights_torch = awq_dequantize_torch(
qweight, scales, zeros, split_k_iters, thx, thy)
print(f"Torch result:iweights_torch = {iweights_torch}")

if use_torch and use_triton:
diff = iweights_torch - iweights_triton
error = torch.sum(torch.sqrt(diff * diff))
print(f"error = {error}")

if __name__ == '__main__':
main()