Skip to content

Commit 72be490

Browse files
authored
[Refactor] Refactor env into a more flexible version (#740)
* Fix environment variable name for compilation print setting in `env.py` * Remove deprecated test file for warp specialized pass configuration and refactor environment variable access in `env.py` to utilize a centralized `EnvVar` class for better management and clarity. * lint fix * Refactor cache check to use `env.is_cache_enabled()` for consistency in `tuner.py`
1 parent e3a80b7 commit 72be490

File tree

10 files changed

+254
-142
lines changed

10 files changed

+254
-142
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import tilelang
2+
import os
3+
4+
5+
def test_env_var():
6+
# test default value
7+
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
8+
# test forced value
9+
os.environ["TILELANG_PRINT_ON_COMPILATION"] = "0"
10+
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "0"
11+
# test forced value with class method
12+
tilelang.env.TILELANG_PRINT_ON_COMPILATION = "1"
13+
assert tilelang.env.TILELANG_PRINT_ON_COMPILATION == "1"
14+
15+
16+
if __name__ == "__main__":
17+
test_env_var()

tilelang/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ def _init_logger():
5353

5454
logger = logging.getLogger(__name__)
5555

56-
from .env import SKIP_LOADING_TILELANG_SO
5756
from .env import enable_cache, disable_cache, is_cache_enabled # noqa: F401
57+
from .env import env as env # noqa: F401
5858

5959
import tvm
6060
import tvm.base
@@ -76,12 +76,12 @@ def _load_tile_lang_lib():
7676

7777

7878
# only load once here
79-
if SKIP_LOADING_TILELANG_SO == "0":
79+
if env.SKIP_LOADING_TILELANG_SO == "0":
8080
_LIB, _LIB_PATH = _load_tile_lang_lib()
8181

8282
from .jit import jit, JITKernel, compile # noqa: F401
8383
from .profiler import Profiler # noqa: F401
84-
from .cache import cached # noqa: F401
84+
from .cache import clear_cache # noqa: F401
8585

8686
from .utils import (
8787
TensorSupplyType, # noqa: F401

tilelang/autotuner/tuner.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,7 @@
2525
import traceback
2626
from pathlib import Path
2727

28-
from tilelang.env import (
29-
TILELANG_CACHE_DIR,
30-
TILELANG_AUTO_TUNING_CPU_UTILITIES,
31-
TILELANG_AUTO_TUNING_CPU_COUNTS,
32-
TILELANG_AUTO_TUNING_MAX_CPU_COUNT,
33-
is_cache_enabled,
34-
)
28+
from tilelang import env
3529
from tilelang.autotuner.param import CompileArgs, ProfileArgs, AutotuneResult
3630
from tilelang.autotuner.capture import get_autotune_inputs
3731
from tilelang.jit.param import _P, _RProg
@@ -111,7 +105,7 @@ class AutoTuner:
111105
_kernel_parameters: Optional[Tuple[str, ...]] = None
112106
_lock = threading.Lock() # For thread safety
113107
_memory_cache = {} # In-memory cache dictionary
114-
cache_dir: Path = Path(TILELANG_CACHE_DIR) / "autotuner"
108+
cache_dir: Path = Path(env.TILELANG_CACHE_DIR) / "autotuner"
115109

116110
def __init__(self, fn: Callable, configs):
117111
self.fn = fn
@@ -285,7 +279,7 @@ def run(self, warmup: int = 25, rep: int = 100, timeout: int = 30):
285279
key = self.generate_cache_key(parameters)
286280

287281
with self._lock:
288-
if is_cache_enabled():
282+
if env.is_cache_enabled():
289283
# First check in-memory cache
290284
if key in self._memory_cache:
291285
logger.warning("Found kernel in memory cache. For better performance," \
@@ -437,9 +431,9 @@ def shape_equal(a, b):
437431
return autotuner_result
438432
# get the cpu count
439433
available_cpu_count = get_available_cpu_count()
440-
cpu_utilizations = float(TILELANG_AUTO_TUNING_CPU_UTILITIES)
441-
cpu_counts = int(TILELANG_AUTO_TUNING_CPU_COUNTS)
442-
max_cpu_count = int(TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
434+
cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES)
435+
cpu_counts = int(env.TILELANG_AUTO_TUNING_CPU_COUNTS)
436+
max_cpu_count = int(env.TILELANG_AUTO_TUNING_MAX_CPU_COUNT)
443437
if cpu_counts > 0:
444438
num_workers = min(cpu_counts, available_cpu_count)
445439
logger.info(
@@ -543,7 +537,7 @@ def device_wrapper(func, device, **config_arg):
543537
logger.warning("DLPack backend does not support cache saving to disk.")
544538
else:
545539
with self._lock:
546-
if is_cache_enabled():
540+
if env.is_cache_enabled():
547541
self._save_result_to_disk(key, autotuner_result)
548542

549543
self._memory_cache[key] = autotuner_result

tilelang/cache/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from tvm.target import Target
55
from tvm.tir import PrimFunc
66
from tilelang.jit import JITKernel
7+
from tilelang import env
78
from .kernel_cache import KernelCache
8-
from tilelang.env import TILELANG_CLEAR_CACHE
99

1010
# Create singleton instance of KernelCache
1111
_kernel_cache_instance = KernelCache()
@@ -44,5 +44,5 @@ def clear_cache():
4444
_kernel_cache_instance.clear_cache()
4545

4646

47-
if TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
47+
if env.TILELANG_CLEAR_CACHE.lower() in ("1", "true", "yes", "on"):
4848
clear_cache()

tilelang/cache/kernel_cache.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from tvm.tir import PrimFunc
1515

1616
from tilelang.engine.param import KernelParam
17-
from tilelang.env import TILELANG_CACHE_DIR, TILELANG_TMP_DIR, is_cache_enabled
17+
from tilelang import env
1818
from tilelang.jit import JITKernel
1919
from tilelang.version import __version__
2020

@@ -61,8 +61,8 @@ def __new__(cls):
6161

6262
@staticmethod
6363
def _create_dirs():
64-
os.makedirs(TILELANG_CACHE_DIR, exist_ok=True)
65-
os.makedirs(TILELANG_TMP_DIR, exist_ok=True)
64+
os.makedirs(env.TILELANG_CACHE_DIR, exist_ok=True)
65+
os.makedirs(env.TILELANG_TMP_DIR, exist_ok=True)
6666

6767
def _generate_key(
6868
self,
@@ -132,7 +132,7 @@ def cached(
132132
Returns:
133133
JITKernel: The compiled kernel, either freshly compiled or from cache
134134
"""
135-
if not is_cache_enabled():
135+
if not env.is_cache_enabled():
136136
return JITKernel(
137137
func,
138138
out_idx=out_idx,
@@ -190,7 +190,7 @@ def cached(
190190
self.logger.warning("DLPack backend does not support cache saving to disk.")
191191
else:
192192
with self._lock:
193-
if is_cache_enabled():
193+
if env.is_cache_enabled():
194194
self._save_kernel_to_disk(key, kernel, func, verbose)
195195

196196
# Store in memory cache after compilation
@@ -215,7 +215,7 @@ def _get_cache_path(self, key: str) -> str:
215215
Returns:
216216
str: Absolute path to the cache directory for this kernel.
217217
"""
218-
return os.path.join(TILELANG_CACHE_DIR, key)
218+
return os.path.join(env.TILELANG_CACHE_DIR, key)
219219

220220
@staticmethod
221221
def _load_binary(path: str):
@@ -226,7 +226,7 @@ def _load_binary(path: str):
226226
@staticmethod
227227
def _safe_write_file(path: str, mode: str, operation: Callable):
228228
# Random a temporary file within the same FS as the cache directory
229-
temp_path = os.path.join(TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
229+
temp_path = os.path.join(env.TILELANG_TMP_DIR, f"{os.getpid()}_{uuid.uuid4()}")
230230
with open(temp_path, mode) as temp_file:
231231
operation(temp_file)
232232

@@ -396,7 +396,7 @@ def _clear_disk_cache(self):
396396
"""
397397
try:
398398
# Delete the entire cache directory
399-
shutil.rmtree(TILELANG_CACHE_DIR)
399+
shutil.rmtree(env.TILELANG_CACHE_DIR)
400400

401401
# Re-create the cache directory
402402
KernelCache._create_dirs()

tilelang/contrib/nvcc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import subprocess
88
import warnings
9-
from ..env import CUDA_HOME
9+
from tilelang.env import CUDA_HOME
1010

1111
import tvm.ffi
1212
from tvm.target import Target

0 commit comments

Comments
 (0)