Skip to content

Commit 484d217

Browse files
committed
Move cutlass_dir to general config
Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
1 parent 30ae902 commit 484d217

File tree

5 files changed

+18
-17
lines changed

5 files changed

+18
-17
lines changed

torch/_inductor/codecache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2797,7 +2797,7 @@ def _cutlass_include_paths() -> list[str]:
27972797

27982798
cutlass_path = parutil.get_dir_path("cutlass-3-headers")
27992799
else:
2800-
cutlass_path = config.cuda.cutlass_dir
2800+
cutlass_path = config.cutlass_dir
28012801
return [
28022802
# Use realpath to get canonical absolute paths, in order not to mess up cache keys
28032803
os.path.realpath(os.path.join(cutlass_path, "include")),

torch/_inductor/codegen/cuda/cutlass_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,13 @@ def try_import_cutlass() -> bool:
5353
log.warning("CUTLASS version < 3.7 is not recommended.")
5454

5555
log.debug(
56-
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
56+
"Found cutlass_library in python search path, overriding config.cutlass_dir"
5757
)
5858
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
5959
assert os.path.isdir(cutlass_library_dir), (
6060
f"{cutlass_library_dir} is not a directory"
6161
)
62-
config.cuda.cutlass_dir = os.path.abspath(
62+
config.cutlass_dir = os.path.abspath(
6363
os.path.join(
6464
cutlass_library_dir,
6565
"source",
@@ -68,15 +68,15 @@ def try_import_cutlass() -> bool:
6868
return True
6969
except ModuleNotFoundError:
7070
log.debug(
71-
"cutlass_library not found in sys.path, trying to import from config.cuda.cutlass_dir"
71+
"cutlass_library not found in sys.path, trying to import from config.cutlass_dir"
7272
)
7373

7474
# Copy CUTLASS python scripts to a temp dir and add the temp dir to Python search path.
7575
# This is a temporary hack to avoid CUTLASS module naming conflicts.
7676
# TODO(ipiszy): remove this hack when CUTLASS solves Python scripts packaging structure issues.
7777

7878
cutlass_py_full_path = os.path.abspath(
79-
os.path.join(config.cuda.cutlass_dir, "python/cutlass_library")
79+
os.path.join(config.cutlass_dir, "python/cutlass_library")
8080
)
8181
tmp_cutlass_py_full_path = os.path.abspath(
8282
os.path.join(cache_dir(), "torch_cutlass_library")

torch/_inductor/config.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,16 @@ def decide_compile_threads() -> int:
866866
# Adds NVTX annotations aroung training phases
867867
annotate_training: bool = os.environ.get("TORCHINDUCTOR_ANNOTATE_TRAINING", "0") == "1"
868868

869+
# Path to the CUTLASS repo root directory.
870+
# The default path only works under PyTorch local development environment.
871+
cutlass_dir = os.environ.get(
872+
"TORCHINDUCTOR_CUTLASS_DIR",
873+
os.path.abspath(
874+
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
875+
),
876+
)
877+
878+
869879

870880
# config specific to codegen/cpp.py
871881
class cpp:
@@ -1272,15 +1282,6 @@ class cuda:
12721282
# Whether to use fast math.
12731283
use_fast_math = False
12741284

1275-
# Path to the CUTLASS repo root directory.
1276-
# The default path only works under PyTorch local development environment.
1277-
cutlass_dir = os.environ.get(
1278-
"TORCHINDUCTOR_CUTLASS_DIR",
1279-
os.path.abspath(
1280-
os.path.join(os.path.dirname(torch.__file__), "../third_party/cutlass/")
1281-
),
1282-
)
1283-
12841285
# Configures the maximum number of CUTLASS configs to profile in max_autotune.
12851286
# By default it's None, so that all CUTLASS configs are tuned.
12861287
# This is mainly used to reduce test time in CI.
@@ -1526,7 +1527,7 @@ class trace:
15261527
# trace functions are not relevant to config caching
15271528
"trace",
15281529
# uses absolute path
1529-
"cuda.cutlass_dir",
1530+
"cutlass_dir",
15301531
# not relevant
15311532
"worker_start_method",
15321533
"compile_threads",

torch/_inductor/fuzzer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,7 @@ def keys(self) -> KeysView[ComboType]:
469469
"aot_inductor.presets": DEFAULT, # Typing
470470
"cuda.arch": DEFAULT, # Out of Scope
471471
"cuda.version": DEFAULT, # Out of Scope
472-
"cuda.cutlass_dir": DEFAULT, # Out of Scope
472+
"cutlass_dir": DEFAULT, # Out of Scope
473473
"cuda.cuda_cxx": DEFAULT, # Out of Scope
474474
"rocm.arch": DEFAULT, # Out of Scope
475475
"rocm.ck_supported_arch": DEFAULT, # Out of Scope

torch/_inductor/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,7 @@ def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool:
13901390
if not try_import_cutlass():
13911391
log.warning(
13921392
"Failed to import CUTLASS lib. Please check whether "
1393-
"_inductor.config.cuda.cutlass_dir is set correctly. "
1393+
"_inductor.config.cutlass_dir is set correctly. "
13941394
"Skipping CUTLASS backend for now."
13951395
)
13961396
return False

0 commit comments

Comments
 (0)