Skip to content

Commit

Permalink
add EqualityMapper to follow pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed May 4, 2022
1 parent befa5cb commit fda3bdc
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 41 deletions.
124 changes: 84 additions & 40 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,25 +40,28 @@
CallbackMapper as CallbackMapperBase,
CSECachingMapperMixin,
)
from pymbolic.mapper.evaluator import \
EvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.substitutor import \
SubstitutionMapper as SubstitutionMapperBase
from pymbolic.mapper.stringifier import \
StringifyMapper as StringifyMapperBase
from pymbolic.mapper.dependency import \
DependencyMapper as DependencyMapperBase
from pymbolic.mapper.coefficient import \
CoefficientCollector as CoefficientCollectorBase
from pymbolic.mapper.unifier import UnidirectionalUnifier \
as UnidirectionalUnifierBase
from pymbolic.mapper.constant_folder import \
ConstantFoldingMapper as ConstantFoldingMapperBase
from pymbolic.mapper.equality import (
EqualityMapper as EqualityMapperBase)
from pymbolic.mapper.evaluator import (
EvaluationMapper as EvaluationMapperBase)
from pymbolic.mapper.substitutor import (
SubstitutionMapper as SubstitutionMapperBase)
from pymbolic.mapper.stringifier import (
StringifyMapper as StringifyMapperBase)
from pymbolic.mapper.dependency import (
DependencyMapper as DependencyMapperBase)
from pymbolic.mapper.coefficient import (
CoefficientCollector as CoefficientCollectorBase)
from pymbolic.mapper.unifier import (
UnidirectionalUnifier as UnidirectionalUnifierBase)
from pymbolic.mapper.constant_folder import (
ConstantFoldingMapper as ConstantFoldingMapperBase)

from pymbolic.parser import Parser as ParserBase
from loopy.diagnostic import LoopyError
from loopy.diagnostic import (ExpressionToAffineConversionError,
UnableToDetermineAccessRangeError)
from loopy.diagnostic import (
ExpressionToAffineConversionError,
UnableToDetermineAccessRangeError)


import islpy as isl
Expand Down Expand Up @@ -114,8 +117,11 @@ def map_literal(self, expr, *args, **kwargs):
return expr

def map_array_literal(self, expr, *args, **kwargs):
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
for ch in expr.children))
children = [self.rec(ch, *args, **kwargs) for ch in expr.children]
if all(ch is orig for ch, orig in zip(children, expr.children)):
return expr

return type(expr)(tuple(children))

def map_group_hw_index(self, expr, *args, **kwargs):
return expr
Expand Down Expand Up @@ -484,6 +490,55 @@ def map_substitution(self, name, rule, arguments):

return self.rec(expr)


class EqualityMapper(EqualityMapperBase):
def map_loopy_function_identifier(self, expr, other) -> bool:
return True

def map_reduction(self, expr, other) -> bool:
return (
expr.operation == other.operation
and expr.allow_simultaneous == other.allow_simultaneous
and self.rec(expr.expr, other.expr)
and all(iname == other_iname
for iname, other_iname in zip(expr.inames, other.inames)))

def map_group_hw_index(self, expr, other) -> bool:
return expr.axis == other.axis

map_local_hw_index = map_group_hw_index

def map_rule_argument(self, expr, other) -> bool:
return expr.index == other.index

def map_resolved_function(self, expr, other) -> bool:
return self.rec(expr.function, other.function)

def map_sub_array_ref(self, expr, other) -> bool:
return (
len(expr.swept_inames) == len(other.swept_inames)
and self.rec(expr.subscript, other.subscript)
and all(self.rec(iname, other_iname)
for iname, other_iname in zip(
expr.swept_inames,
other.swept_inames))
)

def map_tagged_variable(self, expr, other) -> bool:
return (
expr.name == other.name
and all(tag == other_tag
for tag, other_tag in zip(expr.tags, other.tags))
)

def map_type_cast(self, expr, other) -> bool:
return (
expr.type == other.type
and self.rec(expr.child, other.child))

def map_fortran_division(self, expr, other) -> bool:
return self.map_quotient(expr, other)

# }}}


Expand All @@ -497,15 +552,18 @@ def stringifier(self):
def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

def make_equality_mapper(self):
return EqualityMapper()


class Literal(LoopyExpressionBase):
"""A literal to be used during code generation.
.. note::
Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, s):
Expand All @@ -525,8 +583,8 @@ class ArrayLiteral(LoopyExpressionBase):
.. note::
Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, children):
Expand Down Expand Up @@ -555,8 +613,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
.. note::
Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_group_hw_index"

Expand All @@ -566,8 +624,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
.. note::
Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_local_hw_index"

Expand Down Expand Up @@ -774,12 +832,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.operation, self.inames, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.operation == self.operation
and other.inames == self.inames
and other.expr == self.expr)

@property
def is_tuple_typed(self):
return self.operation.arg_count > 1
Expand Down Expand Up @@ -977,14 +1029,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.swept_inames, self.subscript))

def is_equal(self, other):
"""
Returns *True* iff the sub-array refs have identical expressions.
"""
return (other.__class__ == self.__class__
and other.subscript == self.subscript
and other.swept_inames == self.swept_inames)

def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/cgen.git#egg=cgen
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/codepy.git#egg=codepy

Expand Down

0 comments on commit fda3bdc

Please sign in to comment.