From b15fe77f774808d628de126516f519a0efdedacc Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Mon, 8 Nov 2021 09:59:06 -0600 Subject: [PATCH] add metadata to describe arrays' types --- meshmode/discretization/__init__.py | 28 ++++++- meshmode/discretization/connection/direct.py | 36 +++++--- meshmode/discretization/connection/face.py | 9 +- .../discretization/connection/same_mesh.py | 11 ++- meshmode/transform_metadata.py | 84 ++++++++++++++++++- 5 files changed, 147 insertions(+), 21 deletions(-) diff --git a/meshmode/discretization/__init__.py b/meshmode/discretization/__init__.py index e6d2a2609..b91fc8525 100644 --- a/meshmode/discretization/__init__.py +++ b/meshmode/discretization/__init__.py @@ -33,7 +33,9 @@ import loopy as lp from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + FirstAxisIsElementsTag, DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) from warnings import warn @@ -480,7 +482,14 @@ def _new_array(self, actx, creation_func, dtype=None): dtype = np.dtype(dtype) return _DOFArray(actx, tuple( - creation_func(shape=(grp.nelements, grp.nunit_dofs), dtype=dtype) + actx.tag_axis(0, + DiscretizationElementAxisTag.from_group(grp), + actx.tag_axis(1, + DiscretizationDOFAxisTag.from_group(grp), + creation_func(shape=(grp.nelements, + grp.nunit_dofs), + dtype=dtype) + )) for grp in self.groups)) def empty(self, actx: ArrayContext, dtype=None): @@ -576,6 +585,12 @@ def nodes(self, cached=True): def resample_mesh_nodes(grp, iaxis): # TODO: would be nice to have the mesh use an array context already nodes = actx.from_numpy(grp.mesh_el_group.nodes[iaxis]) + nodes = actx.tag_axis(0, + DiscretizationElementAxisTag.from_group(grp), + nodes) + nodes = actx.tag_axis(1, + DiscretizationDOFAxisTag.from_group(grp), + nodes) grp_unit_nodes = grp.unit_nodes.reshape(-1) meg_unit_nodes = grp.mesh_el_group.unit_nodes.reshape(-1) @@ -586,7 +601,10 @@ def resample_mesh_nodes(grp, iaxis): return nodes return actx.einsum("ij,ej->ei", - actx.from_numpy(grp.from_mesh_interp_matrix()), + actx.tag_axis( + 0, + DiscretizationDOFAxisTag.from_group(grp), + actx.from_numpy(grp.from_mesh_interp_matrix())), nodes, tagged=(FirstAxisIsElementsTag(),)) @@ -646,7 +664,9 @@ def get_mat(grp, gref_axes): return _DOFArray(actx, tuple( actx.einsum("ij,ej->ei", - get_mat(grp, ref_axes), + actx.tag_axis(0, + DiscretizationDOFAxisTag.from_group(grp), + get_mat(grp, ref_axes)), vec[grp.index], tagged=(FirstAxisIsElementsTag(),)) for grp in discr.groups)) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index 1204e3bd3..c0165899d 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -26,7 +26,8 @@ import loopy as lp from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + DiscretizationElementAxisTag, DiscretizationDOFAxisTag) from pytools import memoize_in, keyed_memoize_method from arraycontext import ( ArrayContext, NotAnArrayContainerError, @@ -439,7 +440,15 @@ def batch_pick_knl(): actx, self.to_discr.groups[i_tgrp]), n_to_nodes=self.to_discr.groups[i_tgrp].nunit_dofs )["result"] - + batch_result = actx.tag_axis( + 0, + DiscretizationElementAxisTag.from_group( + self.to_discr.groups[i_tgrp]), + actx.tag_axis(1, + DiscretizationDOFAxisTag.from_group( + self.to_discr.groups[i_tgrp]), + batch_result) + ) else: batch_result = actx.call_loopy( batch_pick_knl(), @@ -455,17 +464,24 @@ def batch_pick_knl(): # After computing each batched result, take the sum # to get the entire contribution over the group if batched_data: - group_data.append(sum(batched_data)) + to_group_data = sum(batched_data) else: # If no batched data at all, return zeros for this # particular group array - group_data.append( - actx.zeros( - shape=(self.to_discr.groups[i_tgrp].nelements, - self.to_discr.groups[i_tgrp].nunit_dofs), - dtype=ary.entry_dtype - ) - ) + to_group_data = actx.zeros( + shape=(self.to_discr.groups[i_tgrp].nelements, + self.to_discr.groups[i_tgrp].nunit_dofs), + dtype=ary.entry_dtype) + + group_data.append( + actx.tag_axis(1, + (DiscretizationDOFAxisTag + .from_group(self.to_discr.groups[i_tgrp])), + actx.tag_axis( + 0, + (DiscretizationElementAxisTag + .from_group(self.to_discr.groups[i_tgrp])), + to_group_data))) return DOFArray(actx, data=tuple(group_data)) diff --git a/meshmode/discretization/connection/face.py b/meshmode/discretization/connection/face.py index 9e57b5377..3905d6ccb 100644 --- a/meshmode/discretization/connection/face.py +++ b/meshmode/discretization/connection/face.py @@ -21,6 +21,7 @@ """ from pytools import Record +from meshmode.transform_metadata import DiscretizationElementAxisTag import numpy as np import modepy as mp @@ -442,8 +443,12 @@ def make_face_to_all_faces_embedding(actx, faces_connection, all_faces_discr, assert all_faces_grp.nelements == nfaces * vol_grp.nelements to_element_indices = actx.freeze( - vol_grp.nelements*iface - + actx.thaw(src_batch.from_element_indices)) + actx.tag_axis(0, + DiscretizationElementAxisTag.from_batch( + src_batch, + all_faces_grp), + vol_grp.nelements*iface + + actx.thaw(src_batch.from_element_indices))) batches.append( InterpolationBatch( diff --git a/meshmode/discretization/connection/same_mesh.py b/meshmode/discretization/connection/same_mesh.py index 638bcb598..281069cbf 100644 --- a/meshmode/discretization/connection/same_mesh.py +++ b/meshmode/discretization/connection/same_mesh.py @@ -21,6 +21,7 @@ """ import numpy as np +from meshmode.transform_metadata import DiscretizationElementAxisTag # {{{ same-mesh constructor @@ -42,10 +43,12 @@ def make_same_mesh_connection(actx, to_discr, from_discr): groups = [] for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)): all_elements = actx.freeze( - actx.from_numpy( - np.arange( - fgrp.nelements, - dtype=np.intp))) + actx.tag_axis(0, + DiscretizationElementAxisTag.from_group(fgrp), + actx.from_numpy( + np.arange( + fgrp.nelements, + dtype=np.intp)))) ibatch = InterpolationBatch( from_group_index=igrp, from_element_indices=all_elements, diff --git a/meshmode/transform_metadata.py b/meshmode/transform_metadata.py index 981685c60..cd3ddfc94 100644 --- a/meshmode/transform_metadata.py +++ b/meshmode/transform_metadata.py @@ -2,6 +2,10 @@ .. autoclass:: FirstAxisIsElementsTag .. autoclass:: ConcurrentElementInameTag .. autoclass:: ConcurrentDOFInameTag +.. autoclass:: DiscretizationEntityAxisTag +.. autoclass:: DiscretizationElementAxisTag +.. autoclass:: DiscretizationFaceAxisTag +.. autoclass:: DiscretizationDOFAxisTag """ __copyright__ = """ @@ -28,7 +32,12 @@ THE SOFTWARE. """ -from pytools.tag import Tag +from pytools.tag import Tag, tag_dataclass, UniqueTag +from typing import Tuple, TYPE_CHECKING + +if TYPE_CHECKING: + from meshmode.discretization import ElementGroupBase + from meshmode.discretization.connection.direct import InterpolationBatch class FirstAxisIsElementsTag(Tag): @@ -57,3 +66,76 @@ class ConcurrentDOFInameTag(Tag): computations for all DOFs within each element may be performed concurrently. """ + + +class DiscretizationEntityAxisTag(UniqueTag): + @classmethod + def from_group(cls, + group: "ElementGroupBase") -> "DiscretizationEntityAxisTag": + return cls((group.dim, group.nelements)) + + @classmethod + def from_batch(cls, + batch: "InterpolationBatch", + group: "ElementGroupBase") -> "DiscretizationEntityAxisTag": + return cls((group.dim, batch.nelements)) + + +@tag_dataclass +class DiscretizationElementAxisTag(DiscretizationEntityAxisTag): + """ + Tagged to an array's axis representing element indices in the + discretization identified by discretization key :attr:`discretization_key`. + + .. attribute:: discretization_key + + A tuple of ``(ndim, nelements)`` where ``ndim`` is the topological + dimension of the space on which array's elements reside and + ``nelements`` is the number of entries in that discretization. + """ + discretization_key: Tuple[int, int] + + def __post_init__(self) -> None: + assert isinstance(self.discretization_key, tuple) + assert len(self.discretization_key) == 2 + assert all(isinstance(k, int) for k in self.discretization_key) + + +@tag_dataclass +class DiscretizationFaceAxisTag(DiscretizationEntityAxisTag): + """ + Tagged to an array's axis representing face indices in the discretization + identified by discretization key :attr:`discretization_key`. + + .. attribute:: discretization_key + + A tuple of ``(ndim, nelements)`` where ``ndim`` is the topological + dimension of the space on which array's entries reside and + ``nelements`` is the number of elements in that discretization. + """ + discretization_key: Tuple[int, int] + + def __post_init__(self) -> None: + assert isinstance(self.discretization_key, tuple) + assert len(self.discretization_key) == 2 + assert all(isinstance(k, int) for k in self.discretization_key) + + +@tag_dataclass +class DiscretizationDOFAxisTag(DiscretizationEntityAxisTag): + """ + Tagged to an array's axis representing DOF indices in the discretization + identified by discretization key :attr:`discretization_key`. + + .. attribute:: discretization_key + + A tuple of ``(ndim, nelements)`` where ``ndim`` is the topological + dimension of the space on which array's entries reside and + ``nelements`` is the number of elements in that discretization. + """ + discretization_key: Tuple[int, int] + + def __post_init__(self) -> None: + assert isinstance(self.discretization_key, tuple) + assert len(self.discretization_key) == 2 + assert all(isinstance(k, int) for k in self.discretization_key)