Skip to content

Commit 693ad86

Browse files
committed
Move cutlass_dir to general config
Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent 8ab8cba commit 693ad86

File tree

5 files changed

+18
-17
lines changed

5 files changed

+18
-17
lines changed

torch/_inductor/codecache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2944,7 +2944,7 @@ def _cutlass_include_paths() -> list[str]:
29442944

29452945
cutlass_path = parutil.get_dir_path("cutlass-3-headers")
29462946
else:
2947-
cutlass_path = config.cuda.cutlass_dir
2947+
cutlass_path = config.cutlass_dir
29482948
return [
29492949
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
29502950
os.path.realpath(os.path.join(cutlass_path, "include")),

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def try_import_cutlass() -> bool:
5353
log.warning("CUTLASS version < 3.7 is not recommended.")
5454

5555
log.debug(
56-
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
56+
"Found cutlass_library in python search path, overriding config.cutlass_dir"
5757
)
5858
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
5959
assert os.path.isdir(cutlass_library_dir), (
6060
f"{cutlass_library_dir} is not a directory"
6161
)
62-
config.cuda.cutlass_dir = os.path.abspath(
62+
config.cutlass_dir = os.path.abspath(
6363
os.path.join(
6464
cutlass_library_dir,
6565
"source",
@@ -68,15 +68,15 @@ def try_import_cutlass() -> bool:
6868
return True
6969
except ModuleNotFoundError:
7070
log.debug(
71-
"cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir"
71+
"cutlass_library not found in sys.path, trying to import from config.cutlass_dir"
7272
)
7373

7474
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
7575
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7676
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7777

7878
cutlass_py_full_path = os.path.abspath(
79-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
79+
os.path.join(config.cutlass_dir, "python/cutlass_library")
8080
)
8181
tmp_cutlass_py_full_path = os.path.abspath(
8282
os.path.join(cache_dir(), "torch_cutlass_library")

torch/_inductor/config.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,16 @@ def decide_compile_threads() -> int:
871871
# Adds NVTX annotations aroung training phases
872872
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
873873

874+
# Path to the CUTLASS repo root directory.
875+
# The default path only works under PyTorch local development environment.
876+
cutlass_dir = os.environ.get(
877+
"TORCHINDUCTOR_CUTLASS_DIR",
878+
os.path.abspath(
879+
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
880+
),
881+
)
882+
883+
874884

875885
# config specific to codegen/cpp.py
876886
class cpp:
@@ -1280,15 +1290,6 @@ class cuda:
12801290
# Whether to use fast math.
12811291
use_fast_math = False
12821292

1283-
# Path to the CUTLASS repo root directory.
1284-
# The default path only works under PyTorch local development environment.
1285-
cutlass_dir = os.environ.get(
1286-
"TORCHINDUCTOR_CUTLASS_DIR",
1287-
os.path.abspath(
1288-
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
1289-
),
1290-
)
1291-
12921293
# Configures the maximum number of CUTLASS configs to profile in max_autotune.
12931294
# By default it's None, so that all CUTLASS configs are tuned.
12941295
# This is mainly used to reduce test time in CI.
@@ -1534,7 +1535,7 @@ class trace:
15341535
# trace functions are not relevant to config caching
15351536
"trace",
15361537
# uses absolute path
1537-
"cuda.cutlass_dir",
1538+
"cutlass_dir",
15381539
# not relevant
15391540
"worker_start_method",
15401541
"compile_threads",

torch/_inductor/fuzzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def keys(self) -> KeysView[ComboType]:
469469
"aot_inductor.presets": DEFAULT, # Typing
470470
"cuda.arch": DEFAULT, # Out of Scope
471471
"cuda.version": DEFAULT, # Out of Scope
472-
"cuda.cutlass_dir": DEFAULT, # Out of Scope
472+
"cutlass_dir": DEFAULT, # Out of Scope
473473
"cuda.cuda_cxx": DEFAULT, # Out of Scope
474474
"rocm.arch": DEFAULT, # Out of Scope
475475
"rocm.ck_supported_arch": DEFAULT, # Out of Scope

torch/_inductor/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
14051405
if not try_import_cutlass():
14061406
log.warning(
14071407
"Failed to import CUTLASS lib. Please check whether "
1408-
"_inductor.config.cuda.cutlass_dir is set correctly. "
1408+
"_inductor.config.cutlass_dir is set correctly. "
14091409
"Skipping CUTLASS backend for now."
14101410
)
14111411
return False

0 commit comments

Comments
 (0)