-
Notifications
You must be signed in to change notification settings - Fork 317
Introduce lowbit quantized linear MPS kernels #954
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
facebook-github-bot
merged 1 commit into
pytorch:main
from
manuelcandales:export-D63342895
Oct 10, 2024
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
59 changes: 59 additions & 0 deletions
59
torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from typing import Optional | ||
import os | ||
import yaml | ||
|
||
torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT") | ||
assert torchao_root is not None, "TORCHAO_ROOT is not set" | ||
|
||
MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps") | ||
|
||
# Path to yaml file containing the list of .metal files to include | ||
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml") | ||
|
||
metal_files = set() | ||
with open(METAL_YAML, "r") as yamlf: | ||
metal_config = yaml.safe_load(yamlf) | ||
for op in metal_config: | ||
if "file" in op: | ||
metal_files.add(op["file"]) | ||
metal_files = sorted(metal_files) | ||
|
||
# Path to the folder containing the .metal files | ||
METAL_DIR = os.path.join(MPS_DIR, "metal") | ||
|
||
# Output file where the generated code will be written | ||
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h") | ||
|
||
prefix = """/** | ||
* This file is generated by gen_metal_shader_lib.py | ||
*/ | ||
|
||
#ifdef ATEN | ||
using namespace at::native::mps; | ||
#else | ||
#include <torchao/experimental/kernels/mps/src/OperationUtils.h> | ||
#endif | ||
|
||
static MetalShaderLibrary metal_lowbit_quantized_lib(R"METAL_LOWBIT( | ||
""" | ||
|
||
suffix = """ | ||
)METAL_LOWBIT"); | ||
""" | ||
|
||
comment = """ | ||
/** | ||
* Contents of {} | ||
*/ | ||
|
||
""" | ||
|
||
with open(OUTPUT_FILE, "w") as outf: | ||
outf.write(prefix) | ||
for file in metal_files: | ||
with open(os.path.join(METAL_DIR, file), "r") as f: | ||
content = f.read() | ||
outf.write(comment.format(file)) | ||
outf.write(content) | ||
outf.write("\n\n") | ||
outf.write(suffix) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
- func: int1mm | ||
file: divbit.metal | ||
|
||
- func: int2mm | ||
file: divbit.metal | ||
|
||
- func: int3mm | ||
file: int3mm.metal | ||
|
||
- func: int4mm | ||
file: divbit.metal | ||
|
||
- func: int5mm | ||
file: int5mm.metal | ||
|
||
- func: int6mm | ||
file: int6mm.metal | ||
|
||
- func: int7mm | ||
file: int7mm.metal |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* LowBit Quantized Linear for bitwidths that are divisors of 8. Hence the name. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (nbit * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned nbit, unsigned groupSize> | ||
kernel void divbit_mm( | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B; | ||
|
||
constexpr uint8_t zero_shift = 1 << (nbit - 1); | ||
constexpr uint8_t values_per_byte = 8 / nbit; | ||
constexpr uint8_t minimask = (1 << nbit) - 1; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const T scale = scalesAndZeros[(kb * N + n) * 2 + 0]; | ||
const T zero = scalesAndZeros[(kb * N + n) * 2 + 1] - scale * T(zero_shift); | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for(uint idx = 0; idx < groupSize && k < K; idx++, k++) { | ||
const auto a_val = float(A_ptr[k]); | ||
uint8_t b_val = B_ptr[(n * K + k) / values_per_byte]; | ||
uint8_t shift = nbit * (k % values_per_byte); | ||
uint8_t mask = minimask << shift; | ||
b_val = (b_val & mask) >> shift; | ||
rc += a_val * float(scale * T(b_val) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_DIVBIT_MM(NBIT, DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int" #NBIT "pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void divbit_mm<DTYPE, NBIT, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_DIVBIT_MM(1, float, 32); | ||
INSTANTIATE_DIVBIT_MM(1, half, 32); | ||
INSTANTIATE_DIVBIT_MM(1, float, 64); | ||
INSTANTIATE_DIVBIT_MM(1, half, 64); | ||
INSTANTIATE_DIVBIT_MM(1, float, 128); | ||
INSTANTIATE_DIVBIT_MM(1, half, 128); | ||
INSTANTIATE_DIVBIT_MM(1, float, 256); | ||
INSTANTIATE_DIVBIT_MM(1, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(1, bfloat, 256); | ||
#endif | ||
|
||
INSTANTIATE_DIVBIT_MM(2, float, 32); | ||
INSTANTIATE_DIVBIT_MM(2, half, 32); | ||
INSTANTIATE_DIVBIT_MM(2, float, 64); | ||
INSTANTIATE_DIVBIT_MM(2, half, 64); | ||
INSTANTIATE_DIVBIT_MM(2, float, 128); | ||
INSTANTIATE_DIVBIT_MM(2, half, 128); | ||
INSTANTIATE_DIVBIT_MM(2, float, 256); | ||
INSTANTIATE_DIVBIT_MM(2, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(2, bfloat, 256); | ||
#endif | ||
|
||
INSTANTIATE_DIVBIT_MM(4, float, 32); | ||
INSTANTIATE_DIVBIT_MM(4, half, 32); | ||
INSTANTIATE_DIVBIT_MM(4, float, 64); | ||
INSTANTIATE_DIVBIT_MM(4, half, 64); | ||
INSTANTIATE_DIVBIT_MM(4, float, 128); | ||
INSTANTIATE_DIVBIT_MM(4, half, 128); | ||
INSTANTIATE_DIVBIT_MM(4, float, 256); | ||
INSTANTIATE_DIVBIT_MM(4, half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 32); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 64); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 128); | ||
INSTANTIATE_DIVBIT_MM(4, bfloat, 256); | ||
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* 3-Bit Quantized Linear. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (3 * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned groupSize> | ||
kernel void int3pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B + n * 3 * K / 8; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(4); | ||
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { | ||
const auto a_val0 = float(A_ptr[k + 0]); | ||
const auto a_val1 = float(A_ptr[k + 1]); | ||
const auto a_val2 = float(A_ptr[k + 2]); | ||
const auto a_val3 = float(A_ptr[k + 3]); | ||
const auto a_val4 = float(A_ptr[k + 4]); | ||
const auto a_val5 = float(A_ptr[k + 5]); | ||
const auto a_val6 = float(A_ptr[k + 6]); | ||
const auto a_val7 = float(A_ptr[k + 7]); | ||
|
||
uchar b0 = B_ptr[3 * (k / 8) + 0]; | ||
uchar b1 = B_ptr[3 * (k / 8) + 1]; | ||
uchar b2 = B_ptr[3 * (k / 8) + 2]; | ||
|
||
uchar w_val0 = ((b0 & 1) << 2) | (b1 & 3); | ||
uchar w_val1 = ((b0 & 2) << 1) | ((b1 & 12) >> 2); | ||
uchar w_val2 = (b0 & 4) | ((b1 & 48) >> 4); | ||
uchar w_val3 = ((b0 & 8) >> 1) | ((b1 & 192) >> 6); | ||
|
||
uchar w_val4 = ((b0 & 16) >> 2) | (b2 & 3); | ||
uchar w_val5 = ((b0 & 32) >> 3) | ((b2 & 12) >> 2); | ||
uchar w_val6 = ((b0 & 64) >> 4) | ((b2 & 48) >> 4); | ||
uchar w_val7 = ((b0 & 128) >> 5) | ((b2 & 192) >> 6); | ||
manuelcandales marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
rc += a_val0 * (scale * float(w_val0) + zero); | ||
rc += a_val1 * (scale * float(w_val1) + zero); | ||
rc += a_val2 * (scale * float(w_val2) + zero); | ||
rc += a_val3 * (scale * float(w_val3) + zero); | ||
rc += a_val4 * (scale * float(w_val4) + zero); | ||
rc += a_val5 * (scale * float(w_val5) + zero); | ||
rc += a_val6 * (scale * float(w_val6) + zero); | ||
rc += a_val7 * (scale * float(w_val7) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT3MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int3pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int3pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT3MM(float, 32); | ||
INSTANTIATE_INT3MM(half, 32); | ||
INSTANTIATE_INT3MM(float, 64); | ||
INSTANTIATE_INT3MM(half, 64); | ||
INSTANTIATE_INT3MM(float, 128); | ||
INSTANTIATE_INT3MM(half, 128); | ||
INSTANTIATE_INT3MM(float, 256); | ||
INSTANTIATE_INT3MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT3MM(bfloat, 32); | ||
INSTANTIATE_INT3MM(bfloat, 64); | ||
INSTANTIATE_INT3MM(bfloat, 128); | ||
INSTANTIATE_INT3MM(bfloat, 256); | ||
#endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
#include <metal_stdlib> | ||
using namespace metal; | ||
|
||
/** | ||
* 5-Bit Quantized Linear. | ||
* | ||
* @param[A] M x K unquantized input tensor of floating point dtype (Float, Half, BFloat16) | ||
* @param[B] Packed & quantized weight tensor of uint8 dtype. Expected shape is N x (5 * K / 8) | ||
* @param[scalesAndZeros] 3D tensor containg the scales and zero point for each group. Expected shape is #groups x N x 2 | ||
* @param[outputData] M x N output tensor of floating point dtype (same as input) | ||
* @param[sizes] The sizes involved in the order: M, K, N | ||
* | ||
* Dispatched threads: N x M x 1 | ||
*/ | ||
template<typename T, unsigned groupSize> | ||
kernel void int5pack_mm( | ||
constant T * A [[buffer(0)]], | ||
constant uchar * B [[buffer(1)]], | ||
constant T * scalesAndZeros [[buffer(2)]], | ||
device T * outputData [[buffer(3)]], | ||
constant uint3 & sizes [[buffer(4)]], // M, K, N | ||
uint2 thread_index [[thread_position_in_grid]]) { | ||
const uint K = sizes.y; | ||
const uint N = sizes.z; | ||
const uint m = thread_index.y; // 0..M-1 | ||
const uint n = thread_index.x; // 0..N-1 | ||
const uint32_t k_block = (K + groupSize - 1) / groupSize; | ||
constant T *A_ptr = A + m * K; | ||
constant uchar *B_ptr = B + n * 5 * K / 8; | ||
|
||
float rc = 0.0; | ||
uint k = 0; | ||
for (uint32_t kb = 0; kb < k_block ; kb ++) { | ||
const float scale = float(scalesAndZeros[(kb * N + n) * 2 + 0]); | ||
const float zero = float(scalesAndZeros[(kb * N + n) * 2 + 1]) - scale * float(16); | ||
for(uint idx = 0; idx < groupSize && k < K; idx+=8, k+=8) { | ||
const auto a_val0 = float(A_ptr[k + 0]); | ||
const auto a_val1 = float(A_ptr[k + 1]); | ||
const auto a_val2 = float(A_ptr[k + 2]); | ||
const auto a_val3 = float(A_ptr[k + 3]); | ||
const auto a_val4 = float(A_ptr[k + 4]); | ||
const auto a_val5 = float(A_ptr[k + 5]); | ||
const auto a_val6 = float(A_ptr[k + 6]); | ||
const auto a_val7 = float(A_ptr[k + 7]); | ||
|
||
uchar b0 = B_ptr[5 * (k / 8) + 0]; | ||
uchar b1 = B_ptr[5 * (k / 8) + 1]; | ||
uchar b2 = B_ptr[5 * (k / 8) + 2]; | ||
uchar b3 = B_ptr[5 * (k / 8) + 3]; | ||
uchar b4 = B_ptr[5 * (k / 8) + 4]; | ||
|
||
uchar w_val0 = ((b0 & 1) << 4) | (b1 & 15); | ||
uchar w_val1 = ((b0 & 2) << 3) | ((b1 & 240) >> 4); | ||
uchar w_val2 = ((b0 & 4) << 2) | (b2 & 15); | ||
uchar w_val3 = ((b0 & 8) << 1) | ((b2 & 240) >> 4); | ||
|
||
uchar w_val4 = ((b0 & 16)) | (b3 & 15); | ||
uchar w_val5 = ((b0 & 32) >> 1) | ((b3 & 240) >> 4); | ||
uchar w_val6 = ((b0 & 64) >> 2) | (b4 & 15); | ||
uchar w_val7 = ((b0 & 128) >> 3) | ((b4 & 240) >> 4); | ||
|
||
rc += a_val0 * (scale * float(w_val0) + zero); | ||
rc += a_val1 * (scale * float(w_val1) + zero); | ||
rc += a_val2 * (scale * float(w_val2) + zero); | ||
rc += a_val3 * (scale * float(w_val3) + zero); | ||
rc += a_val4 * (scale * float(w_val4) + zero); | ||
rc += a_val5 * (scale * float(w_val5) + zero); | ||
rc += a_val6 * (scale * float(w_val6) + zero); | ||
rc += a_val7 * (scale * float(w_val7) + zero); | ||
} | ||
} | ||
outputData[m * N + n] = T(rc); | ||
} | ||
|
||
#define INSTANTIATE_INT5MM(DTYPE, GSIZE) \ | ||
template \ | ||
[[host_name("int5pack_mm_" #GSIZE "_" #DTYPE)]] \ | ||
kernel void int5pack_mm<DTYPE, GSIZE>( \ | ||
constant DTYPE * A [[buffer(0)]], \ | ||
constant uchar * B [[buffer(1)]], \ | ||
constant DTYPE * scalesAndZeros [[buffer(2)]], \ | ||
device DTYPE * outputData [[buffer(3)]], \ | ||
constant uint3 & sizes [[buffer(4)]], \ | ||
uint2 thread_index [[thread_position_in_grid]]) | ||
|
||
INSTANTIATE_INT5MM(float, 32); | ||
INSTANTIATE_INT5MM(half, 32); | ||
INSTANTIATE_INT5MM(float, 64); | ||
INSTANTIATE_INT5MM(half, 64); | ||
INSTANTIATE_INT5MM(float, 128); | ||
INSTANTIATE_INT5MM(half, 128); | ||
INSTANTIATE_INT5MM(float, 256); | ||
INSTANTIATE_INT5MM(half, 256); | ||
#if __METAL_VERSION__ >= 310 | ||
INSTANTIATE_INT5MM(bfloat, 32); | ||
INSTANTIATE_INT5MM(bfloat, 64); | ||
INSTANTIATE_INT5MM(bfloat, 128); | ||
INSTANTIATE_INT5MM(bfloat, 256); | ||
#endif |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.