Skip to content

[knobs] Fix environment propagation & scope() API #6664

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/test/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,31 @@ def _fresh_knobs_impl(monkeypatch, skipped_attr: Optional[Set[str]] = None):
if isinstance(knobset, knobs.base_knobs) and knobset != knobs.base_knobs and name not in skipped_attr
}

# We store which variables we need to unset below in finally because
# monkeypatch doesn't appear to reset variables that were never set
# before the monkeypatch.delenv call below.
env_to_unset = []
prev_propagate_env = knobs.propagate_env

def fresh_function():
nonlocal env_to_unset
for name, knobset in knobs_map.items():
setattr(knobs, name, knobset.copy().reset())
for knob in knobset.knob_descriptors.values():
monkeypatch.delenv(knob.key, raising=False)
if knob.key in os.environ:
monkeypatch.delenv(knob.key, raising=False)
else:
env_to_unset.append(knob.key)
knobs.propagate_env = True
return knobs

def reset_function():
for name, knobset in knobs_map.items():
setattr(knobs, name, knobset)
for k in env_to_unset:
if k in os.environ:
del os.environ[k]
knobs.propagate_env = prev_propagate_env

return fresh_function, reset_function

Expand Down
34 changes: 31 additions & 3 deletions python/test/unit/test_knobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from pathlib import Path


def test_knobs_utils() -> None:
def test_knobs_utils(fresh_knobs) -> None:
triton.knobs.propagate_env = False

class test_knobs(triton.knobs.base_knobs):
foo: triton.knobs.env_str = triton.knobs.env_str("FOO", "triton")
Expand Down Expand Up @@ -68,10 +69,12 @@ class test_knobs(triton.knobs.base_knobs):


def test_knobs_scope(fresh_knobs, monkeypatch):
monkeypatch.setenv("TRITON_HIP_LOCAL_PREFETCH", "17")
fresh_knobs.amd.global_prefetch = 4
fresh_knobs.amd.local_prefetch = 3

# Update env *after* the __set__() does
monkeypatch.setenv("TRITON_HIP_LOCAL_PREFETCH", "17")

assert fresh_knobs.amd.global_prefetch == 4
assert fresh_knobs.amd.local_prefetch == 3
assert fresh_knobs.amd.use_buffer_ops
Expand Down Expand Up @@ -103,6 +106,17 @@ def test_knobs_scope(fresh_knobs, monkeypatch):
assert fresh_knobs.amd.use_buffer_ops


def test_env_updated(fresh_knobs, monkeypatch):
fresh_knobs.amd.use_buffer_ops = False
assert os.getenv("AMDGCN_USE_BUFFER_OPS") == "0"
# Just triple checking both APIs give us what we expect
assert os.environ["AMDGCN_USE_BUFFER_OPS"] == "0"

fresh_knobs.cache.home_dir = "/foo/bar"
assert os.getenv("TRITON_HOME") == "/foo/bar"
assert os.environ["TRITON_HOME"] == "/foo/bar"


@pytest.mark.parametrize("truthy, falsey", [("1", "0"), ("true", "false"), ("True", "False"), ("TRUE", "FALSE"),
("y", "n"), ("YES", "NO"), ("ON", "OFF")])
def test_read_env(truthy, falsey, fresh_knobs, monkeypatch):
Expand Down Expand Up @@ -170,13 +184,18 @@ def test_set_knob_directly(fresh_knobs, monkeypatch):
monkeypatch.setenv("TRITON_CACHE_DIR", "/tmp/other_triton_cache")
assert fresh_knobs.cache.dir == "/tmp/triton_cache"

# Disable propagation to verify resetting/del behavior
triton.knobs.propagate_env = False

fresh_knobs.cache.dir = fresh_knobs.env
assert fresh_knobs.cache.dir == "/tmp/other_triton_cache"

fresh_knobs.cache.dir = "/tmp/triton_cache"
fresh_knobs.cache.reset()
assert fresh_knobs.cache.dir == "/tmp/other_triton_cache"

triton.knobs.propagate_env = True

# Just in case, lets check all the other datatypes too
fresh_knobs.language.default_fp_fusion = False
fresh_knobs.amd.use_block_pingpong = True
Expand Down Expand Up @@ -226,20 +245,29 @@ def test_nvidia_tool(fresh_knobs, tmp_path, monkeypatch):
triton_root = Path(fresh_knobs.__file__).parent
default_ptxas = triton_root / "backends/nvidia/bin/ptxas"

assert default_ptxas.exists()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CliveUnger I think nixing this assert should address your concern, since everything else is just validating Paths/strings (follow up from here for any other readers).

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea this looks good!

assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == default_ptxas.resolve()

tmp_ptxas = tmp_path / "ptxas-special"
shutil.copy(default_ptxas, tmp_ptxas)
monkeypatch.setenv("TRITON_PTXAS_PATH", str(tmp_ptxas))
assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == tmp_ptxas.resolve()

# Don't prop so that the `del` is correctly tested
fresh_knobs.propagate_env = False
fresh_knobs.nvidia.ptxas = str(default_ptxas)
fresh_knobs.propagate_env = True
assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == default_ptxas.resolve()

del fresh_knobs.nvidia.ptxas
assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == tmp_ptxas.resolve()

# Triple check scope works
with fresh_knobs.nvidia.scope():
fresh_knobs.nvidia.ptxas = str(default_ptxas)
assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == default_ptxas.resolve()

assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == tmp_ptxas.resolve()

monkeypatch.delenv("TRITON_PTXAS_PATH")
assert Path(fresh_knobs.nvidia.ptxas.path).resolve() == default_ptxas.resolve()

Expand Down
57 changes: 44 additions & 13 deletions python/triton/knobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,41 @@ class Env:

env = Env()

propagate_env: bool = True


def getenv(key: str) -> Optional[str]:
res = os.getenv(key)
return res.strip() if res is not None else res


def setenv(key: str, value: Optional[str]) -> None:
if not propagate_env:
return

if value is not None:
os.environ[key] = value
elif key in os.environ:
del os.environ[key]


def toenv(val: Any) -> Union[None, tuple[Optional[str]]]:
if val is None:
return (None, )

t = type(val)
if t is bool:
return ("1" if val else "0", )

if t is str:
return (val, )

if t is int:
return (str(val), )

return None


# There's an asymmetry here so that e.g. env_nvidia_tool can be specified with a
# a string but return an NvidiaTool.
SetType = TypeVar("SetType")
Expand All @@ -52,16 +81,21 @@ def __get__(self, obj: Optional[object], objclass: Optional[Type[object]]) -> Ge
else:
return self.get()

@property
def env_val(self) -> str | None:
return getenv(self.key)

def get(self) -> GetType:
env = getenv(self.key)
env = self.env_val
return self.transform(self.default() if env is None else self.from_env(env))

def __set__(self, obj: object, value: Union[SetType, Env]) -> None:
if isinstance(value, Env):
obj.__dict__.pop(self.name, None)
else:
obj.__dict__[self.name] = value
self.set(value)
if env_val := toenv(value):
setenv(self.key, env_val[0])
Comment on lines +97 to +98
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost exactly the same impl you gave @peterbell10 , but needed a function to convert the value (e.g. "1" instead of "True" and to ignore when e.g. value is a class for env_class)


def __delete__(self, obj: object) -> None:
obj.__dict__.pop(self.name, None)
Expand All @@ -71,21 +105,12 @@ def transform(self, val: SetType) -> GetType:
# if GetType != SetType.
return cast(GetType, val)

def set(self, val: SetType) -> None:
pass

def from_env(self, val: str) -> SetType:
raise NotImplementedError()


class env_str(env_base[str, str]):

def set(self, value: Optional[str]) -> None:
if value is None:
os.unsetenv(self.key)
else:
os.putenv(self.key, value)

def from_env(self, val: str) -> str:
return val

Expand Down Expand Up @@ -253,8 +278,7 @@ def knobs(self) -> dict[str, Any]:

def copy(self: knobs_type) -> knobs_type:
res = type(self)()
for k, v in self.__dict__.items():
res.__dict__[k] = v
res.__dict__.update(self.__dict__)
return res

def reset(self: knobs_type) -> knobs_type:
Expand All @@ -265,12 +289,19 @@ def reset(self: knobs_type) -> knobs_type:
@contextmanager
def scope(self) -> Generator[None, None, None]:
try:
initial_env = {knob.key: knob.env_val for knob in self.knob_descriptors.values()}
orig = dict(self.__dict__)
yield
finally:
self.__dict__.clear()
self.__dict__.update(orig)

for k, v in initial_env.items():
if v is not None:
os.environ[k] = v
elif k in os.environ:
del os.environ[k]


class build_knobs(base_knobs):
"""Configuration controlling how the native compiler is invoked"""
Expand Down
Loading