Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OptimizedLinear updates #5791

Merged
merged 20 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add fp8-fused gemm kernel
  • Loading branch information
sfc-gh-reyazda committed Jul 11, 2024
commit 4c3b8fd5a44489107a2b32cf35bcf8221626defa
203 changes: 203 additions & 0 deletions deepspeed/ops/fp_quantizer/fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@


# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

######## Fused MoE kernel #########
# These kernels are implemented for
# fusing GeMM with dequantization of
# fp8 weight data when using bit-16
# activation.
###################################

import torch
import triton
import triton.language as tl


@triton.jit
def matmul_kernel_fp8_bf16(
inp_ptr,
weight_ptr,
out_ptr,
scale_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ((pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> bf16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 4)).to(tl.uint16)
w = (w + 0x3C00).to(tl.uint16)
w = (w.to(tl.bfloat16, bitcast=True) * scale).to(tl.bfloat16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator = tl.dot(inp, w, accumulator)

out = accumulator.to(tl.bfloat16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))

@triton.jit
def matmul_kernel_fp8_fp16(
inp_ptr,
weight_ptr,
out_ptr,
scale_ptr,
M,
N,
K,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
quantization_group_size: tl.constexpr
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)

inp_data = inp_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
weight_data = weight_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
weight_ptrs_offset = offs_k[:, None] * (stride_bk // quantization_group_size) + ((pid_n * BLOCK_SIZE_N) // quantization_group_size)

weight = tl.load(weight_data, mask=offs_k[:, None] < K, other=0.0)
scale = tl.load(scale_ptr + weight_ptrs_offset)

accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
inp = tl.load(inp_data, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
# Dequantize weight (fp8 -> fp16)
w = (((weight & 0x80) << 8) | ((weight & 0x7f) << 10)).to(tl.uint16)
w = (w + 0x2000).to(tl.uint16)
w = (w.to(tl.float16, bitcast=True) * scale).to(tl.float16)

inp_data += BLOCK_SIZE_K * stride_ak
weight_data += BLOCK_SIZE_K * stride_bk

weight = tl.load(weight_data, mask=offs_k[:, None] < K - (k + 1) * BLOCK_SIZE_K, other=0.0)
scale = tl.load(scale_ptr + (weight_ptrs_offset + (((k + 1) * BLOCK_SIZE_K * stride_bk) // quantization_group_size)))

accumulator = tl.dot(inp, w, accumulator)

out = accumulator.to(tl.float16)

offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
out_data = out_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
tl.store(out_data, out, mask=(offs_cm[:, None] < M) & (offs_cn[None, :] < N))


def matmul_fp8(inp,
weight,
scale,
quantization_group_size):

assert inp.shape[1] == weight.shape[0], \
f"Incompatible dimensions (input: {inp.shape}, weight: {weight.shape})"

M, K = inp.shape
K, N = weight.shape

out = torch.empty((M, N), device=inp.device, dtype=inp.dtype)

# GEMM tuning parameters!
# TODO: Add a more configurable tuning for selecting the best GeMM
BLOCK_SIZE_M = 16 if M <= 16 else 32 if M <= 32 else 64 if M <= 64 else 128
BLOCK_SIZE_N = 64
BLOCK_SIZE_K = max(64, quantization_group_size)
GROUP_SIZE_M = 8
num_stages = 4
num_warps = 4
if M >= 256:
BLOCK_SIZE_M = 256
BLOCK_SIZE_N = 128
BLOCK_SIZE_K = max(128, quantization_group_size)
num_stages = 3
num_warps = 8

grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )
params = (inp,
weight,
out,
scale,
M,
N,
K,
inp.stride(0),
inp.stride(1),
weight.stride(0),
weight.stride(1),
out.stride(0),
out.stride(1),
quantization_group_size=quantization_group_size,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
GROUP_SIZE_M=GROUP_SIZE_M,
num_stages=num_stages,
num_warps=num_warps
)
kernel = matmul_kernel_fp8_bf16 if inp.dtype == torch.bfloat16 else matmul_kernel_fp8_fp16
kernel[grid](**params)
return out
42 changes: 42 additions & 0 deletions tests/unit/ops/fp_quantizer/test_fp8_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import pytest
import torch
import deepspeed

from deepspeed.ops.fp_quantizer import FP_Quantize
from deepspeed.ops.op_builder import FPQuantizerBuilder

if not deepspeed.ops.__compatible_ops__[FPQuantizerBuilder.NAME]:
pytest.skip("FPQuantizer op is not available on this system", allow_module_level=True)

@pytest.mark.parametrize("dtype", [torch.bfloat16], ids=["bf16"])
@pytest.mark.parametrize("q_bits", [8], ids=["qbits8",])
def test_fp_quant(dtype, q_bits):
group_size = 128
fpq = FP_Quantize(group_size=group_size)

N = 8192
H = 4096
for M in [256, 512, 1024, 1536, 2048, 3072, 4096, 8192]:
x = torch.randn(M, H, dtype=dtype, device='cuda')
weight_bf16 = torch.randn(H, N, dtype=dtype, device='cuda')
weight, _ = fpq.quantize(
weight_bf16.data,
q_bits=8,
return_meta_tensor=True
)
scale = fpq.get_scales()
out = matmul_fp8(
x,
weight,
scale,
quantization_group_size,
)
out_q = matmul(x, fpq.dequantize(wq, scale=scale))
error = ((out - out_q).abs() / (out.abs() + 1e-5)).sum() / out.numel()

assert 0.004 > error, f"failed on iteration {i}"