|
| 1 | +# mypy: allow-untyped-defs |
| 2 | +import functools |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import sys |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Any, Optional |
| 8 | + |
| 9 | +import sympy |
| 10 | + |
| 11 | +import torch |
| 12 | +from torch._inductor.utils import clear_on_fresh_inductor_cache |
| 13 | + |
| 14 | +from ... import config |
| 15 | +from ...ir import Layout |
| 16 | +from ...runtime.runtime_utils import cache_dir |
| 17 | +from ...virtualized import V |
| 18 | + |
| 19 | + |
| 20 | +log = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +@functools.lru_cache(None) |
| 24 | +def try_import_cutlass() -> bool: |
| 25 | + """ |
| 26 | + Currently only supporting user specified cutlass_dir or falling to the |
| 27 | + default ../third_party/cutlass/ (build from source setups). |
| 28 | + """ |
| 29 | + # Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path. |
| 30 | + |
| 31 | + cutlass_py_full_path = os.path.abspath( |
| 32 | + os.path.join(config.cutlass_dir, "python/cutlass_library") |
| 33 | + ) |
| 34 | + tmp_cutlass_py_full_path = os.path.abspath( |
| 35 | + os.path.join(cache_dir(), "torch_cutlass_library") |
| 36 | + ) |
| 37 | + dst_link = os.path.join(tmp_cutlass_py_full_path, "cutlass_library") |
| 38 | + |
| 39 | + if os.path.isdir(cutlass_py_full_path): |
| 40 | + if tmp_cutlass_py_full_path not in sys.path: |
| 41 | + if os.path.exists(dst_link): |
| 42 | + assert os.path.islink(dst_link), ( |
| 43 | + f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." |
| 44 | + ) |
| 45 | + assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( |
| 46 | + cutlass_py_full_path |
| 47 | + ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" |
| 48 | + else: |
| 49 | + os.makedirs(tmp_cutlass_py_full_path, exist_ok=True) |
| 50 | + os.symlink(cutlass_py_full_path, dst_link) |
| 51 | + sys.path.append(tmp_cutlass_py_full_path) |
| 52 | + try: |
| 53 | + import cutlass_library.generator # noqa: F401 |
| 54 | + import cutlass_library.library # noqa: F401 |
| 55 | + import cutlass_library.manifest # noqa: F401 |
| 56 | + |
| 57 | + return True |
| 58 | + except ImportError as e: |
| 59 | + log.debug( |
| 60 | + "Failed to import CUTLASS packages: %s, ignoring the CUTLASS backend.", |
| 61 | + str(e), |
| 62 | + ) |
| 63 | + else: |
| 64 | + log.debug( |
| 65 | + "Failed to import CUTLASS packages: CUTLASS repo does not exist: %s", |
| 66 | + cutlass_py_full_path, |
| 67 | + ) |
| 68 | + return False |
| 69 | + |
| 70 | + |
| 71 | +@functools.lru_cache(8) |
| 72 | +def _normalize_sycl_arch(arch: str) -> str: |
| 73 | + if int(arch) == 11: |
| 74 | + return "11" |
| 75 | + else: |
| 76 | + raise NotImplementedError(f"Unsupported sycl arch: {arch}") |
| 77 | + |
| 78 | + |
| 79 | +@dataclass |
| 80 | +class CUTLASSArgs: |
| 81 | + """ |
| 82 | + CUTLASS args used to initialize a CUTLASS Manifest. |
| 83 | + """ |
| 84 | + |
| 85 | + architectures: Optional[str] = None |
| 86 | + cuda_version: Optional[str] = None # Unused in generator.py for PVC |
| 87 | + instantiation_level: Optional[str] = None # Unused YET in generator.py for PVC |
| 88 | + |
| 89 | + operations = "all" |
| 90 | + build_dir = "" |
| 91 | + curr_build_dir = "" |
| 92 | + generator_target = "" |
| 93 | + kernels = "all" |
| 94 | + ignore_kernels = "" |
| 95 | + exclude_kernels = "" |
| 96 | + # UNUSED at the moment, part of Manifest class in cutlass_library |
| 97 | + kernel_filter_file: None = None |
| 98 | + selected_kernel_list: None = None |
| 99 | + interface_dir: None = None |
| 100 | + filter_by_cc = False |
| 101 | + disable_full_archs_compilation = False |
| 102 | + |
| 103 | + def __post_init__(self): |
| 104 | + if self.architectures is None: |
| 105 | + raise RuntimeError(f"{self.architectures=} is None!") |
| 106 | + self.architectures = _normalize_sycl_arch(self.architectures) |
| 107 | + |
| 108 | + |
| 109 | +@clear_on_fresh_inductor_cache |
| 110 | +@functools.lru_cache(None) |
| 111 | +def _gen_ops_cached(arch) -> list[Any]: |
| 112 | + # Import cutlass python scripts. |
| 113 | + assert try_import_cutlass() |
| 114 | + import cutlass_library.generator as cutlass_generator |
| 115 | + import cutlass_library.manifest as cutlass_manifest |
| 116 | + |
| 117 | + if arch is None: |
| 118 | + log.error( |
| 119 | + "Cannot detect XPU arch %s." |
| 120 | + "Will discard all cutlass ops. " |
| 121 | + "Please consider setting _inductor.xpu.arch", |
| 122 | + arch, |
| 123 | + ) |
| 124 | + return [] |
| 125 | + arch = _normalize_sycl_arch(arch) |
| 126 | + |
| 127 | + sycl_version = "2025.0.1" # Placeholder, Unused in GeneratePVC |
| 128 | + |
| 129 | + args = CUTLASSArgs( |
| 130 | + architectures=arch, |
| 131 | + instantiation_level="0", # TODO (SYCL) : Make it config param once enabled in cutlass_library/generator.py |
| 132 | + cuda_version=sycl_version, |
| 133 | + ) |
| 134 | + manifest = cutlass_manifest.Manifest(args) |
| 135 | + |
| 136 | + if arch == "11": |
| 137 | + cutlass_generator.GeneratePVC(manifest, sycl_version) |
| 138 | + else: |
| 139 | + log.error("Invalid XPU arch") |
| 140 | + return [] |
| 141 | + return manifest.operations |
| 142 | + |
| 143 | + |
| 144 | +def gen_ops() -> list[Any]: |
| 145 | + """ |
| 146 | + Generates all supported CUTLASS operations. |
| 147 | + """ |
| 148 | + # Currently limited to PVC (arch 1100), harcoding arch |
| 149 | + # TODO :(SYCL) get_xpu_arch() |
| 150 | + arch = "11" |
| 151 | + return _gen_ops_cached(arch) |
| 152 | + |
| 153 | + |
| 154 | +def torch_dtype_to_cutlass_type( |
| 155 | + torch_dtype: torch.dtype, |
| 156 | +) -> "cutlass_library.library.DataType": # type: ignore[name-defined] # noqa: F821 |
| 157 | + # Import cutlass python scripts. |
| 158 | + assert try_import_cutlass() |
| 159 | + import cutlass_library # type: ignore[import] |
| 160 | + |
| 161 | + if torch_dtype == torch.float: |
| 162 | + return cutlass_library.library.DataType.f32 |
| 163 | + elif torch_dtype == torch.half: |
| 164 | + return cutlass_library.library.DataType.f16 |
| 165 | + elif torch_dtype == torch.bfloat16: |
| 166 | + return cutlass_library.library.DataType.bf16 |
| 167 | + else: |
| 168 | + raise NotImplementedError(f"Unsupported data type: {torch_dtype=}") |
| 169 | + |
| 170 | + |
| 171 | +def dtype_match( |
| 172 | + torch_dtype: Optional[torch.dtype], |
| 173 | + cutlass_dtype: "cutlass_library.library.DataType", # type: ignore[name-defined] # noqa: F821 |
| 174 | +) -> bool: |
| 175 | + # Import cutlass python scripts. |
| 176 | + assert try_import_cutlass() |
| 177 | + import cutlass_library |
| 178 | + |
| 179 | + if torch_dtype == torch.float: |
| 180 | + return cutlass_dtype == cutlass_library.library.DataType.f32 |
| 181 | + elif torch_dtype == torch.half: |
| 182 | + return cutlass_dtype == cutlass_library.library.DataType.f16 |
| 183 | + elif torch_dtype == torch.bfloat16: |
| 184 | + return cutlass_dtype == cutlass_library.library.DataType.bf16 |
| 185 | + elif torch_dtype == torch.int8: |
| 186 | + return cutlass_dtype == cutlass_library.library.DataType.s8 |
| 187 | + elif torch_dtype == torch.uint8: |
| 188 | + return cutlass_dtype == cutlass_library.library.DataType.u8 |
| 189 | + elif torch_dtype == torch.int32: |
| 190 | + return cutlass_dtype == cutlass_library.library.DataType.s32 |
| 191 | + else: |
| 192 | + return False |
| 193 | + |
| 194 | + |
| 195 | +def get_accumulator_dtype( |
| 196 | + input_torch_dtypes: list[torch.dtype], |
| 197 | +) -> Optional[torch.dtype]: |
| 198 | + """ |
| 199 | + Given a pair of input torch dtypes, returns the inferred accumulator torch dtype. |
| 200 | + """ |
| 201 | + # TODO (SYCL) : Extend this once other accumulator & input types are supported |
| 202 | + if len(input_torch_dtypes) != 2: |
| 203 | + return None |
| 204 | + |
| 205 | + if all(dtype == torch.bfloat16 for dtype in input_torch_dtypes): |
| 206 | + return torch.float |
| 207 | + else: |
| 208 | + raise NotImplementedError(f"Unsupported data types: {input_torch_dtypes}") |
| 209 | + |
| 210 | + |
| 211 | +def get_alignments(torch_dtype: torch.dtype) -> list[int]: |
| 212 | + """ |
| 213 | + Returns all possible valid CUTLASS alignments in terms of the number of elements for a given dtype. |
| 214 | + """ |
| 215 | + # TODO (SYCL): Extend for other types & double-check alignments |
| 216 | + if torch_dtype == torch.bfloat16: |
| 217 | + return [8, 4, 2, 1] |
| 218 | + elif torch_dtype == torch.float: |
| 219 | + return [4, 2, 1] |
| 220 | + else: |
| 221 | + raise NotImplementedError(f"unsupported {torch_dtype=} for alignments") |
| 222 | + |
| 223 | + |
| 224 | +def get_max_alignment(inductor_layout: Layout) -> int: |
| 225 | + """ |
| 226 | + Returns the max alignment (in terms of number of elements) for a given Inductor Layout. |
| 227 | + """ |
| 228 | + |
| 229 | + dtype = inductor_layout.dtype |
| 230 | + size = inductor_layout.size |
| 231 | + offset = inductor_layout.offset |
| 232 | + |
| 233 | + def is_static_int(number): |
| 234 | + return isinstance(number, (int, sympy.Integer)) |
| 235 | + |
| 236 | + def a_factor_of(x, alignment): |
| 237 | + if is_static_int(x) and is_static_int(alignment): |
| 238 | + return x % alignment == 0 |
| 239 | + rem = sympy.Mod(x, alignment) |
| 240 | + return V.graph.sizevars.evaluate_expr(sympy.Eq(rem, 0)) |
| 241 | + |
| 242 | + try: |
| 243 | + contiguous_dim = inductor_layout.stride.index(1) |
| 244 | + except ValueError: |
| 245 | + # No dim with stride 1 found, return 1 |
| 246 | + return 1 |
| 247 | + alignments = get_alignments(dtype) |
| 248 | + for alignment in alignments: |
| 249 | + if not a_factor_of(size[contiguous_dim], alignment) or not a_factor_of( |
| 250 | + offset, alignment |
| 251 | + ): |
| 252 | + continue |
| 253 | + if all( |
| 254 | + (dim == contiguous_dim) |
| 255 | + or a_factor_of(inductor_layout.stride[dim], alignment) |
| 256 | + for dim in range(len(size)) |
| 257 | + ): |
| 258 | + return alignment |
| 259 | + return 1 |
| 260 | + |
| 261 | + |
| 262 | +# TODO (SYCL) : Add helpers for CUTLASS kernels testing & benchmarking once standalone |
| 263 | +# runner is enabled. |
0 commit comments