Skip to content

Commit

Permalink
add metadata to describe arrays' types
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Nov 8, 2021
1 parent cc8273f commit 9988827
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 21 deletions.
28 changes: 24 additions & 4 deletions meshmode/discretization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@

import loopy as lp
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag)
ConcurrentElementInameTag, ConcurrentDOFInameTag,
FirstAxisIsElementsTag, DiscretizationElementAxisTag,
DiscretizationDOFAxisTag)

from warnings import warn

Expand Down Expand Up @@ -480,7 +482,14 @@ def _new_array(self, actx, creation_func, dtype=None):
dtype = np.dtype(dtype)

return _DOFArray(actx, tuple(
creation_func(shape=(grp.nelements, grp.nunit_dofs), dtype=dtype)
actx.tag_axis(0,
DiscretizationElementAxisTag.from_group(grp),
actx.tag_axis(1,
DiscretizationDOFAxisTag.from_group(grp),
creation_func(shape=(grp.nelements,
grp.nunit_dofs),
dtype=dtype)
))
for grp in self.groups))

def empty(self, actx: ArrayContext, dtype=None):
Expand Down Expand Up @@ -576,6 +585,12 @@ def nodes(self, cached=True):
def resample_mesh_nodes(grp, iaxis):
# TODO: would be nice to have the mesh use an array context already
nodes = actx.from_numpy(grp.mesh_el_group.nodes[iaxis])
nodes = actx.tag_axis(0,
DiscretizationElementAxisTag.from_group(grp),
nodes)
nodes = actx.tag_axis(1,
DiscretizationDOFAxisTag.from_group(grp),
nodes)

grp_unit_nodes = grp.unit_nodes.reshape(-1)
meg_unit_nodes = grp.mesh_el_group.unit_nodes.reshape(-1)
Expand All @@ -586,7 +601,10 @@ def resample_mesh_nodes(grp, iaxis):
return nodes

return actx.einsum("ij,ej->ei",
actx.from_numpy(grp.from_mesh_interp_matrix()),
actx.tag_axis(
0,
DiscretizationDOFAxisTag.from_group(grp),
actx.from_numpy(grp.from_mesh_interp_matrix())),
nodes,
tagged=(FirstAxisIsElementsTag(),))

Expand Down Expand Up @@ -646,7 +664,9 @@ def get_mat(grp, gref_axes):

return _DOFArray(actx, tuple(
actx.einsum("ij,ej->ei",
get_mat(grp, ref_axes),
actx.tag_axis(0,
DiscretizationDOFAxisTag.from_group(grp),
get_mat(grp, ref_axes)),
vec[grp.index],
tagged=(FirstAxisIsElementsTag(),))
for grp in discr.groups))
Expand Down
36 changes: 26 additions & 10 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

import loopy as lp
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag)
ConcurrentElementInameTag, ConcurrentDOFInameTag,
DiscretizationElementAxisTag, DiscretizationDOFAxisTag)
from pytools import memoize_in, keyed_memoize_method
from arraycontext import (
ArrayContext, NotAnArrayContainerError,
Expand Down Expand Up @@ -439,7 +440,15 @@ def batch_pick_knl():
actx, self.to_discr.groups[i_tgrp]),
n_to_nodes=self.to_discr.groups[i_tgrp].nunit_dofs
)["result"]

batch_result = actx.tag_axis(
0,
DiscretizationElementAxisTag.from_group(
self.to_discr.groups[i_tgrp]),
actx.tag_axis(1,
DiscretizationDOFAxisTag.from_group(
self.to_discr.groups[i_tgrp]),
batch_result)
)
else:
batch_result = actx.call_loopy(
batch_pick_knl(),
Expand All @@ -455,17 +464,24 @@ def batch_pick_knl():
# After computing each batched result, take the sum
# to get the entire contribution over the group
if batched_data:
group_data.append(sum(batched_data))
to_group_data = sum(batched_data)
else:
# If no batched data at all, return zeros for this
# particular group array
group_data.append(
actx.zeros(
shape=(self.to_discr.groups[i_tgrp].nelements,
self.to_discr.groups[i_tgrp].nunit_dofs),
dtype=ary.entry_dtype
)
)
to_group_data = actx.zeros(
shape=(self.to_discr.groups[i_tgrp].nelements,
self.to_discr.groups[i_tgrp].nunit_dofs),
dtype=ary.entry_dtype)

group_data.append(
actx.tag_axis(1,
(DiscretizationDOFAxisTag
.from_group(self.to_discr.groups[i_tgrp])),
actx.tag_axis(
0,
(DiscretizationElementAxisTag
.from_group(self.to_discr.groups[i_tgrp])),
to_group_data)))

return DOFArray(actx, data=tuple(group_data))

Expand Down
9 changes: 7 additions & 2 deletions meshmode/discretization/connection/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

from pytools import Record
from meshmode.transform_metadata import DiscretizationElementAxisTag

import numpy as np
import modepy as mp
Expand Down Expand Up @@ -442,8 +443,12 @@ def make_face_to_all_faces_embedding(actx, faces_connection, all_faces_discr,
assert all_faces_grp.nelements == nfaces * vol_grp.nelements

to_element_indices = actx.freeze(
vol_grp.nelements*iface
+ actx.thaw(src_batch.from_element_indices))
actx.tag_axis(0,
DiscretizationElementAxisTag.from_batch(
src_batch,
all_faces_grp),
vol_grp.nelements*iface
+ actx.thaw(src_batch.from_element_indices)))

batches.append(
InterpolationBatch(
Expand Down
11 changes: 7 additions & 4 deletions meshmode/discretization/connection/same_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import numpy as np
from meshmode.transform_metadata import DiscretizationElementAxisTag


# {{{ same-mesh constructor
Expand All @@ -42,10 +43,12 @@ def make_same_mesh_connection(actx, to_discr, from_discr):
groups = []
for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)):
all_elements = actx.freeze(
actx.from_numpy(
np.arange(
fgrp.nelements,
dtype=np.intp)))
actx.tag_axis(0,
DiscretizationElementAxisTag.from_group(fgrp),
actx.from_numpy(
np.arange(
fgrp.nelements,
dtype=np.intp))))
ibatch = InterpolationBatch(
from_group_index=igrp,
from_element_indices=all_elements,
Expand Down
84 changes: 83 additions & 1 deletion meshmode/transform_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
.. autoclass:: FirstAxisIsElementsTag
.. autoclass:: ConcurrentElementInameTag
.. autoclass:: ConcurrentDOFInameTag
.. autoclass:: DiscretizationEntityAxisTag
.. autoclass:: DiscretizationElementAxisTag
.. autoclass:: DiscretizationFaceAxisTag
.. autoclass:: DiscretizationDOFAxisTag
"""

__copyright__ = """
Expand All @@ -28,7 +32,12 @@
THE SOFTWARE.
"""

from pytools.tag import Tag
from pytools.tag import Tag, tag_dataclass, UniqueTag
from typing import Tuple, TYPE_CHECKING

if TYPE_CHECKING:
from meshmode.discretization import ElementGroupBase
from meshmode.discretization.connection.direct import InterpolationBatch


class FirstAxisIsElementsTag(Tag):
Expand Down Expand Up @@ -57,3 +66,76 @@ class ConcurrentDOFInameTag(Tag):
computations for all DOFs within each element may be performed
concurrently.
"""


class DiscretizationEntityAxisTag(UniqueTag):
@classmethod
def from_group(cls,
group: "ElementGroupBase") -> "DiscretizationEntityAxisTag":
return cls((group.dim, group.nelements))

@classmethod
def from_batch(cls,
batch: "InterpolationBatch",
group: "ElementGroupBase") -> "DiscretizationEntityAxisTag":
return cls((group.dim, batch.nelements))


@tag_dataclass
class DiscretizationElementAxisTag(DiscretizationEntityAxisTag):
"""
Tagged to an array's axis representing element indices in the group
identified by discretization key :attr:`discretization_key`.
.. attribute:: discretization_key
A tuple of ``(ndim, nelements)`` where ``ndim`` is the topological
dimension of the space on which array's elements reside and
``nelements`` is the number of elements in that discretization.
"""
discretization_key: Tuple[int, int]

def __post_init__(self) -> None:
assert isinstance(self.discretization_key, tuple)
assert len(self.discretization_key) == 2
assert all(isinstance(k, int) for k in self.discretization_key)


@tag_dataclass
class DiscretizationFaceAxisTag(DiscretizationEntityAxisTag):
"""
Tagged to an array's axis representing face indices in the group
identified by discretization key :attr:`discretization_key`.
.. attribute:: discretization_key
A tuple of ``(ndim, nelements)`` where ``ndim`` is the topological
dimension of the space on which array's elements reside and
``nelements`` is the number of elements in that discretization.
"""
discretization_key: Tuple[int, int]

def __post_init__(self) -> None:
assert isinstance(self.discretization_key, tuple)
assert len(self.discretization_key) == 2
assert all(isinstance(k, int) for k in self.discretization_key)


@tag_dataclass
class DiscretizationDOFAxisTag(DiscretizationEntityAxisTag):
"""
Tagged to an array's axis representing DOF indices in the group
identified by discretization key :attr:`discretization_key`.
.. attribute:: discretization_key
A tuple of ``(ndim, nelements)`` where ``ndim`` is the topological
dimension of the space on which array's elements reside and
``nelements`` is the number of elements in that discretization.
"""
discretization_key: Tuple[int, int]

def __post_init__(self) -> None:
assert isinstance(self.discretization_key, tuple)
assert len(self.discretization_key) == 2
assert all(isinstance(k, int) for k in self.discretization_key)

0 comments on commit 9988827

Please sign in to comment.