Skip to content

Commit

Permalink
PytatoPyOpenCLArrayContext: untag any info. on NamedArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Jun 11, 2022
1 parent 1ca4d5a commit dff7d4d
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,27 @@ def transform_loopy_program(self, t_unit):
# {{{ pytato pyopencl array context subclass

class PytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContextBase):
def transform_dag(self, dag):
dag = super().transform_dag(dag)

# {{{ /!\ Remove tags from NamedArrays
# See <https://www.github.com/inducer/pytato/issues/195>

import pytato as pt

def untag_loopy_call_results(expr):
if isinstance(expr, pt.NamedArray):
return expr.copy(tags=frozenset(),
axes=(pt.Axis(frozenset()),)*expr.ndim)
else:
return expr

dag = pt.transform.map_and_copy(dag, untag_loopy_call_results)

# }}}

return dag

def transform_loopy_program(self, t_unit):
# FIXME: Do not parallelize for now.
return t_unit
Expand Down

0 comments on commit dff7d4d

Please sign in to comment.