Skip to content

Commit

Permalink
add EqualityMapper to follow pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Jun 30, 2022
1 parent ccc8cbf commit 8ea4f38
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 23 deletions.
89 changes: 67 additions & 22 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
CSECachingMapperMixin,
)
import immutables
from pymbolic.mapper.equality import (
EqualityMapper as EqualityMapperBase)
from pymbolic.mapper.evaluator import \
CachedEvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.substitutor import \
Expand Down Expand Up @@ -502,6 +504,60 @@ def map_substitution(self, name, rule, arguments):

return self.rec(expr)


class EqualityMapper(EqualityMapperBase):
def map_loopy_function_identifier(self, expr, other) -> bool:
return True

def map_reduction(self, expr, other) -> bool:
return (
expr.operation == other.operation
and expr.allow_simultaneous == other.allow_simultaneous
and self.rec(expr.expr, other.expr)
and all(iname == other_iname
for iname, other_iname in zip(expr.inames, other.inames)))

def map_group_hw_index(self, expr, other) -> bool:
return expr.axis == other.axis

map_local_hw_index = map_group_hw_index

def map_linear_subscript(self, expr, other) -> bool:
return (
self.rec(expr.index, other.index)
and self.rec(expr.aggregate, other.aggregate))

def map_rule_argument(self, expr, other) -> bool:
return expr.index == other.index

def map_resolved_function(self, expr, other) -> bool:
return self.rec(expr.function, other.function)

def map_sub_array_ref(self, expr, other) -> bool:
return (
len(expr.swept_inames) == len(other.swept_inames)
and self.rec(expr.subscript, other.subscript)
and all(self.rec(iname, other_iname)
for iname, other_iname in zip(
expr.swept_inames,
other.swept_inames))
)

def map_tagged_variable(self, expr, other) -> bool:
return (
expr.name == other.name
and all(tag == other_tag
for tag, other_tag in zip(expr.tags, other.tags))
)

def map_type_cast(self, expr, other) -> bool:
return (
expr.type == other.type
and self.rec(expr.child, other.child))

def map_fortran_division(self, expr, other) -> bool:
return self.map_quotient(expr, other)

# }}}


Expand All @@ -515,15 +571,18 @@ def stringifier(self):
def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

def make_equality_mapper(self):
return EqualityMapper()


class Literal(LoopyExpressionBase):
"""A literal to be used during code generation.
.. note::
Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, s):
Expand All @@ -543,8 +602,8 @@ class ArrayLiteral(LoopyExpressionBase):
.. note::
Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, children):
Expand Down Expand Up @@ -573,8 +632,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
.. note::
Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_group_hw_index"

Expand All @@ -584,8 +643,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
.. note::
Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_local_hw_index"

Expand Down Expand Up @@ -792,12 +851,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.operation, self.inames, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.operation == self.operation
and other.inames == self.inames
and other.expr == self.expr)

@property
def is_tuple_typed(self):
return self.operation.arg_count > 1
Expand Down Expand Up @@ -994,14 +1047,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.swept_inames, self.subscript))

def is_equal(self, other):
"""
Returns *True* iff the sub-array refs have identical expressions.
"""
return (other.__class__ == self.__class__
and other.subscript == self.subscript
and other.swept_inames == self.swept_inames)

def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/cgen.git#egg=cgen
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/codepy.git#egg=codepy

Expand Down

0 comments on commit 8ea4f38

Please sign in to comment.