Skip to content

Commit

Permalink
[Kernel] (2/N) Machete - Integrate into CompressedTensorsWNA16 and GP…
Browse files Browse the repository at this point in the history
…TQMarlin (vllm-project#7701)

Co-authored-by: mgoin <michael@neuralmagic.com>
Co-authored-by: Divakar Verma <137818590+divakar-amd@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
  • Loading branch information
4 people authored Sep 23, 2024
1 parent ee5f34b commit 86e9c8d
Show file tree
Hide file tree
Showing 27 changed files with 1,005 additions and 246 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/gguf/gguf_kernel.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/custom_all_reduce.cu"
"csrc/permute_cols.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu")
Expand Down
74 changes: 61 additions & 13 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import math
import pickle as pkl
import time
from typing import Callable, Iterable, List, Tuple
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple

import pandas as pd
import torch
import torch.utils.benchmark as TBenchmark
from torch.utils.benchmark import Measurement as TMeasurement
Expand Down Expand Up @@ -84,6 +86,10 @@ def loop_over_weights(
fn(a, w_ref, w_q, w_s)


_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None


def bench(atype: torch.dtype,
wtype: ScalarType,
group_size: int,
Expand All @@ -94,6 +100,8 @@ def bench(atype: torch.dtype,
sub_label: str,
benchmark_marlinv1: bool = True,
sweep_schedules: bool = True) -> Iterable[TMeasurement]:
global _SWEEP_SCHEDULES_RESULTS

a, weights = make_bench_tensors(atype, wtype, group_size, m, n, k)
sub_label += f", L={len(weights)}"

Expand Down Expand Up @@ -163,6 +171,11 @@ def marlinv1_permute_scales(w_s: torch.tensor) -> torch.tensor:
best_schedule = None
schedules = ops.machete_supported_schedules(wtype)
for schedule in reversed(schedules):
schedule_M = int(schedule.split("_")[0].split("x")[1])

# Prune known bad schedules
if schedule_M >= 2 * max(m, 16) or schedule_M < m // 4:
continue

def run(a, _, w_q, w_s, schedule=schedule):
ops.machete_gemm(a,
Expand All @@ -175,6 +188,20 @@ def run(a, _, w_q, w_s, schedule=schedule):
res = bench_fn(label, sub_label, "machete_best",
lambda: loop_over_weights(a, weights_machete, run))

results_row = {
"M": m,
"K": k,
"N": n,
"group_size": group_size,
"schedule": schedule,
"median": res.median,
}
if _SWEEP_SCHEDULES_RESULTS is None:
_SWEEP_SCHEDULES_RESULTS = pd.DataFrame(
columns=results_row.keys())
_SWEEP_SCHEDULES_RESULTS.\
loc[len(_SWEEP_SCHEDULES_RESULTS)] = results_row

print(f" {res.median:5.5} ", schedule)
if not best or res.median < best.median:
best = res
Expand Down Expand Up @@ -235,18 +262,22 @@ def run_square_bench(args):
dim_sizes = list(
range(args.dim_start, args.dim_end + 1, args.dim_increment))
MKNs = list(zip(dim_sizes, dim_sizes, dim_sizes))

data = run(args.dtype, args.sweep_schedules, MKNs)

make_output(data, MKNs, f"square_bench-{args.dtype}")


def run_range_bench(args):
dim_sizes = list(range(args.dim_start, args.dim_end, args.dim_increment))
n = len(dim_sizes)
Ms = [args.m_constant] * n if args.m_constant is not None else dim_sizes
Ks = [args.k_constant] * n if args.k_constant is not None else dim_sizes
Ns = [args.n_constant] * n if args.n_constant is not None else dim_sizes
MKNs = list(zip(Ms, Ks, Ns))
m_start, k_start, n_start = [int(x) for x in args.dim_start.split(",")]
m_end, k_end, n_end = [int(x) for x in args.dim_end.split(",")]
m_increment, k_increment, n_increment = \
[int(x) for x in args.dim_increment.split(",")]
Ms = list(range(m_start, m_end + 1, m_increment))
Ks = list(range(k_start, k_end + 1, k_increment))
Ns = list(range(n_start, n_end + 1, n_increment))
MKNs = list(product(Ms, Ks, Ns))

data = run(args.dtype, args.sweep_schedules, MKNs)

make_output(data, MKNs, f"range_bench-{args.dtype}")
Expand Down Expand Up @@ -333,6 +364,9 @@ def to_torch_dtype(dt):
action="store_true",
help="Run a sweep over all supported schedules",
)
parser.add_argument("--sweep-csv-out",
help="CSV to store sweep results",
default="sch_sweep_results.csv")
subparsers = parser.add_subparsers(dest="cmd", required=True)

square_parser = subparsers.add_parser("square_bench")
Expand All @@ -342,12 +376,21 @@ def to_torch_dtype(dt):
square_parser.set_defaults(func=run_square_bench)

range_parser = subparsers.add_parser("range_bench")
range_parser.add_argument("--dim-start", type=int, required=True)
range_parser.add_argument("--dim-end", type=int, required=True)
range_parser.add_argument("--dim-increment", type=int, required=True)
range_parser.add_argument("--m-constant", type=int, default=None)
range_parser.add_argument("--n-constant", type=int, default=None)
range_parser.add_argument("--k-constant", type=int, default=None)
range_parser.add_argument(
"--dim-start",
type=str,
required=True,
help="Start value for M,K,N as common separated list")
range_parser.add_argument(
"--dim-end",
type=str,
required=True,
help="End value (inclusive) for M,K,N as common separated list")
range_parser.add_argument(
"--dim-increment",
type=str,
required=True,
help="Increment value for M,K,N as common separated list")
range_parser.set_defaults(func=run_range_bench)

model_parser = subparsers.add_parser("model_bench")
Expand All @@ -369,4 +412,9 @@ def to_torch_dtype(dt):
model_parser.set_defaults(func=run_model_bench)

args = parser.parse_args()

_SWEEP_SCHEDULES_RESULTS_CSV = args.sweep_csv_out
args.func(args)

if _SWEEP_SCHEDULES_RESULTS is not None:
_SWEEP_SCHEDULES_RESULTS.to_csv(_SWEEP_SCHEDULES_RESULTS_CSV)
1 change: 1 addition & 0 deletions benchmarks/kernels/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pandas
8 changes: 7 additions & 1 deletion csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
name, ".stride(", idx, ") to be ", StrideEle::value);
return StrideEle{};
} else {
return tensor.stride(idx);
if (tensor.size(idx) == 1) {
// use 0 stride for dim with size 1, this is easier for
// cute/cutlass to optimize (helps the TMA code flatten dims)
return StrideEle{0};
} else {
return tensor.stride(idx);
}
}
} else {
// Extra strides are assumed to be 0 or 1
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ torch::Tensor prepack_B(torch::Tensor const& B,

}; // namespace machete

torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);

torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_meta,
torch::Tensor& b_scales,
Expand Down
88 changes: 88 additions & 0 deletions csrc/permute_cols.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
#include <torch/all.h>

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <cuda_fp16.h>

static constexpr int default_threads = 256;
static constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }

// For a given "a" of size [M,K] performs a permutation of the K columns based
// on the given "perm" indices.
// Currently only supports 16bit types (since we permute half types)
__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int const* __restrict__ perm_int_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {
int start_row = block_rows * blockIdx.x;
int finish_row = start_row + block_rows;
if (finish_row > size_m) {
finish_row = size_m;
}
int cur_block_rows = std::max(finish_row - start_row, 0);

int row_stride = size_k * sizeof(half) / 16;

auto permute_row = [&](int row) {
int iters = size_k / default_threads;
int rest = size_k % default_threads;

int offset = row * row_stride;

half const* a_row_half = reinterpret_cast<half const*>(a_int4_ptr + offset);
half* out_half = reinterpret_cast<half*>(out_int4_ptr + offset);

int base_k = 0;

for (int i = 0; i < iters; i++) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];

out_half[cur_k] = a_row_half[src_pos];

base_k += default_threads;
}

if (rest) {
if (threadIdx.x < rest) {
int cur_k = base_k + threadIdx.x;
int src_pos = perm_int_ptr[cur_k];

out_half[cur_k] = a_row_half[src_pos];
}
}
};

for (int i = 0; i < cur_block_rows; i++) {
int cur_row = start_row + i;
if (cur_row < size_m) {
permute_row(cur_row);
}
}
}

// More efficient version of A[..., perm]
// taken from gptq_marlin.cu
torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
auto dev = A.get_device();
auto stream = at::cuda::getCurrentCUDAStream(dev);

TORCH_CHECK(A.scalar_type() == at::kHalf || A.scalar_type() == at::kBFloat16,
"Currently only 16bit types are supported");
TORCH_CHECK(A.is_contiguous(), "A must be contiguous");
TORCH_CHECK(A.size(-1) % 8 == 0,
"A columns must be a multiple of 8 (128bits)");
auto A_2d = A.view({-1, A.size(-1)});

torch::Tensor D = torch::empty_like(A);
int sms;
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
int block_rows = div_ceil(A_2d.size(0), sms);
permute_cols_kernel<<<sms, default_threads, 0, stream>>>(
reinterpret_cast<int4 const*>(A_2d.const_data_ptr()),
perm.const_data_ptr<int>(), reinterpret_cast<int4*>(D.mutable_data_ptr()),
A_2d.size(0), A_2d.size(1), block_rows);
return D;
}
Loading

0 comments on commit 86e9c8d

Please sign in to comment.