Skip to content

Commit 7489c7d

Browse files
cmake torchao_ops_mps_linear_fp_act_xbit_weight
Differential Revision: D66120124 Pull Request resolved: #1304
1 parent 7446433 commit 7489c7d

File tree

8 files changed

+143
-46
lines changed

8 files changed

+143
-46
lines changed

torchao/experimental/kernels/mps/codegen/gen_metal_shader_lib.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,22 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from typing import Optional
28
import os
9+
import sys
310
import yaml
411

5-
torchao_root: Optional[str] = os.getenv("TORCHAO_ROOT")
6-
assert torchao_root is not None, "TORCHAO_ROOT is not set"
12+
if len(sys.argv) != 2:
13+
print("Usage: gen_metal_shader_lib.py <output_file>")
14+
sys.exit(1)
15+
16+
# Output file where the generated code will be written
17+
OUTPUT_FILE = sys.argv[1]
718

8-
MPS_DIR = os.path.join(torchao_root, "torchao", "experimental", "kernels", "mps")
19+
MPS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
920

1021
# Path to yaml file containing the list of .metal files to include
1122
METAL_YAML = os.path.join(MPS_DIR, "metal.yaml")
@@ -21,9 +32,6 @@
2132
# Path to the folder containing the .metal files
2233
METAL_DIR = os.path.join(MPS_DIR, "metal")
2334

24-
# Output file where the generated code will be written
25-
OUTPUT_FILE = os.path.join(MPS_DIR, "src", "metal_shader_lib.h")
26-
2735
prefix = """/**
2836
* This file is generated by gen_metal_shader_lib.py
2937
*/
@@ -48,6 +56,7 @@
4856
4957
"""
5058

59+
os.makedirs(os.path.dirname(OUTPUT_FILE), exist_ok=True)
5160
with open(OUTPUT_FILE, "w") as outf:
5261
outf.write(prefix)
5362
for file in metal_files:
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
cmake-out/
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
project(torchao_ops_mps_linear_fp_act_xbit_weight)
10+
11+
set(CMAKE_CXX_STANDARD 17)
12+
set(CMAKE_CXX_STANDARD_REQUIRED YES)
13+
14+
if (NOT CMAKE_BUILD_TYPE)
15+
set(CMAKE_BUILD_TYPE Release)
16+
endif()
17+
18+
if (CMAKE_SYSTEM_NAME STREQUAL "Darwin")
19+
if (NOT CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64")
20+
message(FATAL_ERROR "Unified Memory requires Apple Silicon architecture")
21+
endif()
22+
else()
23+
message(FATAL_ERROR "Torchao experimental mps ops can only be built on macOS/iOS")
24+
endif()
25+
26+
find_package(Torch REQUIRED)
27+
28+
# Generate metal_shader_lib.h by running gen_metal_shader_lib.py
29+
set(GENERATED_METAL_SHADER_LIB ${CMAKE_INSTALL_PREFIX}/include/torchao/experimental/kernels/mps/src/metal_shader_lib.h)
30+
add_custom_command(
31+
OUTPUT ${GENERATED_METAL_SHADER_LIB}
32+
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/../../kernels/mps/codegen/gen_metal_shader_lib.py ${GENERATED_METAL_SHADER_LIB}
33+
COMMENT "Generating metal_shader_lib.h using gen_metal_shader_lib.py"
34+
)
35+
add_custom_target(generated_metal_shader_lib ALL DEPENDS ${GENERATED_METAL_SHADER_LIB})
36+
37+
if(NOT TORCHAO_INCLUDE_DIRS)
38+
set(TORCHAO_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR}/../../../..)
39+
endif()
40+
message(STATUS "TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
41+
42+
include_directories(${TORCHAO_INCLUDE_DIRS})
43+
include_directories(${CMAKE_INSTALL_PREFIX}/include)
44+
add_library(torchao_ops_mps_linear_fp_act_xbit_weight_aten SHARED aten/register.mm)
45+
add_dependencies(torchao_ops_mps_linear_fp_act_xbit_weight_aten generated_metal_shader_lib)
46+
47+
target_include_directories(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_INCLUDE_DIRS}")
48+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE "${TORCH_LIBRARIES}")
49+
target_compile_definitions(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE USE_ATEN=1)
50+
51+
# Enable Metal support
52+
find_library(METAL_LIB Metal)
53+
find_library(FOUNDATION_LIB Foundation)
54+
target_link_libraries(torchao_ops_mps_linear_fp_act_xbit_weight_aten PRIVATE ${METAL_LIB} ${FOUNDATION_LIB})
55+
56+
install(
57+
TARGETS torchao_ops_mps_linear_fp_act_xbit_weight_aten
58+
EXPORT _targets
59+
DESTINATION lib
60+
)

torchao/experimental/ops/mps/register.mm renamed to torchao/experimental/ops/mps/aten/register.mm

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// LICENSE file in the root directory of this source tree.
66

77
// clang-format off
8-
#include <torch/extension.h>
8+
#include <torch/library.h>
99
#include <ATen/native/mps/OperationUtils.h>
1010
#include <torchao/experimental/kernels/mps/src/lowbit.h>
1111
// clang-format on
@@ -147,9 +147,6 @@ Tensor pack_weights_cpu_kernel(const Tensor& W) {
147147
return B;
148148
}
149149

150-
// Registers _C as a Python extension module.
151-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {}
152-
153150
TORCH_LIBRARY(torchao, m) {
154151
m.def("_pack_weight_1bit(Tensor W) -> Tensor");
155152
m.def("_pack_weight_2bit(Tensor W) -> Tensor");

torchao/experimental/ops/mps/build.sh

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#!/bin/bash -eu
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
cd "$(dirname "$BASH_SOURCE")"
9+
10+
export CMAKE_PREFIX_PATH=$(python -c 'from distutils.sysconfig import get_python_lib; print(get_python_lib())')
11+
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
12+
export CMAKE_OUT=${PWD}/cmake-out
13+
echo "CMAKE_OUT: ${CMAKE_OUT}"
14+
15+
cmake -DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
16+
-DCMAKE_INSTALL_PREFIX=${CMAKE_OUT} \
17+
-S . \
18+
-B ${CMAKE_OUT}
19+
cmake --build ${CMAKE_OUT} -j 16 --target install --config Release

torchao/experimental/ops/mps/setup.py

Lines changed: 0 additions & 23 deletions
This file was deleted.

torchao/experimental/ops/mps/test/test_lowbit.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,38 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
8+
import sys
79
import torch
8-
import torchao_mps_ops
910
import unittest
1011

12+
from parameterized import parameterized
1113

12-
def parameterized(test_cases):
13-
def decorator(func):
14-
def wrapper(self):
15-
for case in test_cases:
16-
with self.subTest(case=case):
17-
func(self, *case)
14+
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
15+
libpath = os.path.abspath(
16+
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
17+
)
1818

19-
return wrapper
20-
21-
return decorator
19+
try:
20+
for nbit in range(1, 8):
21+
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
22+
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
23+
except AttributeError:
24+
try:
25+
torch.ops.load_library(libpath)
26+
except:
27+
raise RuntimeError(f"Failed to load library {libpath}")
28+
else:
29+
try:
30+
for nbit in range(1, 8):
31+
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
32+
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
33+
except AttributeError as e:
34+
raise e
2235

2336

2437
class TestLowBitQuantWeightsLinear(unittest.TestCase):
25-
cases = [
38+
CASES = [
2639
(nbit, *param)
2740
for nbit in range(1, 8)
2841
for param in [
@@ -73,7 +86,7 @@ def _reference_linear_lowbit_quant_weights(self, A, W, group_size, S, Z, nbit):
7386
W = scales * W + zeros
7487
return torch.mm(A, W.t())
7588

76-
@parameterized(cases)
89+
@parameterized.expand(CASES)
7790
def test_linear(self, nbit, M=1, K=32, N=32, group_size=32):
7891
print(f"nbit: {nbit}, M: {M}, K: {K}, N: {N}, group_size: {group_size}")
7992
A, W, S, Z = self._init_tensors(group_size, M, K, N, nbit=nbit)

torchao/experimental/ops/mps/test/test_quantizer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,34 @@
1111
import sys
1212

1313
import torch
14-
import torchao_mps_ops
1514
import unittest
1615

1716
from parameterized import parameterized
1817
from torchao.experimental.quant_api import UIntxWeightOnlyLinearQuantizer
1918
from torchao.experimental.quant_api import _quantize
2019

20+
libname = "libtorchao_ops_mps_linear_fp_act_xbit_weight_aten.dylib"
21+
libpath = os.path.abspath(
22+
os.path.join(os.path.dirname(__file__), "../cmake-out/lib/", libname)
23+
)
24+
25+
try:
26+
for nbit in range(1, 8):
27+
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
28+
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
29+
except AttributeError:
30+
try:
31+
torch.ops.load_library(libpath)
32+
except:
33+
raise RuntimeError(f"Failed to load library {libpath}")
34+
else:
35+
try:
36+
for nbit in range(1, 8):
37+
getattr(torch.ops.torchao, f"_linear_fp_act_{nbit}bit_weight")
38+
getattr(torch.ops.torchao, f"_pack_weight_{nbit}bit")
39+
except AttributeError as e:
40+
raise e
41+
2142

2243
class TestUIntxWeightOnlyLinearQuantizer(unittest.TestCase):
2344
BITWIDTHS = range(1, 8)

0 commit comments

Comments
 (0)