Skip to content

Commit 05bf838

Browse files
Initial support of SYCL CUTLASS for XPU backend through Inductor (#2)
Summary : This patch enables an initial execution of `torch.mm` through an inductor generated SYCL CUTLASS kernel for intel PVC. Following the reference CUDA implementation, this implements the following key functionalities : - For Template Generation & rendering : - `SYCLTemplate`, `CUTLASSTemplate`, `CUTLASSGemmTemplate`, `CUTLASS3xGemmTemplate` : Handles generating the full c++ code from the call to `GeneratePVC` to get the `GemmOperations` (exposed by cutlass_library), filtering the operations, constructing the Manifest & extracting the gemm instance from the emitter (exposed by cutlass_library) until the full wrapping of the c++ template code using runtime arguments & final kernel launch. - `cutlass_utils.py` : utility file containing relevant functions used across the codegen process. - `SYCLKernel`, `SYCLTemplateKernel` : Handles higher level kernel template and kernel calling from host side. Used within the previous Template classes. - For Autotuning : - `SYCLBenchmarkRequest` : Currently added as an almost dummy, not really benchmarking since we're selecting a single generated configuration for this PoC. - For wrapping & triggering the above : - `SYCLTemplateCaller` : Wrapper holding a ready to compile, execute, benchmark SYCL Template Kernel. This is the higher level construct that's added to the list of "choices" in the autotuning process for selecting the best configuration. - For scheduling/Execution : - `SYCLCPPScheduling` & `SYCLCombinedScheduling` : Orchestrator of kernel calls across eventually nodes with different lowerings (Triton & CUTLASS SYCL for instance). Few changes have been made to this compared to original CUDA implementation. Current state was fine-tuned to support the only type configuration exposed by cutlass on PVC so far, a.k.a `bfloat16` input and `fp32` accumulation, forcing some workarounds on pytorch side, namely related to D(layout/output node) & C(source/input_node[2]) dtypes. Unsupported features &/or partially implemented ones are highlighted as comments `TODO (SYCL)`. --------- Co-authored-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent 2efa9ee commit 05bf838

File tree

13 files changed

+2414
-5
lines changed

13 files changed

+2414
-5
lines changed

torch/_inductor/autotune_process.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
DLLWrapper,
3131
get_hash,
3232
PyCodeCache,
33+
SYCLCodeCache,
3334
)
3435
from torch._inductor.utils import get_gpu_type, get_ld_library_path, is_gpu
3536
from torch._logging import getArtifactLogger
@@ -886,6 +887,95 @@ def get_tuning_process_pool() -> TuningProcessPool:
886887
return pool
887888

888889

890+
class SYCLBenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
891+
# Important: Instances of this class have to be serializable
892+
# across process boundaries. Do not put Tensors in here!
893+
# TODO (SYCL) : Complete the bmrq class to enable full autotuning
894+
def __init__(
895+
self,
896+
kernel_name: str,
897+
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
898+
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
899+
extra_args: Iterable[Any],
900+
source_code: str,
901+
) -> None:
902+
super().__init__(kernel_name, input_tensor_meta, output_tensor_meta, extra_args)
903+
self.source_code = source_code
904+
self.workspace_size: int = 0 # TODO (SYCL): workspace size remains 0
905+
self.workspace: Optional[torch.Tensor] = None
906+
self.DLL: Optional[DLLWrapper] = None
907+
self._workspace_size_updated = False
908+
self.hash_key: str = ""
909+
self.source_file: str = ""
910+
self.hash_key, self.source_file = SYCLCodeCache.write(self.source_code, "so")
911+
912+
def precompile(self):
913+
# Prepopulate SYCLCodeCache
914+
autotuning_log.debug("Precompiling %s", self)
915+
SYCLCodeCache.compile(self.source_code, "so")
916+
autotuning_log.debug("Done precompiling %s", self)
917+
918+
def make_run_fn(
919+
self, *input_tensors: torch.Tensor, output_tensor: torch.Tensor
920+
) -> Callable[[], None]:
921+
self.ensure_dll_loaded()
922+
self.update_workspace_size() # TODO (SYCL): No effect on workspace_size being unused (remains = 0)
923+
args = [
924+
c_void_p(tensor.data_ptr())
925+
for tensor in list(input_tensors) + [output_tensor]
926+
]
927+
autotuning_log.debug(
928+
"make_run_fn: self.kernel_name=%s, self.source_file=%s, self.hash_key=%s, self.DLL=%s, args=%s, self.extra_args=%s",
929+
self.kernel_name,
930+
self.source_file,
931+
self.hash_key,
932+
self.DLL,
933+
args,
934+
self.extra_args,
935+
)
936+
queue_ptr = c_void_p(torch.xpu.current_stream().sycl_queue)
937+
run_method = getattr(self.DLL, self.kernel_name)
938+
workspace_ptr = c_void_p(0)
939+
if self.workspace_size > 0:
940+
self.workspace = torch.zeros(
941+
(self.workspace_size + 7) // 8,
942+
dtype=torch.float64,
943+
device=output_tensor.device,
944+
)
945+
workspace_ptr = c_void_p(self.workspace.data_ptr())
946+
947+
# Generate partial function.
948+
return functools.partial(
949+
run_method,
950+
*args,
951+
*self.extra_args,
952+
None, # null workspace size ptr
953+
workspace_ptr, # set workspace ptr,
954+
queue_ptr,
955+
)
956+
957+
def update_workspace_size(self) -> None:
958+
if self._workspace_size_updated:
959+
return
960+
# TODO (SYCL): Harcoded to zero since no SLM is used on PVC at the moment
961+
self.workspace_size = 0
962+
self._workspace_size_updated = True
963+
964+
def ensure_dll_loaded(self):
965+
if self.DLL is None:
966+
self.DLL, self.hash_key, self.source_file = SYCLCodeCache.load(
967+
self.source_code, "so"
968+
)
969+
970+
def cleanup_run_fn(self) -> None:
971+
if self.DLL is not None:
972+
self.DLL.close()
973+
self.workspace = None
974+
975+
def __str__(self) -> str:
976+
return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}"
977+
978+
889979
def benchmark_in_sub_process(
890980
choices: list[TritonTemplateCaller],
891981
) -> dict[TritonTemplateCaller, float]:

torch/_inductor/codegen/common.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ def init_backend_registration() -> None:
395395
from .mps import MetalScheduling
396396
from .triton import TritonScheduling
397397
from .wrapper import PythonWrapperCodegen
398+
from .xpu_combined_scheduling import SYCLCombinedScheduling
398399

399400
if get_scheduling_for_device("cpu") is None:
400401
cpu_backends = {
@@ -425,9 +426,10 @@ def init_backend_registration() -> None:
425426
)
426427

427428
if get_scheduling_for_device("xpu") is None:
429+
# SYCLCombinedScheduling combines Triton and SYCL C++ scheduling for XPU devices via delegation
428430
register_backend_for_device(
429431
"xpu",
430-
TritonScheduling,
432+
SYCLCombinedScheduling,
431433
PythonWrapperCodegen,
432434
CppWrapperGpu,
433435
)
Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
# mypy: allow-untyped-defs
2+
import functools
3+
import logging
4+
import os
5+
import sys
6+
from dataclasses import dataclass
7+
from typing import Any, Optional
8+
9+
import sympy
10+
11+
import torch
12+
from torch._inductor.utils import clear_on_fresh_inductor_cache
13+
14+
from ... import config
15+
from ...ir import Layout
16+
from ...runtime.runtime_utils import cache_dir
17+
from ...virtualized import V
18+
19+
20+
log = logging.getLogger(__name__)
21+
22+
23+
@functools.lru_cache(None)
24+
def try_import_cutlass() -> bool:
25+
"""
26+
Currently only supporting user specified cutlass_dir or falling to the
27+
default ../third_party/cutlass/ (build from source setups).
28+
"""
29+
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
30+
31+
cutlass_py_full_path = os.path.abspath(
32+
os.path.join(config.cutlass_dir, "python/cutlass_library")
33+
)
34+
tmp_cutlass_py_full_path = os.path.abspath(
35+
os.path.join(cache_dir(), "torch_cutlass_library")
36+
)
37+
dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library")
38+
39+
if os.path.isdir(cutlass_py_full_path):
40+
if tmp_cutlass_py_full_path not in sys.path:
41+
if os.path.exists(dst_link):
42+
assert os.path.islink(dst_link), (
43+
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
44+
)
45+
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
46+
cutlass_py_full_path
47+
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
48+
else:
49+
os.makedirs(tmp_cutlass_py_full_path, exist_ok=True)
50+
os.symlink(cutlass_py_full_path, dst_link)
51+
sys.path.append(tmp_cutlass_py_full_path)
52+
try:
53+
import cutlass_library.generator # noqa: F401
54+
import cutlass_library.library # noqa: F401
55+
import cutlass_library.manifest # noqa: F401
56+
57+
return True
58+
except ImportError as e:
59+
log.debug(
60+
"Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.",
61+
str(e),
62+
)
63+
else:
64+
log.debug(
65+
"Failed to import CUTLASS packages: CUTLASS repo does not exist: %s",
66+
cutlass_py_full_path,
67+
)
68+
return False
69+
70+
71+
@functools.lru_cache(8)
72+
def _normalize_sycl_arch(arch: str) -> str:
73+
if int(arch) == 11:
74+
return "11"
75+
else:
76+
raise NotImplementedError(f"Unsupported sycl arch: {arch}")
77+
78+
79+
@dataclass
80+
class CUTLASSArgs:
81+
"""
82+
CUTLASS args used to initialize a CUTLASS Manifest.
83+
"""
84+
85+
architectures: Optional[str] = None
86+
cuda_version: Optional[str] = None # Unused in generator.py for PVC
87+
instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC
88+
89+
operations = "all"
90+
build_dir = ""
91+
curr_build_dir = ""
92+
generator_target = ""
93+
kernels = "all"
94+
ignore_kernels = ""
95+
exclude_kernels = ""
96+
# UNUSED at the moment, part of Manifest class in cutlass_library
97+
kernel_filter_file: None = None
98+
selected_kernel_list: None = None
99+
interface_dir: None = None
100+
filter_by_cc = False
101+
disable_full_archs_compilation = False
102+
103+
def __post_init__(self):
104+
if self.architectures is None:
105+
raise RuntimeError(f"{self.architectures=} is None!")
106+
self.architectures = _normalize_sycl_arch(self.architectures)
107+
108+
109+
@clear_on_fresh_inductor_cache
110+
@functools.lru_cache(None)
111+
def _gen_ops_cached(arch) -> list[Any]:
112+
# Import cutlass python scripts.
113+
assert try_import_cutlass()
114+
import cutlass_library.generator as cutlass_generator
115+
import cutlass_library.manifest as cutlass_manifest
116+
117+
if arch is None:
118+
log.error(
119+
"Cannot detect XPU arch %s."
120+
"Will discard all cutlass ops. "
121+
"Please consider setting _inductor.xpu.arch",
122+
arch,
123+
)
124+
return []
125+
arch = _normalize_sycl_arch(arch)
126+
127+
sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC
128+
129+
args = CUTLASSArgs(
130+
architectures=arch,
131+
instantiation_level="0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py
132+
cuda_version=sycl_version,
133+
)
134+
manifest = cutlass_manifest.Manifest(args)
135+
136+
if arch == "11":
137+
cutlass_generator.GeneratePVC(manifest, sycl_version)
138+
else:
139+
log.error("Invalid XPU arch")
140+
return []
141+
return manifest.operations
142+
143+
144+
def gen_ops() -> list[Any]:
145+
"""
146+
Generates all supported CUTLASS operations.
147+
"""
148+
# Currently limited to PVC (arch 1100), harcoding arch
149+
# TODO :(SYCL) get_xpu_arch()
150+
arch = "11"
151+
return _gen_ops_cached(arch)
152+
153+
154+
def torch_dtype_to_cutlass_type(
155+
torch_dtype: torch.dtype,
156+
) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821
157+
# Import cutlass python scripts.
158+
assert try_import_cutlass()
159+
import cutlass_library # type: ignore[import]
160+
161+
if torch_dtype == torch.float:
162+
return cutlass_library.library.DataType.f32
163+
elif torch_dtype == torch.half:
164+
return cutlass_library.library.DataType.f16
165+
elif torch_dtype == torch.bfloat16:
166+
return cutlass_library.library.DataType.bf16
167+
else:
168+
raise NotImplementedError(f"Unsupported data type: {torch_dtype=}")
169+
170+
171+
def dtype_match(
172+
torch_dtype: Optional[torch.dtype],
173+
cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821
174+
) -> bool:
175+
# Import cutlass python scripts.
176+
assert try_import_cutlass()
177+
import cutlass_library
178+
179+
if torch_dtype == torch.float:
180+
return cutlass_dtype == cutlass_library.library.DataType.f32
181+
elif torch_dtype == torch.half:
182+
return cutlass_dtype == cutlass_library.library.DataType.f16
183+
elif torch_dtype == torch.bfloat16:
184+
return cutlass_dtype == cutlass_library.library.DataType.bf16
185+
elif torch_dtype == torch.int8:
186+
return cutlass_dtype == cutlass_library.library.DataType.s8
187+
elif torch_dtype == torch.uint8:
188+
return cutlass_dtype == cutlass_library.library.DataType.u8
189+
elif torch_dtype == torch.int32:
190+
return cutlass_dtype == cutlass_library.library.DataType.s32
191+
else:
192+
return False
193+
194+
195+
def get_accumulator_dtype(
196+
input_torch_dtypes: list[torch.dtype],
197+
) -> Optional[torch.dtype]:
198+
"""
199+
Given a pair of input torch dtypes, returns the inferred accumulator torch dtype.
200+
"""
201+
# TODO (SYCL) : Extend this once other accumulator & input types are supported
202+
if len(input_torch_dtypes) != 2:
203+
return None
204+
205+
if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes):
206+
return torch.float
207+
else:
208+
raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}")
209+
210+
211+
def get_alignments(torch_dtype: torch.dtype) -> list[int]:
212+
"""
213+
Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype.
214+
"""
215+
# TODO (SYCL): Extend for other types & double-check alignments
216+
if torch_dtype == torch.bfloat16:
217+
return [8, 4, 2, 1]
218+
elif torch_dtype == torch.float:
219+
return [4, 2, 1]
220+
else:
221+
raise NotImplementedError(f"unsupported {torch_dtype=} for alignments")
222+
223+
224+
def get_max_alignment(inductor_layout: Layout) -> int:
225+
"""
226+
Returns the max alignment (in terms of number of elements) for a given Inductor Layout.
227+
"""
228+
229+
dtype = inductor_layout.dtype
230+
size = inductor_layout.size
231+
offset = inductor_layout.offset
232+
233+
def is_static_int(number):
234+
return isinstance(number, (int, sympy.Integer))
235+
236+
def a_factor_of(x, alignment):
237+
if is_static_int(x) and is_static_int(alignment):
238+
return x % alignment == 0
239+
rem = sympy.Mod(x, alignment)
240+
return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0))
241+
242+
try:
243+
contiguous_dim = inductor_layout.stride.index(1)
244+
except ValueError:
245+
# No dim with stride 1 found, return 1
246+
return 1
247+
alignments = get_alignments(dtype)
248+
for alignment in alignments:
249+
if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of(
250+
offset, alignment
251+
):
252+
continue
253+
if all(
254+
(dim == contiguous_dim)
255+
or a_factor_of(inductor_layout.stride[dim], alignment)
256+
for dim in range(len(size))
257+
):
258+
return alignment
259+
return 1
260+
261+
262+
# TODO (SYCL) : Add helpers for CUTLASS kernels testing & benchmarking once standalone
263+
# runner is enabled.

0 commit comments

Comments
 (0)