From 1ca4d5a2ac0e4689ca9361c87df6db3191f3e054 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 17 Nov 2021 23:41:14 -0600 Subject: [PATCH] Attach discretization tags to array axes' coming out of primitive ops --- meshmode/discretization/__init__.py | 31 ++++++++++++++----- meshmode/discretization/connection/direct.py | 23 ++++++++++++-- meshmode/discretization/connection/face.py | 7 +++-- meshmode/discretization/connection/modal.py | 19 ++++++------ .../discretization/connection/same_mesh.py | 11 ++++--- meshmode/discretization/visualization.py | 10 ++++-- 6 files changed, 74 insertions(+), 27 deletions(-) diff --git a/meshmode/discretization/__init__.py b/meshmode/discretization/__init__.py index 6ce8850a2..1a3a3b145 100644 --- a/meshmode/discretization/__init__.py +++ b/meshmode/discretization/__init__.py @@ -36,11 +36,13 @@ import numpy as np import loopy as lp -from arraycontext import ArrayContext, make_loopy_program +from arraycontext import ArrayContext, make_loopy_program, tag_axes from pytools import memoize_in, memoize_method, keyed_memoize_in from pytools.obj_array import make_obj_array from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag, FirstAxisIsElementsTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + FirstAxisIsElementsTag, DiscretizationElementAxisTag, + DiscretizationDOFAxisTag) # underscored because it shouldn't be imported from here. from meshmode.dof_array import DOFArray as _DOFArray @@ -542,9 +544,14 @@ def _new_array(self, actx, creation_func, dtype=None): else: dtype = np.dtype(dtype) - return _DOFArray(actx, tuple( - creation_func(shape=(grp.nelements, grp.nunit_dofs), dtype=dtype) - for grp in self.groups)) + return tag_axes(actx, { + 0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + _DOFArray(actx, + tuple(creation_func(shape=(grp.nelements, + grp.nunit_dofs), + dtype=dtype) + for grp in self.groups))) def empty(self, actx: ArrayContext, dtype: Optional[np.dtype] = None) -> _DOFArray: @@ -643,7 +650,10 @@ def nodes(self, cached: bool = True) -> np.ndarray: def resample_mesh_nodes(grp, iaxis): # TODO: would be nice to have the mesh use an array context already - nodes = actx.from_numpy(grp.mesh_el_group.nodes[iaxis]) + nodes = tag_axes(actx, + {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + actx.from_numpy(grp.mesh_el_group.nodes[iaxis])) grp_unit_nodes = grp.unit_nodes.reshape(-1) meg_unit_nodes = grp.mesh_el_group.unit_nodes.reshape(-1) @@ -654,7 +664,10 @@ def resample_mesh_nodes(grp, iaxis): return nodes return actx.einsum("ij,ej->ei", - actx.from_numpy(grp.from_mesh_interp_matrix()), + actx.tag_axis( + 0, + DiscretizationDOFAxisTag(), + actx.from_numpy(grp.from_mesh_interp_matrix())), nodes, tagged=(FirstAxisIsElementsTag(),)) @@ -714,7 +727,9 @@ def get_mat(grp, gref_axes): return _DOFArray(actx, tuple( actx.einsum("ij,ej->ei", - get_mat(grp, ref_axes), + actx.tag_axis(0, + DiscretizationDOFAxisTag(), + get_mat(grp, ref_axes)), vec[igrp], tagged=(FirstAxisIsElementsTag(),)) for igrp, grp in enumerate(discr.groups))) diff --git a/meshmode/discretization/connection/direct.py b/meshmode/discretization/connection/direct.py index fa95c1308..b6cd6b36c 100644 --- a/meshmode/discretization/connection/direct.py +++ b/meshmode/discretization/connection/direct.py @@ -30,11 +30,13 @@ import loopy as lp from meshmode.transform_metadata import ( - ConcurrentElementInameTag, ConcurrentDOFInameTag) + ConcurrentElementInameTag, ConcurrentDOFInameTag, + DiscretizationElementAxisTag, DiscretizationDOFAxisTag) from pytools import memoize_in, keyed_memoize_method from arraycontext import ( ArrayContext, NotAnArrayContainerError, serialize_container, deserialize_container, make_loopy_program, + tag_axes ) from arraycontext.container import ArrayT, ArrayOrContainerT @@ -372,7 +374,10 @@ def _resample_matrix(self, actx: ArrayContext, to_group_index: int, from_grp_basis_fcts, ibatch.result_unit_nodes, from_grp.unit_nodes) - return actx.freeze(actx.from_numpy(result)) + # freeze, attach metadata + return actx.freeze( + tag_axes(actx, {1: DiscretizationDOFAxisTag()}, + actx.from_numpy(result))) # }}} @@ -726,6 +731,14 @@ def group_pick_knl(is_surjective: bool): from_el_present.reshape((-1, 1)), grp_ary_contrib, 0) + + # attach metadata + grp_ary_contrib = tag_axes( + actx, + {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + grp_ary_contrib) + group_array_contributions.append(grp_ary_contrib) else: for fgpd in group_pick_info: @@ -801,6 +814,12 @@ def group_pick_knl(is_surjective: bool): self.to_discr.groups[i_tgrp].nunit_dofs) )["result"] + # attach metadata + batch_result = tag_axes(actx, + {0: DiscretizationElementAxisTag(), + 1: DiscretizationDOFAxisTag()}, + batch_result,) + group_array_contributions.append(batch_result) if group_array_contributions: diff --git a/meshmode/discretization/connection/face.py b/meshmode/discretization/connection/face.py index 28a03a599..34f65d893 100644 --- a/meshmode/discretization/connection/face.py +++ b/meshmode/discretization/connection/face.py @@ -21,6 +21,7 @@ """ from dataclasses import dataclass +from meshmode.transform_metadata import DiscretizationElementAxisTag import numpy as np import modepy as mp @@ -445,8 +446,10 @@ def make_face_to_all_faces_embedding(actx, faces_connection, all_faces_discr, assert all_faces_grp.nelements == nfaces * vol_grp.nelements to_element_indices = actx.freeze( - vol_grp.nelements*iface - + actx.thaw(src_batch.from_element_indices)) + actx.tag_axis(0, + DiscretizationElementAxisTag(), + vol_grp.nelements*iface + + actx.thaw(src_batch.from_element_indices))) batches.append( InterpolationBatch( diff --git a/meshmode/discretization/connection/modal.py b/meshmode/discretization/connection/modal.py index c0178563d..2fe813e5b 100644 --- a/meshmode/discretization/connection/modal.py +++ b/meshmode/discretization/connection/modal.py @@ -29,7 +29,8 @@ from arraycontext import ( NotAnArrayContainerError, serialize_container, deserialize_container) -from meshmode.transform_metadata import FirstAxisIsElementsTag +from meshmode.transform_metadata import (FirstAxisIsElementsTag, + DiscretizationDOFAxisTag) from meshmode.discretization import InterpolatoryElementGroupBase from meshmode.discretization.poly_element import QuadratureSimplexElementGroup from meshmode.discretization.connection.direct import DiscretizationConnection @@ -142,12 +143,12 @@ def quadrature_matrix(grp, mgrp): grp.unit_nodes) w_diag = np.diag(grp.quadrature_rule().weights) vtw = np.dot(vdm.T, w_diag) - return actx.from_numpy(vtw) + return actx.tag_axis(0, DiscretizationDOFAxisTag(), + actx.from_numpy(vtw)) return actx.einsum("ib,eb->ei", - quadrature_matrix(grp, mgrp), - ary, - tagged=(FirstAxisIsElementsTag(),)) + quadrature_matrix(grp, mgrp), + ary, tagged=(FirstAxisIsElementsTag(),)) def _compute_coeffs_via_inv_vandermonde(self, actx, ary, grp): @@ -160,12 +161,12 @@ def vandermonde_inverse(grp): vdm = mp.vandermonde(grp.basis_obj().functions, grp.unit_nodes) vdm_inv = la.inv(vdm) - return actx.from_numpy(vdm_inv) + return actx.tag_axis(0, DiscretizationDOFAxisTag(), + actx.from_numpy(vdm_inv)) return actx.einsum("ij,ej->ei", - vandermonde_inverse(grp), - ary, - tagged=(FirstAxisIsElementsTag(),)) + vandermonde_inverse(grp), + ary, tagged=(FirstAxisIsElementsTag(),)) def __call__(self, ary): """Computes modal coefficients data from a functions diff --git a/meshmode/discretization/connection/same_mesh.py b/meshmode/discretization/connection/same_mesh.py index 638bcb598..ba79d056a 100644 --- a/meshmode/discretization/connection/same_mesh.py +++ b/meshmode/discretization/connection/same_mesh.py @@ -21,6 +21,7 @@ """ import numpy as np +from meshmode.transform_metadata import DiscretizationElementAxisTag # {{{ same-mesh constructor @@ -42,10 +43,12 @@ def make_same_mesh_connection(actx, to_discr, from_discr): groups = [] for igrp, (fgrp, tgrp) in enumerate(zip(from_discr.groups, to_discr.groups)): all_elements = actx.freeze( - actx.from_numpy( - np.arange( - fgrp.nelements, - dtype=np.intp))) + actx.tag_axis(0, + DiscretizationElementAxisTag(), + actx.from_numpy( + np.arange( + fgrp.nelements, + dtype=np.intp)))) ibatch = InterpolationBatch( from_group_index=igrp, from_element_indices=all_elements, diff --git a/meshmode/discretization/visualization.py b/meshmode/discretization/visualization.py index 6658bdcf4..8ddfc0518 100644 --- a/meshmode/discretization/visualization.py +++ b/meshmode/discretization/visualization.py @@ -33,6 +33,7 @@ from pytools.obj_array import make_obj_array from arraycontext import flatten from meshmode.dof_array import DOFArray +from meshmode.transform_metadata import DiscretizationFlattenedDOFAxisTag from modepy.shapes import Shape, Simplex, Hypercube @@ -144,7 +145,9 @@ def _resample_to_numpy(conn, vis_discr, vec, *, stack=False, by_group=False): from meshmode.dof_array import check_dofarray_against_discr check_dofarray_against_discr(vis_discr, vec) - return actx.to_numpy(flatten(vec, actx)) + return actx.to_numpy(actx.tag_axis(0, + DiscretizationFlattenedDOFAxisTag(), + flatten(vec, actx))) else: raise TypeError(f"unsupported array type: {type(vec).__name__}") @@ -550,7 +553,10 @@ def copy_with_same_connectivity(self, actx, discr, skip_tests=False): def _vis_nodes_numpy(self): actx = self.vis_discr._setup_actx return np.array([ - actx.to_numpy(flatten(actx.thaw(ary), actx)) + actx.to_numpy(actx.tag_axis( + 0, + DiscretizationFlattenedDOFAxisTag(), + flatten(actx.thaw(ary), actx))) for ary in self.vis_discr.nodes() ])