Skip to content

Implement SYCL code cache #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@
url = https://github.com/GPUOpen-LibrariesAndSDKs/VulkanMemoryAllocator.git
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
url = https://github.com/codeplaysoftware/cutlass-fork.git
[submodule "third_party/mimalloc"]
path = third_party/mimalloc
url = https://github.com/microsoft/mimalloc.git
Expand Down
2 changes: 1 addition & 1 deletion third_party/cutlass
13 changes: 13 additions & 0 deletions torch/_inductor/async_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
HalideCodeCache,
LambdaFuture,
ROCmCodeCache,
SYCLCodeCache,
torch_key,
)
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
Expand Down Expand Up @@ -410,6 +411,18 @@ def task():

return self.submit(task)

def sycl(self, source_code, dst_file_ext, aot_compile=False):
kernel_code_log.info("SYCL Kernel:\n%s", source_code)

def task():
if aot_compile:
# TODO: Suppot AoT compilation.
raise RuntimeError("AoT compilation not yet supported for SYCL")
return SYCLCodeCache.load(source_code, dst_file_ext)[0]

return self.submit(task)


def halide(self, meta: HalideMeta, source_code: str):
kernel_code_log.info("Halide Kernel:\n%r\n%s", meta, source_code)
if get_compile_threads() <= 1:
Expand Down
177 changes: 175 additions & 2 deletions torch/_inductor/codecache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2944,7 +2944,7 @@ def _cutlass_include_paths() -> list[str]:

cutlass_path = parutil.get_dir_path("cutlass-3-headers")
else:
cutlass_path = config.cuda.cutlass_dir
cutlass_path = config.cutlass_dir
return [
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
os.path.realpath(os.path.join(cutlass_path, "include")),
Expand Down Expand Up @@ -3084,7 +3084,6 @@ def close(self) -> None:

def _dlclose(self) -> None:
f_dlclose = None

if is_linux():
syms = CDLL(None)
if not hasattr(syms, "dlclose"):
Expand Down Expand Up @@ -3323,6 +3322,180 @@ def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str
return (DLLWrapper(dst_file_path), hash_key, source_code_path)


def _sycl_compiler() -> Optional[str]:
# TODO: Add detection of `icpx` as release compiler and add option to use
# environment variable to specify specific compiler.
return "clang++"


def _sycl_lib_options() -> list[str]:
_set_gpu_runtime_env() # cpp_extension consults the env
from torch.utils import cpp_extension

lpaths = cpp_extension.library_paths(device_type="xpu")
extra_ldflags: list[str] = []
if is_linux():
for path in lpaths:
# -rpath ensures the DLL can find its dependencies when loaded, even
# if the library path is non-standard.
extra_ldflags.extend([f"-L{path}", "-Xlinker", f"-rpath={path}"])
else:
raise NotImplementedError(
"Unsupported env, failed to find cuda libs! Currently only Linux is supported."
)
return extra_ldflags


def _dpcpp_compiler_options() -> list[str]:
# TODO: Automatically detect device architecture.
arch = f"intel_gpu_{config.sycl.arch}" if config.sycl.arch is not None else "spir64"
options = [
"-fsycl",
"-std=c++17",
"-fPIC",
"-Xspirv-translator", "-spirv-ext=+SPV_INTEL_split_barrier",
"-fsycl-range-rounding=disable",
f"-fsycl-targets={arch}",
config.sycl.compile_opt_level,
"-DCUTLASS_ENABLE_SYCL",
"-DSYCL_INTEL_TARGET",
]
# TODO: Add special case for FB?
if config.sycl.enable_debug_info:
options.extend(["-g", "-DCUTLASS_DEBUG_TRACE_LEVEL=1"])
if config.sycl.use_fast_math:
options.extend(
[
"-ffast-math",
"-DCUTLASS_USE_TANH_FOR_SIGMOID=1",
]
)

return options


def sycl_compile_command(
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[list[str]] = None,
) -> str:
if extra_args is None:
extra_args = []
include_paths = _cutlass_include_paths()
sycl_lib_options = _sycl_lib_options()
dpcpp_compiler_options = _dpcpp_compiler_options()
options = (
dpcpp_compiler_options
+ extra_args
+ ["-I" + path for path in include_paths]
+ sycl_lib_options
)
src_file = " ".join(src_files)
res = ""
if dst_file_ext == "o":
res = f"{_sycl_compiler()} {' '.join(options)} -c -o {dst_file} {src_file}"
elif dst_file_ext == "so":
options.append("-shared")
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
elif dst_file_ext == "exe":
res = f"{_sycl_compiler()} {' '.join(options)} -o {dst_file} {src_file}"
else:
raise NotImplementedError(f"Unsupported output file suffix {dst_file_ext}!")
log.debug("SYCL command: %s", res)
return res



@clear_on_fresh_inductor_cache
class SYCLCodeCache:
@dataclasses.dataclass
class CacheEntry:
input_path: str
output_path: str

cache: dict[str, CacheEntry] = {}
cache_clear = staticmethod(cache.clear)
_SOURCE_CODE_SUFFIX = "cpp"

@classmethod
def write(cls, source_code: str, dst_file_ext: str) -> tuple[str, str]:
"""
Writes source code into a file with dst_file_ext as the file extension.
Returns the hash key of source code, and the path to the file.
"""

sycl_command = repr(
sycl_compile_command(["dummy_input"], "dummy_output", dst_file_ext)
)
key, input_path = write(
source_code, cls._SOURCE_CODE_SUFFIX, extra=sycl_command
)
return key, input_path

@classmethod
def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
) -> tuple[str, str, str]:
"""
Compiles SYCL source_code into a file with dst_file_ext extension.
Returns a tuple of dst_file_path, hash_key, source_code_path
"""
key, input_path = cls.write(source_code, dst_file_ext)
if key not in cls.cache:
from torch.utils._filelock import FileLock

lock_dir = get_lock_dir()
lock = FileLock(os.path.join(lock_dir, key + ".lock"), timeout=LOCK_TIMEOUT)
with lock:
output_path = input_path[: -len(cls._SOURCE_CODE_SUFFIX)] + dst_file_ext
if not os.path.exists(output_path):
cmd = sycl_compile_command(
[input_path], output_path, dst_file_ext, extra_args
)
with open(input_path, "a") as f:
f.write("\n")
f.write(f"// SYCL Compile cmd\n// {cmd}\n")
start_time = time()
log.debug("SYCL Compilation: %s", cmd)
cmd_parts = cmd.split(" ")
try:
subprocess.check_output(
cmd_parts, stderr=subprocess.STDOUT, env=os.environ
)
except subprocess.CalledProcessError as error:
raise exc.SYCLCompileError(cmd_parts, error.output) from error
end_time = time()
log_duration_msg = f"SYCL Compilation took {end_time - start_time} seconds. Compile command: {cmd}"
log.info(log_duration_msg)
else:
log.debug(
"SYCL Compilation skipped: %s since output already exists",
input_path,
)
cls.cache[key] = SYCLCodeCache.CacheEntry(input_path, output_path)

return (cls.cache[key].output_path, key, input_path)

@classmethod
def load(cls, source_code: str, dst_file_ext: str) -> tuple[DLLWrapper, str, str]:
"""
Compiles source code and loads the generated .so file.
Returns a tuple of DLLWrapper, hash_key, source_code_path
"""

if dst_file_ext != "so":
raise RuntimeError(
f"Only support loading a .so file for now. "
f"Requested file extension: {dst_file_ext}. Source code: {source_code}"
)
dst_file_path, hash_key, source_code_path = cls.compile(
source_code, dst_file_ext
)
return (DLLWrapper(dst_file_path), hash_key, source_code_path)



class CodeCacheFuture:
def result(self) -> Callable[..., Any]:
raise NotImplementedError
Expand Down
8 changes: 4 additions & 4 deletions torch/_inductor/codegen/cuda/cutlass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ def try_import_cutlass() -> bool:
log.warning("CUTLASS version < 3.7 is not recommended.")

log.debug(
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
"Found cutlass_library in python search path, overriding config.cutlass_dir"
)
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
assert os.path.isdir(cutlass_library_dir), (
f"{cutlass_library_dir} is not a directory"
)
config.cuda.cutlass_dir = os.path.abspath(
config.cutlass_dir = os.path.abspath(
os.path.join(
cutlass_library_dir,
"source",
Expand All @@ -68,15 +68,15 @@ def try_import_cutlass() -> bool:
return True
except ModuleNotFoundError:
log.debug(
"cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir"
"cutlass_library not found in sys.path, trying to import from config.cutlass_dir"
)

# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
# This is a temporary hack to avoid CUTLASS module naming conflicts.
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.

cutlass_py_full_path = os.path.abspath(
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
os.path.join(config.cutlass_dir, "python/cutlass_library")
)
tmp_cutlass_py_full_path = os.path.abspath(
os.path.join(cache_dir(), "torch_cutlass_library")
Expand Down
37 changes: 27 additions & 10 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,6 +871,16 @@ def decide_compile_threads() -> int:
# Adds NVTX annotations aroung training phases
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"

# Path to the CUTLASS repo root directory.
# The default path only works under PyTorch local development environment.
cutlass_dir = os.environ.get(
"TORCHINDUCTOR_CUTLASS_DIR",
os.path.abspath(
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
),
)



# config specific to codegen/cpp.py
class cpp:
Expand Down Expand Up @@ -1280,15 +1290,6 @@ class cuda:
# Whether to use fast math.
use_fast_math = False

# Path to the CUTLASS repo root directory.
# The default path only works under PyTorch local development environment.
cutlass_dir = os.environ.get(
"TORCHINDUCTOR_CUTLASS_DIR",
os.path.abspath(
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
),
)

# Configures the maximum number of CUTLASS configs to profile in max_autotune.
# By default it's None, so that all CUTLASS configs are tuned.
# This is mainly used to reduce test time in CI.
Expand Down Expand Up @@ -1402,6 +1403,22 @@ class rocm:
split_k_threshold: int = 16


class sycl:
# Intel GPU arch to use for SYCL template kernel compilation.
# e.g. "pvc", "bmg", etc.
# When arch is None, generates SPIR-V that is finalized at runtime.
arch: Optional[str] = None

# Optimization level for the host compiler.
compile_opt_level: Literal["-O0", "-O1", "-O2", "-O3", "-OS"] = "-O1"

# Whether to enable debug info, e.g. line number, cutlass debug info.
enable_debug_info = False

# Whether to use fast math.
use_fast_math = False


# Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental)
cpu_backend: Literal["cpp", "triton", "halide"] = "cpp"

Expand Down Expand Up @@ -1518,7 +1535,7 @@ class trace:
# trace functions are not relevant to config caching
"trace",
# uses absolute path
"cuda.cutlass_dir",
"cutlass_dir",
# not relevant
"worker_start_method",
"compile_threads",
Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ class CUDACompileError(CppCompileError):
pass


class SYCLCompileError(CppCompileError):
pass


class TritonMissing(ShortenTraceback):
def __init__(self, first_useful_frame: Optional[types.FrameType]) -> None:
super().__init__(
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def keys(self) -> KeysView[ComboType]:
"aot_inductor.presets": DEFAULT, # Typing
"cuda.arch": DEFAULT, # Out of Scope
"cuda.version": DEFAULT, # Out of Scope
"cuda.cutlass_dir": DEFAULT, # Out of Scope
"cutlass_dir": DEFAULT, # Out of Scope
"cuda.cuda_cxx": DEFAULT, # Out of Scope
"rocm.arch": DEFAULT, # Out of Scope
"rocm.ck_supported_arch": DEFAULT, # Out of Scope
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,7 +1405,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
if not try_import_cutlass():
log.warning(
"Failed to import CUTLASS lib. Please check whether "
"_inductor.config.cuda.cutlass_dir is set correctly. "
"_inductor.config.cutlass_dir is set correctly. "
"Skipping CUTLASS backend for now."
)
return False
Expand Down