Skip to content

Commit

Permalink
Add fp8_gemm fallback for non-triton systems (#6916)
Browse files Browse the repository at this point in the history
- Removed try/except from __init__ file in fp_quantizer and added a
single entry point instead
- Renamed file fp8_gemm to fp8_gemm_triton, and the function matmul_fp8
to matmul_fp8_triton
- Added a new entry point fp8_gemm with matmul_fp8 inside, and if the
system supports triton it calls the triton implementation and if not it
calls the fallback

Co-authored-by: Logan Adams <114770087+loadams@users.noreply.github.com>
  • Loading branch information
oelayan7 and loadams authored Jan 6, 2025
1 parent f8c9f31 commit c5e48f4
Show file tree
Hide file tree
Showing 4 changed files with 189 additions and 168 deletions.
7 changes: 1 addition & 6 deletions deepspeed/ops/fp_quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,4 @@
# DeepSpeed Team

from .quantize import FP_Quantize, Quantizer

try:
import triton
from .fp8_gemm import matmul_fp8
except ImportError:
pass
from .fp8_gemm import matmul_fp8
163 changes: 10 additions & 153 deletions deepspeed/ops/fp_quantizer/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,161 +11,18 @@
###################################

import torch
import triton
import triton.language as tl


@triton.jit
def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
def matmul_fp8(inp, weight, scale, quantization_group_size, quantizer):
from deepspeed import get_accelerator

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
if not get_accelerator().is_triton_supported():
return matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer)
else:
# Import dynamically to prevent failures on systems without triton.
from .fp8_gemm_triton import matmul_fp8_triton
return matmul_fp8_triton(inp, weight, scale, quantization_group_size)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
mask=weight_mask,
other=0.0)

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


@triton.jit
def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> fp16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
w = (w + 0x2000).to(tl.uint16)
w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


def matmul_fp8(inp, weight, scale, quantization_group_size):

assert inp.shape[1] == weight.shape[0], \
f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"

M, K = inp.shape
K, N = weight.shape

out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)

# GEMM tuning parameters!
# TODO: Add a more configurable tuning for selecting the best GeMM
BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = max(64, quantization_group_size)
GROUP_SIZE_M = 8
num_stages = 4
num_warps = 4
if M >= 256:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = max(128, quantization_group_size)
num_stages = 3
num_warps = 8

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
kernel[grid](inp,
weight,
out,
scale,
M,
N,
K,
inp.stride(0),
inp.stride(1),
weight.stride(0),
weight.stride(1),
out.stride(0),
out.stride(1),
quantization_group_size=quantization_group_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M,
num_stages=num_stages,
num_warps=num_warps)
return out
def matmul_fp8_fallback(inp, weight, scale, quantization_group_size, quantizer):
return torch.matmul(inp, quantizer.dequantize(weight, scale=scale))
171 changes: 171 additions & 0 deletions deepspeed/ops/fp_quantizer/fp8_gemm_triton.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

######## Fused MoE kernel #########
# These kernels are implemented for
# fusing GeMM with dequantization of
# fp8 weight data when using bit-16
# activation.
###################################

import torch
import triton
import triton.language as tl


@triton.jit
def matmul_kernel_fp8_bf16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk
weight_mask = offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K
weight = tl.load(weight_data, mask=weight_mask, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)),
mask=weight_mask,
other=0.0)

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


@triton.jit
def matmul_kernel_fp8_fp16(inp_ptr, weight_ptr, out_ptr, scale_ptr, M, N, K, stride_am, stride_ak, stride_bk,
stride_bn, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + (
(pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> fp16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 7)).to(tl.uint16)
w = (w + 0x2000).to(tl.uint16)
w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset +
(((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator += tl.dot(inp, w)

out = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


def matmul_fp8_triton(inp, weight, scale, quantization_group_size):

assert inp.shape[1] == weight.shape[0], \
f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"

M, K = inp.shape
K, N = weight.shape

out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)

# GEMM tuning parameters!
# TODO: Add a more configurable tuning for selecting the best GeMM
BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = max(64, quantization_group_size)
GROUP_SIZE_M = 8
num_stages = 4
num_warps = 4
if M >= 256:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = max(128, quantization_group_size)
num_stages = 3
num_warps = 8

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
kernel[grid](inp,
weight,
out,
scale,
M,
N,
K,
inp.stride(0),
inp.stride(1),
weight.stride(0),
weight.stride(1),
out.stride(0),
out.stride(1),
quantization_group_size=quantization_group_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M,
num_stages=num_stages,
num_warps=num_warps)
return out
16 changes: 7 additions & 9 deletions tests/unit/ops/fp_quantizer/test_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,28 @@

from deepspeed.ops.fp_quantizer import FP_Quantize, matmul_fp8

from deepspeed import get_accelerator


@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize("q_bits", [8], ids=[
"qbits8",
])
@pytest.mark.parametrize("M", [1, 2, 4, 8, 32, 64, 128, 256, 512, 1024, 2048])
def test_fp_quant(dtype, q_bits, M):
device_name = get_accelerator().device_name()
quantization_group_size = 128
fpq = FP_Quantize(group_size=quantization_group_size)

N = 8192
H = 4096

x = torch.randn(M, H, dtype=dtype, device='cuda')
weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda')
x = torch.randn(M, H, dtype=dtype, device=device_name)
weight_bf16 = torch.randn(H, N, dtype=dtype, device=device_name)

weight, _ = fpq.quantize(weight_bf16.data, q_bits=8, return_meta_tensor=True)
weight, _ = fpq.quantize(weight_bf16.data, q_bits=q_bits, return_meta_tensor=True)
scale = fpq.get_scales()
out = matmul_fp8(
x,
weight,
scale,
quantization_group_size,
)
out = matmul_fp8(x, weight, scale, quantization_group_size, fpq)

out_q = torch.matmul(x, fpq.dequantize(weight, scale=fpq.scale))

Expand Down

0 comments on commit c5e48f4

Please sign in to comment.