Skip to content

Commit 730119f

Browse files
authored
Implement SYCL code cache (#1)
Implement a `SYCLCodeCache` similar to the `CUDACodeCache` for compilation of CUTLASS code for integration of the CUTLASS SYCL backend into Pytorch TorchInductor. Also changes the third-party submodule for CUTLASS to use Codeplay's fork with the SYCL backend. --------- Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent 00a2c68 commit 730119f

File tree

9 files changed

+227
-20
lines changed

9 files changed

+227
-20
lines changed

.gitmodules

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@
109109
url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git
110110
[submodule "third_party/cutlass"]
111111
path = third_party/cutlass
112-
url = https://github.com/NVIDIA/cutlass.git
112+
url = https://github.com/codeplaysoftware/cutlass-fork.git
113113
[submodule "third_party/mimalloc"]
114114
path = third_party/mimalloc
115115
url = https://github.com/microsoft/mimalloc.git

third_party/cutlass

torch/_inductor/async_compile.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
HalideCodeCache,
3333
LambdaFuture,
3434
ROCmCodeCache,
35+
SYCLCodeCache,
3536
torch_key,
3637
)
3738
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
@@ -410,6 +411,18 @@ def task():
410411

411412
return self.submit(task)
412413

414+
def sycl(self, source_code, dst_file_ext, aot_compile=False):
415+
kernel_code_log.info("SYCL Kernel:\n%s", source_code)
416+
417+
def task():
418+
if aot_compile:
419+
# TODO: Suppot AoT compilation.
420+
raise RuntimeError("AoT compilation not yet supported for SYCL")
421+
return SYCLCodeCache.load(source_code, dst_file_ext)[0]
422+
423+
return self.submit(task)
424+
425+
413426
def halide(self, meta: HalideMeta, source_code: str):
414427
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
415428
if get_compile_threads() <= 1:

torch/_inductor/codecache.py

Lines changed: 175 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2944,7 +2944,7 @@ def _cutlass_include_paths() -> list[str]:
29442944

29452945
cutlass_path = parutil.get_dir_path("cutlass-3-headers")
29462946
else:
2947-
cutlass_path = config.cuda.cutlass_dir
2947+
cutlass_path = config.cutlass_dir
29482948
return [
29492949
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
29502950
os.path.realpath(os.path.join(cutlass_path, "include")),
@@ -3084,7 +3084,6 @@ def close(self) -> None:
30843084

30853085
def _dlclose(self) -> None:
30863086
f_dlclose = None
3087-
30883087
if is_linux():
30893088
syms = CDLL(None)
30903089
if not hasattr(syms, "dlclose"):
@@ -3323,6 +3322,180 @@ def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str
33233322
return (DLLWrapper(dst_file_path), hash_key, source_code_path)
33243323

33253324

3325+
def _sycl_compiler() -> Optional[str]:
3326+
# TODO: Add detection of `icpx` as release compiler and add option to use
3327+
# environment variable to specify specific compiler.
3328+
return "clang++"
3329+
3330+
3331+
def _sycl_lib_options() -> list[str]:
3332+
_set_gpu_runtime_env() # cpp_extension consults the env
3333+
from torch.utils import cpp_extension
3334+
3335+
lpaths = cpp_extension.library_paths(device_type="xpu")
3336+
extra_ldflags: list[str] = []
3337+
if is_linux():
3338+
for path in lpaths:
3339+
# -rpath ensures the DLL can find its dependencies when loaded, even
3340+
# if the library path is non-standard.
3341+
extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"])
3342+
else:
3343+
raise NotImplementedError(
3344+
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
3345+
)
3346+
return extra_ldflags
3347+
3348+
3349+
def _dpcpp_compiler_options() -> list[str]:
3350+
# TODO: Automatically detect device architecture.
3351+
arch = f"intel_gpu_{config.sycl.arch}" if config.sycl.arch is not None else "spir64"
3352+
options = [
3353+
"-fsycl",
3354+
"-std=c++17",
3355+
"-fPIC",
3356+
"-Xspirv-translator", "-spirv-ext=+SPV_INTEL_split_barrier",
3357+
"-fsycl-range-rounding=disable",
3358+
f"-fsycl-targets={arch}",
3359+
config.sycl.compile_opt_level,
3360+
"-DCUTLASS_ENABLE_SYCL",
3361+
"-DSYCL_INTEL_TARGET",
3362+
]
3363+
# TODO: Add special case for FB?
3364+
if config.sycl.enable_debug_info:
3365+
options.extend(["-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
3366+
if config.sycl.use_fast_math:
3367+
options.extend(
3368+
[
3369+
"-ffast-math",
3370+
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
3371+
]
3372+
)
3373+
3374+
return options
3375+
3376+
3377+
def sycl_compile_command(
3378+
src_files: list[str],
3379+
dst_file: str,
3380+
dst_file_ext: str,
3381+
extra_args: Optional[list[str]] = None,
3382+
) -> str:
3383+
if extra_args is None:
3384+
extra_args = []
3385+
include_paths = _cutlass_include_paths()
3386+
sycl_lib_options = _sycl_lib_options()
3387+
dpcpp_compiler_options = _dpcpp_compiler_options()
3388+
options = (
3389+
dpcpp_compiler_options
3390+
+ extra_args
3391+
+ ["-I" + path for path in include_paths]
3392+
+ sycl_lib_options
3393+
)
3394+
src_file = " ".join(src_files)
3395+
res = ""
3396+
if dst_file_ext == "o":
3397+
res = f"{_sycl_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
3398+
elif dst_file_ext == "so":
3399+
options.append("-shared")
3400+
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
3401+
elif dst_file_ext == "exe":
3402+
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
3403+
else:
3404+
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
3405+
log.debug("SYCL command: %s", res)
3406+
return res
3407+
3408+
3409+
3410+
@clear_on_fresh_inductor_cache
3411+
class SYCLCodeCache:
3412+
@dataclasses.dataclass
3413+
class CacheEntry:
3414+
input_path: str
3415+
output_path: str
3416+
3417+
cache: dict[str, CacheEntry] = {}
3418+
cache_clear = staticmethod(cache.clear)
3419+
_SOURCE_CODE_SUFFIX = "cpp"
3420+
3421+
@classmethod
3422+
def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
3423+
"""
3424+
Writes source code into a file with dst_file_ext as the file extension.
3425+
Returns the hash key of source code, and the path to the file.
3426+
"""
3427+
3428+
sycl_command = repr(
3429+
sycl_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
3430+
)
3431+
key, input_path = write(
3432+
source_code, cls._SOURCE_CODE_SUFFIX, extra=sycl_command
3433+
)
3434+
return key, input_path
3435+
3436+
@classmethod
3437+
def compile(
3438+
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
3439+
) -> tuple[str, str, str]:
3440+
"""
3441+
Compiles SYCL source_code into a file with dst_file_ext extension.
3442+
Returns a tuple of dst_file_path, hash_key, source_code_path
3443+
"""
3444+
key, input_path = cls.write(source_code, dst_file_ext)
3445+
if key not in cls.cache:
3446+
from torch.utils._filelock import FileLock
3447+
3448+
lock_dir = get_lock_dir()
3449+
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
3450+
with lock:
3451+
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
3452+
if not os.path.exists(output_path):
3453+
cmd = sycl_compile_command(
3454+
[input_path], output_path, dst_file_ext, extra_args
3455+
)
3456+
with open(input_path, "a") as f:
3457+
f.write("\n")
3458+
f.write(f"// SYCL Compile cmd\n// {cmd}\n")
3459+
start_time = time()
3460+
log.debug("SYCL Compilation: %s", cmd)
3461+
cmd_parts = cmd.split(" ")
3462+
try:
3463+
subprocess.check_output(
3464+
cmd_parts, stderr=subprocess.STDOUT, env=os.environ
3465+
)
3466+
except subprocess.CalledProcessError as error:
3467+
raise exc.SYCLCompileError(cmd_parts, error.output) from error
3468+
end_time = time()
3469+
log_duration_msg = f"SYCL Compilation took {end_time - start_time} seconds. Compile command: {cmd}"
3470+
log.info(log_duration_msg)
3471+
else:
3472+
log.debug(
3473+
"SYCL Compilation skipped: %s since output already exists",
3474+
input_path,
3475+
)
3476+
cls.cache[key] = SYCLCodeCache.CacheEntry(input_path, output_path)
3477+
3478+
return (cls.cache[key].output_path, key, input_path)
3479+
3480+
@classmethod
3481+
def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]:
3482+
"""
3483+
Compiles source code and loads the generated .so file.
3484+
Returns a tuple of DLLWrapper, hash_key, source_code_path
3485+
"""
3486+
3487+
if dst_file_ext != "so":
3488+
raise RuntimeError(
3489+
f"Only support loading a .so file for now. "
3490+
f"Requested file extension: {dst_file_ext}. Source code: {source_code}"
3491+
)
3492+
dst_file_path, hash_key, source_code_path = cls.compile(
3493+
source_code, dst_file_ext
3494+
)
3495+
return (DLLWrapper(dst_file_path), hash_key, source_code_path)
3496+
3497+
3498+
33263499
class CodeCacheFuture:
33273500
def result(self) -> Callable[..., Any]:
33283501
raise NotImplementedError

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def try_import_cutlass() -> bool:
5353
log.warning("CUTLASS version < 3.7 is not recommended.")
5454

5555
log.debug(
56-
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
56+
"Found cutlass_library in python search path, overriding config.cutlass_dir"
5757
)
5858
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
5959
assert os.path.isdir(cutlass_library_dir), (
6060
f"{cutlass_library_dir} is not a directory"
6161
)
62-
config.cuda.cutlass_dir = os.path.abspath(
62+
config.cutlass_dir = os.path.abspath(
6363
os.path.join(
6464
cutlass_library_dir,
6565
"source",
@@ -68,15 +68,15 @@ def try_import_cutlass() -> bool:
6868
return True
6969
except ModuleNotFoundError:
7070
log.debug(
71-
"cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir"
71+
"cutlass_library not found in sys.path, trying to import from config.cutlass_dir"
7272
)
7373

7474
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
7575
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7676
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7777

7878
cutlass_py_full_path = os.path.abspath(
79-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
79+
os.path.join(config.cutlass_dir, "python/cutlass_library")
8080
)
8181
tmp_cutlass_py_full_path = os.path.abspath(
8282
os.path.join(cache_dir(), "torch_cutlass_library")

torch/_inductor/config.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,16 @@ def decide_compile_threads() -> int:
871871
# Adds NVTX annotations aroung training phases
872872
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
873873

874+
# Path to the CUTLASS repo root directory.
875+
# The default path only works under PyTorch local development environment.
876+
cutlass_dir = os.environ.get(
877+
"TORCHINDUCTOR_CUTLASS_DIR",
878+
os.path.abspath(
879+
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
880+
),
881+
)
882+
883+
874884

875885
# config specific to codegen/cpp.py
876886
class cpp:
@@ -1280,15 +1290,6 @@ class cuda:
12801290
# Whether to use fast math.
12811291
use_fast_math = False
12821292

1283-
# Path to the CUTLASS repo root directory.
1284-
# The default path only works under PyTorch local development environment.
1285-
cutlass_dir = os.environ.get(
1286-
"TORCHINDUCTOR_CUTLASS_DIR",
1287-
os.path.abspath(
1288-
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
1289-
),
1290-
)
1291-
12921293
# Configures the maximum number of CUTLASS configs to profile in max_autotune.
12931294
# By default it's None, so that all CUTLASS configs are tuned.
12941295
# This is mainly used to reduce test time in CI.
@@ -1402,6 +1403,22 @@ class rocm:
14021403
split_k_threshold: int = 16
14031404

14041405

1406+
class sycl:
1407+
# Intel GPU arch to use for SYCL template kernel compilation.
1408+
# e.g. "pvc", "bmg", etc.
1409+
# When arch is None, generates SPIR-V that is finalized at runtime.
1410+
arch: Optional[str] = None
1411+
1412+
# Optimization level for the host compiler.
1413+
compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"
1414+
1415+
# Whether to enable debug info, e.g. line number, cutlass debug info.
1416+
enable_debug_info = False
1417+
1418+
# Whether to use fast math.
1419+
use_fast_math = False
1420+
1421+
14051422
# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
14061423
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"
14071424

@@ -1518,7 +1535,7 @@ class trace:
15181535
# trace functions are not relevant to config caching
15191536
"trace",
15201537
# uses absolute path
1521-
"cuda.cutlass_dir",
1538+
"cutlass_dir",
15221539
# not relevant
15231540
"worker_start_method",
15241541
"compile_threads",

torch/_inductor/exc.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ class CUDACompileError(CppCompileError):
113113
pass
114114

115115

116+
class SYCLCompileError(CppCompileError):
117+
pass
118+
119+
116120
class TritonMissing(ShortenTraceback):
117121
def __init__(self, first_useful_frame: Optional[types.FrameType]) -> None:
118122
super().__init__(

torch/_inductor/fuzzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def keys(self) -> KeysView[ComboType]:
469469
"aot_inductor.presets": DEFAULT, # Typing
470470
"cuda.arch": DEFAULT, # Out of Scope
471471
"cuda.version": DEFAULT, # Out of Scope
472-
"cuda.cutlass_dir": DEFAULT, # Out of Scope
472+
"cutlass_dir": DEFAULT, # Out of Scope
473473
"cuda.cuda_cxx": DEFAULT, # Out of Scope
474474
"rocm.arch": DEFAULT, # Out of Scope
475475
"rocm.ck_supported_arch": DEFAULT, # Out of Scope

torch/_inductor/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
14051405
if not try_import_cutlass():
14061406
log.warning(
14071407
"Failed to import CUTLASS lib. Please check whether "
1408-
"_inductor.config.cuda.cutlass_dir is set correctly. "
1408+
"_inductor.config.cutlass_dir is set correctly. "
14091409
"Skipping CUTLASS backend for now."
14101410
)
14111411
return False

0 commit comments

Comments
 (0)