Skip to content

Introduce lowbit quantized linear MPS kernels #954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Optional
import os
import yaml

torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
assert torchao_root is not None, "TORCHAO_ROOT is not set"

MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")

# Path to yaml file containing the list of .metal files to include
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")

metal_files = set()
with open(METAL_YAML, "r") as yamlf:
metal_config = yaml.safe_load(yamlf)
for op in metal_config:
if "file" in op:
metal_files.add(op["file"])
metal_files = sorted(metal_files)

# Path to the folder containing the .metal files
METAL_DIR = os.path.join(MPS_DIR, "metal")

# Output file where the generated code will be written
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")

prefix = """/**
* This file is generated by gen_metal_shader_lib.py
*/

#ifdef ATEN
using namespace at::native::mps;
#else
#include <torchao/experimental/kernels/mps/src/OperationUtils.h>
#endif

static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT(
"""

suffix = """
)METAL_LOWBIT");
"""

comment = """
/**
* Contents of {}
*/

"""

with open(OUTPUT_FILE, "w") as outf:
outf.write(prefix)
for file in metal_files:
with open(os.path.join(METAL_DIR, file), "r") as f:
content = f.read()
outf.write(comment.format(file))
outf.write(content)
outf.write("\n\n")
outf.write(suffix)
20 changes: 20 additions & 0 deletions torchao/experimental/kernels/mps/metal.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
- func: int1mm
file: divbit.metal

- func: int2mm
file: divbit.metal

- func: int3mm
file: int3mm.metal

- func: int4mm
file: divbit.metal

- func: int5mm
file: int5mm.metal

- func: int6mm
file: int6mm.metal

- func: int7mm
file: int7mm.metal
106 changes: 106 additions & 0 deletions torchao/experimental/kernels/mps/metal/divbit.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include <metal_stdlib>
using namespace metal;

/**
* LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned nbit, unsigned groupSize>
kernel void divbit_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B;

constexpr uint8_t zero_shift = 1 << (nbit - 1);
constexpr uint8_t values_per_byte = 8 / nbit;
constexpr uint8_t minimask = (1 << nbit) - 1;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0];
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift);
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) {
const auto a_val = float(A_ptr[k]);
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte];
uint8_t shift = nbit * (k % values_per_byte);
uint8_t mask = minimask << shift;
b_val = (b_val & mask) >> shift;
rc += a_val * float(scale * T(b_val) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \
template \
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_DIVBIT_MM(1, float, 32);
INSTANTIATE_DIVBIT_MM(1, half, 32);
INSTANTIATE_DIVBIT_MM(1, float, 64);
INSTANTIATE_DIVBIT_MM(1, half, 64);
INSTANTIATE_DIVBIT_MM(1, float, 128);
INSTANTIATE_DIVBIT_MM(1, half, 128);
INSTANTIATE_DIVBIT_MM(1, float, 256);
INSTANTIATE_DIVBIT_MM(1, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(1, bfloat, 32);
INSTANTIATE_DIVBIT_MM(1, bfloat, 64);
INSTANTIATE_DIVBIT_MM(1, bfloat, 128);
INSTANTIATE_DIVBIT_MM(1, bfloat, 256);
#endif

INSTANTIATE_DIVBIT_MM(2, float, 32);
INSTANTIATE_DIVBIT_MM(2, half, 32);
INSTANTIATE_DIVBIT_MM(2, float, 64);
INSTANTIATE_DIVBIT_MM(2, half, 64);
INSTANTIATE_DIVBIT_MM(2, float, 128);
INSTANTIATE_DIVBIT_MM(2, half, 128);
INSTANTIATE_DIVBIT_MM(2, float, 256);
INSTANTIATE_DIVBIT_MM(2, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(2, bfloat, 32);
INSTANTIATE_DIVBIT_MM(2, bfloat, 64);
INSTANTIATE_DIVBIT_MM(2, bfloat, 128);
INSTANTIATE_DIVBIT_MM(2, bfloat, 256);
#endif

INSTANTIATE_DIVBIT_MM(4, float, 32);
INSTANTIATE_DIVBIT_MM(4, half, 32);
INSTANTIATE_DIVBIT_MM(4, float, 64);
INSTANTIATE_DIVBIT_MM(4, half, 64);
INSTANTIATE_DIVBIT_MM(4, float, 128);
INSTANTIATE_DIVBIT_MM(4, half, 128);
INSTANTIATE_DIVBIT_MM(4, float, 256);
INSTANTIATE_DIVBIT_MM(4, half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_DIVBIT_MM(4, bfloat, 32);
INSTANTIATE_DIVBIT_MM(4, bfloat, 64);
INSTANTIATE_DIVBIT_MM(4, bfloat, 128);
INSTANTIATE_DIVBIT_MM(4, bfloat, 256);
#endif
97 changes: 97 additions & 0 deletions torchao/experimental/kernels/mps/metal/int3mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#include <metal_stdlib>
using namespace metal;

/**
* 3-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int3pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 3 * K / 8;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
const auto a_val2 = float(A_ptr[k + 2]);
const auto a_val3 = float(A_ptr[k + 3]);
const auto a_val4 = float(A_ptr[k + 4]);
const auto a_val5 = float(A_ptr[k + 5]);
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[3 * (k / 8) + 0];
uchar b1 = B_ptr[3 * (k / 8) + 1];
uchar b2 = B_ptr[3 * (k / 8) + 2];

uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3);
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2);
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4);
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6);

uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3);
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2);
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4);
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
rc += a_val2 * (scale * float(w_val2) + zero);
rc += a_val3 * (scale * float(w_val3) + zero);
rc += a_val4 * (scale * float(w_val4) + zero);
rc += a_val5 * (scale * float(w_val5) + zero);
rc += a_val6 * (scale * float(w_val6) + zero);
rc += a_val7 * (scale * float(w_val7) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \
template \
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int3pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT3MM(float, 32);
INSTANTIATE_INT3MM(half, 32);
INSTANTIATE_INT3MM(float, 64);
INSTANTIATE_INT3MM(half, 64);
INSTANTIATE_INT3MM(float, 128);
INSTANTIATE_INT3MM(half, 128);
INSTANTIATE_INT3MM(float, 256);
INSTANTIATE_INT3MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT3MM(bfloat, 32);
INSTANTIATE_INT3MM(bfloat, 64);
INSTANTIATE_INT3MM(bfloat, 128);
INSTANTIATE_INT3MM(bfloat, 256);
#endif
99 changes: 99 additions & 0 deletions torchao/experimental/kernels/mps/metal/int5mm.metal
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#include <metal_stdlib>
using namespace metal;

/**
* 5-Bit Quantized Linear.
*
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16)
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8)
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2
* @param[outputData] M x N output tensor of floating point dtype (same as input)
* @param[sizes] The sizes involved in the order: M, K, N
*
* Dispatched threads: N x M x 1
*/
template<typename T, unsigned groupSize>
kernel void int5pack_mm(
constant T * A [[buffer(0)]],
constant uchar * B [[buffer(1)]],
constant T * scalesAndZeros [[buffer(2)]],
device T * outputData [[buffer(3)]],
constant uint3 & sizes [[buffer(4)]], // M, K, N
uint2 thread_index [[thread_position_in_grid]]) {
const uint K = sizes.y;
const uint N = sizes.z;
const uint m = thread_index.y; // 0..M-1
const uint n = thread_index.x; // 0..N-1
const uint32_t k_block = (K + groupSize - 1) / groupSize;
constant T *A_ptr = A + m * K;
constant uchar *B_ptr = B + n * 5 * K / 8;

float rc = 0.0;
uint k = 0;
for (uint32_t kb = 0; kb < k_block ; kb ++) {
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]);
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16);
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) {
const auto a_val0 = float(A_ptr[k + 0]);
const auto a_val1 = float(A_ptr[k + 1]);
const auto a_val2 = float(A_ptr[k + 2]);
const auto a_val3 = float(A_ptr[k + 3]);
const auto a_val4 = float(A_ptr[k + 4]);
const auto a_val5 = float(A_ptr[k + 5]);
const auto a_val6 = float(A_ptr[k + 6]);
const auto a_val7 = float(A_ptr[k + 7]);

uchar b0 = B_ptr[5 * (k / 8) + 0];
uchar b1 = B_ptr[5 * (k / 8) + 1];
uchar b2 = B_ptr[5 * (k / 8) + 2];
uchar b3 = B_ptr[5 * (k / 8) + 3];
uchar b4 = B_ptr[5 * (k / 8) + 4];

uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15);
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4);
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15);
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4);

uchar w_val4 = ((b0 & 16)) | (b3 & 15);
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4);
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15);
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4);

rc += a_val0 * (scale * float(w_val0) + zero);
rc += a_val1 * (scale * float(w_val1) + zero);
rc += a_val2 * (scale * float(w_val2) + zero);
rc += a_val3 * (scale * float(w_val3) + zero);
rc += a_val4 * (scale * float(w_val4) + zero);
rc += a_val5 * (scale * float(w_val5) + zero);
rc += a_val6 * (scale * float(w_val6) + zero);
rc += a_val7 * (scale * float(w_val7) + zero);
}
}
outputData[m * N + n] = T(rc);
}

#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \
template \
[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \
kernel void int5pack_mm<DTYPE, GSIZE>( \
constant DTYPE * A [[buffer(0)]], \
constant uchar * B [[buffer(1)]], \
constant DTYPE * scalesAndZeros [[buffer(2)]], \
device DTYPE * outputData [[buffer(3)]], \
constant uint3 & sizes [[buffer(4)]], \
uint2 thread_index [[thread_position_in_grid]])

INSTANTIATE_INT5MM(float, 32);
INSTANTIATE_INT5MM(half, 32);
INSTANTIATE_INT5MM(float, 64);
INSTANTIATE_INT5MM(half, 64);
INSTANTIATE_INT5MM(float, 128);
INSTANTIATE_INT5MM(half, 128);
INSTANTIATE_INT5MM(float, 256);
INSTANTIATE_INT5MM(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT5MM(bfloat, 32);
INSTANTIATE_INT5MM(bfloat, 64);
INSTANTIATE_INT5MM(bfloat, 128);
INSTANTIATE_INT5MM(bfloat, 256);
#endif
Loading
Loading