Skip to content

Commit a47e05c

Browse files
Initial support of SYCL CUTLASS for XPU backend through Inductor (#2)
Summary : This patch enables an initial execution of `torch.mm` through an inductor generated SYCL CUTLASS kernel for intel PVC. Following the reference CUDA implementation, this implements the following key functionalities : - For Template Generation & rendering : - `SYCLTemplate`, `CUTLASSTemplate`, `CUTLASSGemmTemplate`, `CUTLASS3xGemmTemplate` : Handles generating the full c++ code from the call to `GeneratePVC` to get the `GemmOperations` (exposed by cutlass_library), filtering the operations, constructing the Manifest & extracting the gemm instance from the emitter (exposed by cutlass_library) until the full wrapping of the c++ template code using runtime arguments & final kernel launch. - `cutlass_utils.py` : utility file containing relevant functions used across the codegen process. - `SYCLKernel`, `SYCLTemplateKernel` : Handles higher level kernel template and kernel calling from host side. Used within the previous Template classes. - For Autotuning : - `SYCLBenchmarkRequest` : Currently added as an almost dummy, not really benchmarking since we're selecting a single generated configuration for this PoC. - For wrapping & triggering the above : - `SYCLTemplateCaller` : Wrapper holding a ready to compile, execute, benchmark SYCL Template Kernel. This is the higher level construct that's added to the list of "choices" in the autotuning process for selecting the best configuration. - For scheduling/Execution : - `SYCLCPPScheduling` & `SYCLCombinedScheduling` : Orchestrator of kernel calls across eventually nodes with different lowerings (Triton & CUTLASS SYCL for instance). Few changes have been made to this compared to original CUDA implementation. Current state was fine-tuned to support the only type configuration exposed by cutlass on PVC so far, a.k.a `bfloat16` input and `fp32` accumulation, forcing some workarounds on pytorch side, namely related to D(layout/output node) & C(source/input_node[2]) dtypes. Unsupported features &/or partially implemented ones are highlighted as comments `TODO (SYCL)`. --------- Co-authored-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent 730119f commit a47e05c

File tree

13 files changed

+2414
-5
lines changed

13 files changed

+2414
-5
lines changed

torch/_inductor/autotune_process.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
DLLWrapper,
2828
get_hash,
2929
PyCodeCache,
30+
SYCLCodeCache,
3031
)
3132
from torch._inductor.utils import get_gpu_type, is_gpu
3233
from torch._logging import getArtifactLogger
@@ -935,6 +936,95 @@ def __str__(self) -> str:
935936
return f"{self.kernel_name=}"
936937

937938

939+
class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
940+
# Important: Instances of this class have to be serializable
941+
# across process boundaries. Do not put Tensors in here!
942+
# TODO (SYCL) : Complete the bmrq class to enable full autotuning
943+
def __init__(
944+
self,
945+
kernel_name: str,
946+
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
947+
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
948+
extra_args: Iterable[Any],
949+
source_code: str,
950+
) -> None:
951+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
952+
self.source_code = source_code
953+
self.workspace_size: int = 0 # TODO (SYCL): workspace size remains 0
954+
self.workspace: Optional[torch.Tensor] = None
955+
self.DLL: Optional[DLLWrapper] = None
956+
self._workspace_size_updated = False
957+
self.hash_key: str = ""
958+
self.source_file: str = ""
959+
self.hash_key, self.source_file = SYCLCodeCache.write(self.source_code, "so")
960+
961+
def precompile(self):
962+
# Prepopulate SYCLCodeCache
963+
autotuning_log.debug("Precompiling %s", self)
964+
SYCLCodeCache.compile(self.source_code, "so")
965+
autotuning_log.debug("Done precompiling %s", self)
966+
967+
def make_run_fn(
968+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
969+
) -> Callable[[], None]:
970+
self.ensure_dll_loaded()
971+
self.update_workspace_size() # TODO (SYCL): No effect on workspace_size being unused (remains = 0)
972+
args = [
973+
c_void_p(tensor.data_ptr())
974+
for tensor in list(input_tensors) + [output_tensor]
975+
]
976+
autotuning_log.debug(
977+
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
978+
self.kernel_name,
979+
self.source_file,
980+
self.hash_key,
981+
self.DLL,
982+
args,
983+
self.extra_args,
984+
)
985+
queue_ptr = c_void_p(torch.xpu.current_stream().sycl_queue)
986+
run_method = getattr(self.DLL, self.kernel_name)
987+
workspace_ptr = c_void_p(0)
988+
if self.workspace_size > 0:
989+
self.workspace = torch.zeros(
990+
(self.workspace_size + 7) // 8,
991+
dtype=torch.float64,
992+
device=output_tensor.device,
993+
)
994+
workspace_ptr = c_void_p(self.workspace.data_ptr())
995+
996+
# Generate partial function.
997+
return functools.partial(
998+
run_method,
999+
*args,
1000+
*self.extra_args,
1001+
None, # null workspace size ptr
1002+
workspace_ptr, # set workspace ptr,
1003+
queue_ptr,
1004+
)
1005+
1006+
def update_workspace_size(self) -> None:
1007+
if self._workspace_size_updated:
1008+
return
1009+
# TODO (SYCL): Harcoded to zero since no SLM is used on PVC at the moment
1010+
self.workspace_size = 0
1011+
self._workspace_size_updated = True
1012+
1013+
def ensure_dll_loaded(self):
1014+
if self.DLL is None:
1015+
self.DLL, self.hash_key, self.source_file = SYCLCodeCache.load(
1016+
self.source_code, "so"
1017+
)
1018+
1019+
def cleanup_run_fn(self) -> None:
1020+
if self.DLL is not None:
1021+
self.DLL.close()
1022+
self.workspace = None
1023+
1024+
def __str__(self) -> str:
1025+
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
1026+
1027+
9381028
def benchmark_in_sub_process(
9391029
choices: list[TritonTemplateCaller],
9401030
) -> dict[TritonTemplateCaller, float]:

torch/_inductor/codegen/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ def init_backend_registration() -> None:
394394
from .mps import MetalScheduling
395395
from .triton import TritonScheduling
396396
from .wrapper import PythonWrapperCodegen
397+
from .xpu_combined_scheduling import SYCLCombinedScheduling
397398

398399
if get_scheduling_for_device("cpu") is None:
399400
cpu_backends = {
@@ -424,9 +425,10 @@ def init_backend_registration() -> None:
424425
)
425426

426427
if get_scheduling_for_device("xpu") is None:
428+
# SYCLCombinedScheduling combines Triton and SYCL C++ scheduling for XPU devices via delegation
427429
register_backend_for_device(
428430
"xpu",
429-
TritonScheduling,
431+
SYCLCombinedScheduling,
430432
PythonWrapperCodegen,
431433
CppWrapperGpu,
432434
)
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# mypy: allow-untyped-defs
2+
import functools
3+
import logging
4+
import os
5+
import sys
6+
from dataclasses import dataclass
7+
from typing import Any, Optional
8+
9+
import sympy
10+
11+
import torch
12+
from torch._inductor.utils import clear_on_fresh_inductor_cache
13+
14+
from ... import config
15+
from ...ir import Layout
16+
from ...runtime.runtime_utils import cache_dir
17+
from ...virtualized import V
18+
19+
20+
log = logging.getLogger(__name__)
21+
22+
23+
@functools.lru_cache(None)
24+
def try_import_cutlass() -> bool:
25+
"""
26+
Currently only supporting user specified cutlass_dir or falling to the
27+
default ../third_party/cutlass/ (build from source setups).
28+
"""
29+
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
30+
31+
cutlass_py_full_path = os.path.abspath(
32+
os.path.join(config.cutlass_dir, "python/cutlass_library")
33+
)
34+
tmp_cutlass_py_full_path = os.path.abspath(
35+
os.path.join(cache_dir(), "torch_cutlass_library")
36+
)
37+
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
38+
39+
if os.path.isdir(cutlass_py_full_path):
40+
if tmp_cutlass_py_full_path not in sys.path:
41+
if os.path.exists(dst_link):
42+
assert os.path.islink(dst_link), (
43+
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
44+
)
45+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
46+
cutlass_py_full_path
47+
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
48+
else:
49+
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
50+
os.symlink(cutlass_py_full_path, dst_link)
51+
sys.path.append(tmp_cutlass_py_full_path)
52+
try:
53+
import cutlass_library.generator # noqa: F401
54+
import cutlass_library.library # noqa: F401
55+
import cutlass_library.manifest # noqa: F401
56+
57+
return True
58+
except ImportError as e:
59+
log.debug(
60+
"Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.",
61+
str(e),
62+
)
63+
else:
64+
log.debug(
65+
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
66+
cutlass_py_full_path,
67+
)
68+
return False
69+
70+
71+
@functools.lru_cache(8)
72+
def _normalize_sycl_arch(arch: str) -> str:
73+
if int(arch) == 11:
74+
return "11"
75+
else:
76+
raise NotImplementedError(f"Unsupported sycl arch: {arch}")
77+
78+
79+
@dataclass
80+
class CUTLASSArgs:
81+
"""
82+
CUTLASS args used to initialize a CUTLASS Manifest.
83+
"""
84+
85+
architectures: Optional[str] = None
86+
cuda_version: Optional[str] = None # Unused in generator.py for PVC
87+
instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC
88+
89+
operations = "all"
90+
build_dir = ""
91+
curr_build_dir = ""
92+
generator_target = ""
93+
kernels = "all"
94+
ignore_kernels = ""
95+
exclude_kernels = ""
96+
# UNUSED at the moment, part of Manifest class in cutlass_library
97+
kernel_filter_file: None = None
98+
selected_kernel_list: None = None
99+
interface_dir: None = None
100+
filter_by_cc = False
101+
disable_full_archs_compilation = False
102+
103+
def __post_init__(self):
104+
if self.architectures is None:
105+
raise RuntimeError(f"{self.architectures=} is None!")
106+
self.architectures = _normalize_sycl_arch(self.architectures)
107+
108+
109+
@clear_on_fresh_inductor_cache
110+
@functools.lru_cache(None)
111+
def _gen_ops_cached(arch) -> list[Any]:
112+
# Import cutlass python scripts.
113+
assert try_import_cutlass()
114+
import cutlass_library.generator as cutlass_generator
115+
import cutlass_library.manifest as cutlass_manifest
116+
117+
if arch is None:
118+
log.error(
119+
"Cannot detect XPU arch %s."
120+
"Will discard all cutlass ops. "
121+
"Please consider setting _inductor.xpu.arch",
122+
arch,
123+
)
124+
return []
125+
arch = _normalize_sycl_arch(arch)
126+
127+
sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC
128+
129+
args = CUTLASSArgs(
130+
architectures=arch,
131+
instantiation_level="0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py
132+
cuda_version=sycl_version,
133+
)
134+
manifest = cutlass_manifest.Manifest(args)
135+
136+
if arch == "11":
137+
cutlass_generator.GeneratePVC(manifest, sycl_version)
138+
else:
139+
log.error("Invalid XPU arch")
140+
return []
141+
return manifest.operations
142+
143+
144+
def gen_ops() -> list[Any]:
145+
"""
146+
Generates all supported CUTLASS operations.
147+
"""
148+
# Currently limited to PVC (arch 1100), harcoding arch
149+
# TODO :(SYCL) get_xpu_arch()
150+
arch = "11"
151+
return _gen_ops_cached(arch)
152+
153+
154+
def torch_dtype_to_cutlass_type(
155+
torch_dtype: torch.dtype,
156+
) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821
157+
# Import cutlass python scripts.
158+
assert try_import_cutlass()
159+
import cutlass_library # type: ignore[import]
160+
161+
if torch_dtype == torch.float:
162+
return cutlass_library.library.DataType.f32
163+
elif torch_dtype == torch.half:
164+
return cutlass_library.library.DataType.f16
165+
elif torch_dtype == torch.bfloat16:
166+
return cutlass_library.library.DataType.bf16
167+
else:
168+
raise NotImplementedError(f"Unsupported data type: {torch_dtype=}")
169+
170+
171+
def dtype_match(
172+
torch_dtype: Optional[torch.dtype],
173+
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
174+
) -> bool:
175+
# Import cutlass python scripts.
176+
assert try_import_cutlass()
177+
import cutlass_library
178+
179+
if torch_dtype == torch.float:
180+
return cutlass_dtype == cutlass_library.library.DataType.f32
181+
elif torch_dtype == torch.half:
182+
return cutlass_dtype == cutlass_library.library.DataType.f16
183+
elif torch_dtype == torch.bfloat16:
184+
return cutlass_dtype == cutlass_library.library.DataType.bf16
185+
elif torch_dtype == torch.int8:
186+
return cutlass_dtype == cutlass_library.library.DataType.s8
187+
elif torch_dtype == torch.uint8:
188+
return cutlass_dtype == cutlass_library.library.DataType.u8
189+
elif torch_dtype == torch.int32:
190+
return cutlass_dtype == cutlass_library.library.DataType.s32
191+
else:
192+
return False
193+
194+
195+
def get_accumulator_dtype(
196+
input_torch_dtypes: list[torch.dtype],
197+
) -> Optional[torch.dtype]:
198+
"""
199+
Given a pair of input torch dtypes, returns the inferred accumulator torch dtype.
200+
"""
201+
# TODO (SYCL) : Extend this once other accumulator & input types are supported
202+
if len(input_torch_dtypes) != 2:
203+
return None
204+
205+
if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes):
206+
return torch.float
207+
else:
208+
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}")
209+
210+
211+
def get_alignments(torch_dtype: torch.dtype) -> list[int]:
212+
"""
213+
Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype.
214+
"""
215+
# TODO (SYCL): Extend for other types & double-check alignments
216+
if torch_dtype == torch.bfloat16:
217+
return [8, 4, 2, 1]
218+
elif torch_dtype == torch.float:
219+
return [4, 2, 1]
220+
else:
221+
raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")
222+
223+
224+
def get_max_alignment(inductor_layout: Layout) -> int:
225+
"""
226+
Returns the max alignment (in terms of number of elements) for a given Inductor Layout.
227+
"""
228+
229+
dtype = inductor_layout.dtype
230+
size = inductor_layout.size
231+
offset = inductor_layout.offset
232+
233+
def is_static_int(number):
234+
return isinstance(number, (int, sympy.Integer))
235+
236+
def a_factor_of(x, alignment):
237+
if is_static_int(x) and is_static_int(alignment):
238+
return x % alignment == 0
239+
rem = sympy.Mod(x, alignment)
240+
return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0))
241+
242+
try:
243+
contiguous_dim = inductor_layout.stride.index(1)
244+
except ValueError:
245+
# No dim with stride 1 found, return 1
246+
return 1
247+
alignments = get_alignments(dtype)
248+
for alignment in alignments:
249+
if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of(
250+
offset, alignment
251+
):
252+
continue
253+
if all(
254+
(dim == contiguous_dim)
255+
or a_factor_of(inductor_layout.stride[dim], alignment)
256+
for dim in range(len(size))
257+
):
258+
return alignment
259+
return 1
260+
261+
262+
# TODO (SYCL) : Add helpers for CUTLASS kernels testing & benchmarking once standalone
263+
# runner is enabled.

0 commit comments

Comments
 (0)