Skip to content

Commit

Permalink
Add raise_on_domain_mismatch parameter to rename_iname (#800)
Browse files Browse the repository at this point in the history
* Add raise_on_domain_mismatch parameter to rename_iname

Adds keyword parameter to `rename_iname` so raise_on_domain_mismatch can be passed into the call to `rename_inames`

* Fix doc string

* remove indentation in doc string

* Centralize rename_inames default logic

* Add Optional[bool] typing
  • Loading branch information
nchristensen authored Aug 13, 2023
1 parent a4840e0 commit 250758b
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions loopy/transform/iname.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from loopy.kernel import LoopKernel
from loopy.kernel.function_interface import CallableKernel

from typing import Optional

__doc__ = """
.. currentmodule:: loopy
Expand Down Expand Up @@ -2368,7 +2369,7 @@ def add_inames_for_unused_hw_axes(kernel, within=None):
@for_each_kernel
@remove_any_newly_unused_inames
def rename_inames(kernel, old_inames, new_iname, existing_ok=False,
within=None, raise_on_domain_mismatch: bool = __debug__):
within=None, raise_on_domain_mismatch: Optional[bool] = None):
r"""
:arg old_inames: A collection of inames that must be renamed to **new_iname**.
:arg within: a stack match as understood by
Expand Down Expand Up @@ -2396,6 +2397,9 @@ def rename_inames(kernel, old_inames, new_iname, existing_ok=False,
raise LoopyError("old_inames contains nested inames"
" -- renaming is illegal.")

if raise_on_domain_mismatch is None:
raise_on_domain_mismatch = __debug__

# sort to have deterministic implementation.
old_inames = sorted(old_inames)

Expand Down Expand Up @@ -2504,18 +2508,23 @@ def does_insn_involve_iname(kernel, insn, *args):

@for_each_kernel
def rename_iname(kernel, old_iname, new_iname, existing_ok=False,
within=None, preserve_tags=True):
"""
within=None, preserve_tags=True,
raise_on_domain_mismatch: Optional[bool] = None):
r"""
Single iname version of :func:`loopy.rename_inames`.
:arg existing_ok: execute even if *new_iname* already exists
:arg existing_ok: execute even if *new_iname* already exists.
:arg within: a stack match understood by :func:`loopy.match.parse_stack_match`.
:arg preserve_tags: copy the tags on the old iname to the new iname
:arg preserve_tags: copy the tags on the old iname to the new iname.
:arg raise_on_domain_mismatch: If *True*, raises an error if
:math:`\exists (i_1,i_2) \in \{\text{old\_inames}\}^2 |
\mathcal{D}_{i_1} \neq \mathcal{D}_{i_2}`.
"""
from itertools import product
from loopy import tag_inames

tags = kernel.inames[old_iname].tags
kernel = rename_inames(kernel, [old_iname], new_iname, existing_ok, within)
kernel = rename_inames(kernel, [old_iname], new_iname, existing_ok,
within, raise_on_domain_mismatch)
if preserve_tags:
kernel = tag_inames(kernel, product([new_iname], tags))
return kernel
Expand Down

0 comments on commit 250758b

Please sign in to comment.