Skip to content

Commit

Permalink
Attach discretization tags to array axes' coming out of primitive ops
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd authored and inducer committed Jun 11, 2022
1 parent 76fa5d1 commit 1ca4d5a
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 27 deletions.
31 changes: 23 additions & 8 deletions meshmode/discretization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@
import numpy as np

import loopy as lp
from arraycontext import ArrayContext, make_loopy_program
from arraycontext import ArrayContext, make_loopy_program, tag_axes
from pytools import memoize_in, memoize_method, keyed_memoize_in
from pytools.obj_array import make_obj_array
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag)
ConcurrentElementInameTag, ConcurrentDOFInameTag,
FirstAxisIsElementsTag, DiscretizationElementAxisTag,
DiscretizationDOFAxisTag)

# underscored because it shouldn't be imported from here.
from meshmode.dof_array import DOFArray as _DOFArray
Expand Down Expand Up @@ -542,9 +544,14 @@ def _new_array(self, actx, creation_func, dtype=None):
else:
dtype = np.dtype(dtype)

return _DOFArray(actx, tuple(
creation_func(shape=(grp.nelements, grp.nunit_dofs), dtype=dtype)
for grp in self.groups))
return tag_axes(actx, {
0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
_DOFArray(actx,
tuple(creation_func(shape=(grp.nelements,
grp.nunit_dofs),
dtype=dtype)
for grp in self.groups)))

def empty(self, actx: ArrayContext,
dtype: Optional[np.dtype] = None) -> _DOFArray:
Expand Down Expand Up @@ -643,7 +650,10 @@ def nodes(self, cached: bool = True) -> np.ndarray:

def resample_mesh_nodes(grp, iaxis):
# TODO: would be nice to have the mesh use an array context already
nodes = actx.from_numpy(grp.mesh_el_group.nodes[iaxis])
nodes = tag_axes(actx,
{0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
actx.from_numpy(grp.mesh_el_group.nodes[iaxis]))

grp_unit_nodes = grp.unit_nodes.reshape(-1)
meg_unit_nodes = grp.mesh_el_group.unit_nodes.reshape(-1)
Expand All @@ -654,7 +664,10 @@ def resample_mesh_nodes(grp, iaxis):
return nodes

return actx.einsum("ij,ej->ei",
actx.from_numpy(grp.from_mesh_interp_matrix()),
actx.tag_axis(
0,
DiscretizationDOFAxisTag(),
actx.from_numpy(grp.from_mesh_interp_matrix())),
nodes,
tagged=(FirstAxisIsElementsTag(),))

Expand Down Expand Up @@ -714,7 +727,9 @@ def get_mat(grp, gref_axes):

return _DOFArray(actx, tuple(
actx.einsum("ij,ej->ei",
get_mat(grp, ref_axes),
actx.tag_axis(0,
DiscretizationDOFAxisTag(),
get_mat(grp, ref_axes)),
vec[igrp],
tagged=(FirstAxisIsElementsTag(),))
for igrp, grp in enumerate(discr.groups)))
Expand Down
23 changes: 21 additions & 2 deletions meshmode/discretization/connection/direct.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@

import loopy as lp
from meshmode.transform_metadata import (
ConcurrentElementInameTag, ConcurrentDOFInameTag)
ConcurrentElementInameTag, ConcurrentDOFInameTag,
DiscretizationElementAxisTag, DiscretizationDOFAxisTag)
from pytools import memoize_in, keyed_memoize_method
from arraycontext import (
ArrayContext, NotAnArrayContainerError,
serialize_container, deserialize_container, make_loopy_program,
tag_axes
)
from arraycontext.container import ArrayT, ArrayOrContainerT

Expand Down Expand Up @@ -372,7 +374,10 @@ def _resample_matrix(self, actx: ArrayContext, to_group_index: int,
from_grp_basis_fcts,
ibatch.result_unit_nodes, from_grp.unit_nodes)

return actx.freeze(actx.from_numpy(result))
# freeze, attach metadata
return actx.freeze(
tag_axes(actx, {1: DiscretizationDOFAxisTag()},
actx.from_numpy(result)))

# }}}

Expand Down Expand Up @@ -726,6 +731,14 @@ def group_pick_knl(is_surjective: bool):
from_el_present.reshape((-1, 1)),
grp_ary_contrib,
0)

# attach metadata
grp_ary_contrib = tag_axes(
actx,
{0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
grp_ary_contrib)

group_array_contributions.append(grp_ary_contrib)
else:
for fgpd in group_pick_info:
Expand Down Expand Up @@ -801,6 +814,12 @@ def group_pick_knl(is_surjective: bool):
self.to_discr.groups[i_tgrp].nunit_dofs)
)["result"]

# attach metadata
batch_result = tag_axes(actx,
{0: DiscretizationElementAxisTag(),
1: DiscretizationDOFAxisTag()},
batch_result,)

group_array_contributions.append(batch_result)

if group_array_contributions:
Expand Down
7 changes: 5 additions & 2 deletions meshmode/discretization/connection/face.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

from dataclasses import dataclass
from meshmode.transform_metadata import DiscretizationElementAxisTag

import numpy as np
import modepy as mp
Expand Down Expand Up @@ -445,8 +446,10 @@ def make_face_to_all_faces_embedding(actx, faces_connection, all_faces_discr,
assert all_faces_grp.nelements == nfaces * vol_grp.nelements

to_element_indices = actx.freeze(
vol_grp.nelements*iface
+ actx.thaw(src_batch.from_element_indices))
actx.tag_axis(0,
DiscretizationElementAxisTag(),
vol_grp.nelements*iface
+ actx.thaw(src_batch.from_element_indices)))

batches.append(
InterpolationBatch(
Expand Down
19 changes: 10 additions & 9 deletions meshmode/discretization/connection/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@

from arraycontext import (
NotAnArrayContainerError, serialize_container, deserialize_container)
from meshmode.transform_metadata import FirstAxisIsElementsTag
from meshmode.transform_metadata import (FirstAxisIsElementsTag,
DiscretizationDOFAxisTag)
from meshmode.discretization import InterpolatoryElementGroupBase
from meshmode.discretization.poly_element import QuadratureSimplexElementGroup
from meshmode.discretization.connection.direct import DiscretizationConnection
Expand Down Expand Up @@ -142,12 +143,12 @@ def quadrature_matrix(grp, mgrp):
grp.unit_nodes)
w_diag = np.diag(grp.quadrature_rule().weights)
vtw = np.dot(vdm.T, w_diag)
return actx.from_numpy(vtw)
return actx.tag_axis(0, DiscretizationDOFAxisTag(),
actx.from_numpy(vtw))

return actx.einsum("ib,eb->ei",
quadrature_matrix(grp, mgrp),
ary,
tagged=(FirstAxisIsElementsTag(),))
quadrature_matrix(grp, mgrp),
ary, tagged=(FirstAxisIsElementsTag(),))

def _compute_coeffs_via_inv_vandermonde(self, actx, ary, grp):

Expand All @@ -160,12 +161,12 @@ def vandermonde_inverse(grp):
vdm = mp.vandermonde(grp.basis_obj().functions,
grp.unit_nodes)
vdm_inv = la.inv(vdm)
return actx.from_numpy(vdm_inv)
return actx.tag_axis(0, DiscretizationDOFAxisTag(),
actx.from_numpy(vdm_inv))

return actx.einsum("ij,ej->ei",
vandermonde_inverse(grp),
ary,
tagged=(FirstAxisIsElementsTag(),))
vandermonde_inverse(grp),
ary, tagged=(FirstAxisIsElementsTag(),))

def __call__(self, ary):
"""Computes modal coefficients data from a functions
Expand Down
11 changes: 7 additions & 4 deletions meshmode/discretization/connection/same_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"""

import numpy as np
from meshmode.transform_metadata import DiscretizationElementAxisTag


# {{{ same-mesh constructor
Expand All @@ -42,10 +43,12 @@ def make_same_mesh_connection(actx, to_discr, from_discr):
groups = []
for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)):
all_elements = actx.freeze(
actx.from_numpy(
np.arange(
fgrp.nelements,
dtype=np.intp)))
actx.tag_axis(0,
DiscretizationElementAxisTag(),
actx.from_numpy(
np.arange(
fgrp.nelements,
dtype=np.intp))))
ibatch = InterpolationBatch(
from_group_index=igrp,
from_element_indices=all_elements,
Expand Down
10 changes: 8 additions & 2 deletions meshmode/discretization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from pytools.obj_array import make_obj_array
from arraycontext import flatten
from meshmode.dof_array import DOFArray
from meshmode.transform_metadata import DiscretizationFlattenedDOFAxisTag

from modepy.shapes import Shape, Simplex, Hypercube

Expand Down Expand Up @@ -144,7 +145,9 @@ def _resample_to_numpy(conn, vis_discr, vec, *, stack=False, by_group=False):
from meshmode.dof_array import check_dofarray_against_discr
check_dofarray_against_discr(vis_discr, vec)

return actx.to_numpy(flatten(vec, actx))
return actx.to_numpy(actx.tag_axis(0,
DiscretizationFlattenedDOFAxisTag(),
flatten(vec, actx)))
else:
raise TypeError(f"unsupported array type: {type(vec).__name__}")

Expand Down Expand Up @@ -550,7 +553,10 @@ def copy_with_same_connectivity(self, actx, discr, skip_tests=False):
def _vis_nodes_numpy(self):
actx = self.vis_discr._setup_actx
return np.array([
actx.to_numpy(flatten(actx.thaw(ary), actx))
actx.to_numpy(actx.tag_axis(
0,
DiscretizationFlattenedDOFAxisTag(),
flatten(actx.thaw(ary), actx)))
for ary in self.vis_discr.nodes()
])

Expand Down

0 comments on commit 1ca4d5a

Please sign in to comment.