Skip to content

Commit

Permalink
port deprecated uses of freeze and thaw
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Jun 11, 2022
1 parent acda67e commit 97d99cb
Show file tree
Hide file tree
Showing 16 changed files with 76 additions and 87 deletions.
7 changes: 3 additions & 4 deletions examples/moving-geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

from meshmode.array_context import PyOpenCLArrayContext
from meshmode.transform_metadata import FirstAxisIsElementsTag
from arraycontext import thaw

from pytools import keyed_memoize_in
from pytools.obj_array import make_obj_array
Expand Down Expand Up @@ -186,15 +185,15 @@ def velocity_field(nodes, alpha=1.0):

def source(t, x):
discr = reconstruct_discr_from_nodes(actx, discr0, x)
u = velocity_field(thaw(discr.nodes(), actx))
u = velocity_field(actx.thaw(discr.nodes()))

# {{{

# NOTE: these are just here because this was at some point used to
# profile some more operators (turned out well!)

from meshmode.discretization import num_reference_derivative
x = thaw(discr.nodes()[0], actx)
x = actx.thaw(discr.nodes()[0])
gradx = sum(
num_reference_derivative(discr, (i,), x)
for i in range(discr.dim))
Expand All @@ -214,7 +213,7 @@ def source(t, x):
maxiter = int(tmax // timestep) + 1
dt = tmax / maxiter + 1.0e-15

x = thaw(discr0.nodes(), actx)
x = actx.thaw(discr0.nodes())
t = 0.0

if visualize:
Expand Down
5 changes: 2 additions & 3 deletions examples/parallel-vtkhdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,8 @@ def main(*, ambient_dim: int) -> None:

logger.info("[%4d] discretization: finished", mpirank)

from arraycontext import thaw
vector_field = thaw(discr.nodes(), actx)
scalar_field = actx.np.sin(thaw(discr.nodes()[0], actx))
vector_field = actx.thaw(discr.nodes())
scalar_field = actx.np.sin(vector_field[0])
part_id = 1 + mpirank + discr.zeros(actx)
logger.info("[%4d] fields: finished", mpirank)

Expand Down
3 changes: 1 addition & 2 deletions examples/plot-connectivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import pyopencl as cl

from meshmode.array_context import PyOpenCLArrayContext
from arraycontext import thaw

order = 4

Expand Down Expand Up @@ -30,7 +29,7 @@ def main():
vis = make_visualizer(actx, discr, order)

vis.write_vtk_file("geometry.vtu", [
("f", thaw(discr.nodes()[0], actx)),
("f", actx.thaw(discr.nodes()[0])),
])

from meshmode.discretization.visualization import \
Expand Down
31 changes: 15 additions & 16 deletions examples/simple-dg.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
from meshmode.array_context import (PyOpenCLArrayContext,
PytatoPyOpenCLArrayContext)
from arraycontext import (
freeze, thaw,
ArrayContainer,
map_array_container,
with_container_arithmetic,
Expand All @@ -57,7 +56,7 @@
# {{{ discretization

def parametrization_derivative(actx, discr):
thawed_nodes = thaw(discr.nodes(), actx)
thawed_nodes = actx.thaw(discr.nodes())

from meshmode.discretization import num_reference_derivative
result = np.zeros((discr.ambient_dim, discr.dim), dtype=object)
Expand Down Expand Up @@ -175,17 +174,17 @@ def get_discr(self, where):

@memoize_method
def parametrization_derivative(self):
return freeze(
return self._setup_actx.freeze(
parametrization_derivative(self._setup_actx, self.volume_discr))

@memoize_method
def vol_jacobian(self):
[a, b], [c, d] = thaw(self.parametrization_derivative(), self._setup_actx)
return freeze(a*d-b*c)
[a, b], [c, d] = self._setup_actx.thaw(self.parametrization_derivative())
return self._setup_actx.freeze(a*d - b*c)

@memoize_method
def inverse_parametrization_derivative(self):
[a, b], [c, d] = thaw(self.parametrization_derivative(), self._setup_actx)
[a, b], [c, d] = self._setup_actx.thaw(self.parametrization_derivative())

result = np.zeros((2, 2), dtype=object)
det = a*d-b*c
Expand All @@ -194,13 +193,13 @@ def inverse_parametrization_derivative(self):
result[1, 0] = -c/det
result[1, 1] = a/det

return freeze(result)
return self._setup_actx.freeze(result)

def zeros(self, actx):
return self.volume_discr.zeros(actx)

def grad(self, vec):
ipder = thaw(self.inverse_parametrization_derivative(), vec.array_context)
ipder = vec.array_context.thaw(self.inverse_parametrization_derivative())

from meshmode.discretization import num_reference_derivative
dref = [
Expand All @@ -222,15 +221,15 @@ def normal(self, where):
((a,), (b,)) = parametrization_derivative(self._setup_actx, bdry_discr)

nrm = 1/(a**2+b**2)**0.5
return freeze(flat_obj_array(b*nrm, -a*nrm))
return self._setup_actx.freeze(flat_obj_array(b*nrm, -a*nrm))

@memoize_method
def face_jacobian(self, where):
bdry_discr = self.get_discr(where)

((a,), (b,)) = parametrization_derivative(self._setup_actx, bdry_discr)

return freeze((a**2 + b**2)**0.5)
return self._setup_actx.freeze((a**2 + b**2)**0.5)

@memoize_method
def get_inverse_mass_matrix(self, grp, dtype):
Expand Down Expand Up @@ -261,7 +260,7 @@ def inverse_mass(self, vec):
tagged=(FirstAxisIsElementsTag(),)
) for grp, vec_i in zip(discr.groups, vec)
)
) / thaw(self.vol_jacobian(), actx)
) / actx.thaw(self.vol_jacobian())

@memoize_method
def get_local_face_mass_matrix(self, afgrp, volgrp, dtype):
Expand Down Expand Up @@ -300,7 +299,7 @@ def face_mass(self, vec):
all_faces_discr = all_faces_conn.to_discr
vol_discr = all_faces_conn.from_discr

fj = thaw(self.face_jacobian("all_faces"), vec.array_context)
fj = vec.array_context.thaw(self.face_jacobian("all_faces"))
vec = vec*fj

assert len(all_faces_discr.groups) == len(vol_discr.groups)
Expand Down Expand Up @@ -367,7 +366,7 @@ def wave_flux(actx, discr, c, q_tpair):
u = q_tpair.u
v = q_tpair.v

normal = thaw(discr.normal(q_tpair.where), actx)
normal = actx.thaw(discr.normal(q_tpair.where))

flux_weak = WaveState(
u=np.dot(v.avg, normal),
Expand Down Expand Up @@ -422,7 +421,7 @@ def bump(actx, discr, t=0):
source_width = 0.05
source_omega = 3

nodes = thaw(discr.volume_discr.nodes(), actx)
nodes = actx.thaw(discr.volume_discr.nodes())
center_dist = flat_obj_array([
nodes[0] - source_center[0],
nodes[1] - source_center[1],
Expand Down Expand Up @@ -492,8 +491,8 @@ def rhs(t, q):
compiled_rhs = actx_rhs.compile(rhs)

def rhs_wrapper(t, q):
r = compiled_rhs(t, thaw(freeze(q, actx_outer), actx_rhs))
return thaw(freeze(r, actx_rhs), actx_outer)
r = compiled_rhs(t, actx_rhs.thaw(actx_outer.freeze(q)))
return actx_outer.thaw(actx_rhs.freeze(r))

t = np.float64(0)
t_final = 3
Expand Down
3 changes: 1 addition & 2 deletions examples/to_firedrake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import pyopencl as cl

from meshmode.array_context import PyOpenCLArrayContext
from arraycontext import thaw


# Nb: Some of the initial setup was adapted from meshmode/examplse/simple-dg.py
Expand Down Expand Up @@ -75,7 +74,7 @@ def main():
# = e^x cos(y)
nodes = discr.nodes()
for i in range(len(nodes)):
nodes[i] = thaw(nodes[i], actx)
nodes[i] = actx.thaw(nodes[i])
# First index is dimension
candidate_sol = actx.np.exp(nodes[0]) * actx.np.cos(nodes[1])

Expand Down
4 changes: 1 addition & 3 deletions meshmode/array_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def thaw(actx, ary):
"meshmode.array_context.thaw will continue to work until 2022.",
DeprecationWarning, stacklevel=2)

from arraycontext import thaw as _thaw
# /!\ arg order flipped
return _thaw(ary, actx)
return actx.thaw(ary)


# {{{ kernel transform function
Expand Down
6 changes: 3 additions & 3 deletions meshmode/discretization/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

from pytools import memoize_method
from pytools.obj_array import make_obj_array
from arraycontext import thaw, flatten
from arraycontext import flatten
from meshmode.dof_array import DOFArray

from modepy.shapes import Shape, Simplex, Hypercube
Expand Down Expand Up @@ -550,7 +550,7 @@ def copy_with_same_connectivity(self, actx, discr, skip_tests=False):
def _vis_nodes_numpy(self):
actx = self.vis_discr._setup_actx
return np.array([
actx.to_numpy(flatten(thaw(ary, actx), actx))
actx.to_numpy(flatten(actx.thaw(ary), actx))
for ary in self.vis_discr.nodes()
])

Expand Down Expand Up @@ -1028,7 +1028,7 @@ def _xdmf_nodes_numpy(self):
actx = self.vis_discr._setup_actx
return _resample_to_numpy(
lambda x: x, self.vis_discr,
thaw(self.vis_discr.nodes(), actx),
actx.thaw(self.vis_discr.nodes()),
stack=True, by_group=True)

def _vtk_to_xdmf_cell_type(self, cell_type):
Expand Down
19 changes: 11 additions & 8 deletions meshmode/dof_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@
from arraycontext import (
ArrayContext, NotAnArrayContainerError,
make_loopy_program, with_container_arithmetic,
serialize_container, deserialize_container,
thaw as _thaw, freeze as _freeze, with_array_context,
serialize_container, deserialize_container, with_array_context,
rec_map_array_container, rec_multimap_array_container,
mapped_over_array_containers, multimapped_over_array_containers)
from arraycontext.container import ArrayOrContainerT
Expand Down Expand Up @@ -274,10 +273,10 @@ def __getstate__(self):

# Make sure metadata inference has been done
# https://github.com/inducer/meshmode/pull/318#issuecomment-1088320970
ary = _thaw(freeze(self, self.array_context), self.array_context)
ary = self.array_context.thaw(self.array_context.freeze(self))

if self.array_context is not actx:
ary = _thaw(actx, _freeze(self))
ary = actx.thaw(actx.freeze(self))

d = {}
d["data"] = [actx.to_numpy(ary_i) for ary_i in ary._data]
Expand Down Expand Up @@ -685,7 +684,7 @@ def flatten_to_numpy(actx: ArrayContext, ary: ArrayOrContainerT, *,

def _flatten_to_numpy(subary):
if isinstance(subary, DOFArray) and subary.array_context is None:
subary = _thaw(subary, actx)
subary = actx.thaw(subary)

return actx.to_numpy(_flatten_dof_array(subary, strict=strict))

Expand Down Expand Up @@ -878,11 +877,15 @@ def thaw(actx, ary):
"meshmode.dof_array.thaw will continue to work until 2022.",
DeprecationWarning, stacklevel=2)

# /!\ arg order flipped
return _thaw(ary, actx)
return actx.thaw(ary)


freeze = MovedFunctionDeprecationWrapper(_freeze, deadline="2022")
def freeze(ary, actx):
warn("meshmode.dof_array.freeze is deprecated. Use arraycontext.freeze instead. "
"meshmode.dof_array.freeze will continue to work until 2022.",
DeprecationWarning, stacklevel=2)

return actx.freeze(ary)

# }}}

Expand Down
11 changes: 5 additions & 6 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
])

from arraycontext import (
thaw, freeze,
dataclass_array_container,
with_container_arithmetic)

Expand Down Expand Up @@ -87,7 +86,7 @@ def test_flatten_unflatten(actx_factory):
a_round_trip = flatten_to_numpy(actx, unflatten_from_numpy(actx, discr, a))
assert np.array_equal(a, a_round_trip)

x = thaw(discr.nodes(), actx)
x = actx.thaw(discr.nodes())
avg_mass = DOFArray(actx, tuple([
(np.pi + actx.zeros((grp.nelements, 1), a.dtype)) for grp in discr.groups
]))
Expand All @@ -114,7 +113,7 @@ def _get_test_containers(actx, ambient_dim=2):
b=(+0.5,)*ambient_dim,
nelements_per_axis=(3,)*ambient_dim, order=1)
discr = Discretization(actx, mesh, default_simplex_group_factory(ambient_dim, 3))
x = thaw(discr.nodes()[0], actx)
x = actx.thaw(discr.nodes()[0])

# pylint: disable=unexpected-keyword-arg, no-value-for-parameter
dataclass_of_dofs = MyContainer(
Expand Down Expand Up @@ -207,9 +206,9 @@ def test_dof_array_pickling_tags(actx_factory):
state = DOFArray(actx, (actx.zeros((10, 10), "float64"),
actx.zeros((10, 10), "float64"),))

state = thaw(freeze(actx.tag(FooTag(), state), actx), actx)
state = thaw(freeze(actx.tag_axis(0, FooAxisTag(), state), actx), actx)
state = thaw(freeze(actx.tag_axis(1, FooAxisTag2(), state), actx), actx)
state = actx.thaw(actx.freeze(actx.tag(FooTag(), state)))
state = actx.thaw(actx.freeze(actx.tag_axis(0, FooAxisTag(), state)))
state = actx.thaw(actx.freeze(actx.tag_axis(1, FooAxisTag2(), state)))

with array_context_for_pickling(actx):
pkl = dumps((state, ))
Expand Down
12 changes: 6 additions & 6 deletions test/test_chained.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
[PytestPyOpenCLArrayContextFactory])

from arraycontext import thaw, flatten
from arraycontext import flatten
from meshmode.dof_array import flat_norm

import logging
Expand Down Expand Up @@ -210,7 +210,7 @@ def f(x):
from functools import reduce
return 0.1 * reduce(lambda x, y: x * actx.np.sin(5 * y), x)

x = thaw(connections[0].from_discr.nodes(), actx)
x = actx.thaw(connections[0].from_discr.nodes())
fx = f(x)
f1 = chained(fx)
f2 = connections[1](connections[0](fx))
Expand Down Expand Up @@ -242,7 +242,7 @@ def f(x):

resample_mat = actx.to_numpy(make_full_resample_matrix(actx, chained))

x = thaw(connections[0].from_discr.nodes(), actx)
x = actx.thaw(connections[0].from_discr.nodes())
fx = f(x)
f1 = resample_mat @ actx.to_numpy(flatten(fx, actx))
f2 = actx.to_numpy(flatten(chained(fx), actx))
Expand Down Expand Up @@ -307,7 +307,7 @@ def f(x):
from functools import reduce
return 0.1 * reduce(lambda x, y: x * actx.np.sin(5 * y), x)

x = thaw(connections[0].from_discr.nodes(), actx)
x = actx.thaw(connections[0].from_discr.nodes())
fx = f(x)

t_start = time.time()
Expand Down Expand Up @@ -370,8 +370,8 @@ def run(nelements, order):
reverse = L2ProjectionInverseDiscretizationConnection(chained)

# create test vector
from_nodes = thaw(chained.from_discr.nodes(), actx)
to_nodes = thaw(chained.to_discr.nodes(), actx)
from_nodes = actx.thaw(chained.from_discr.nodes())
to_nodes = actx.thaw(chained.to_discr.nodes())

from_x = 0
to_x = 0
Expand Down
6 changes: 2 additions & 4 deletions test/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import numpy as np
import pytest

from arraycontext import thaw

from meshmode.array_context import PytestPyOpenCLArrayContextFactory
from arraycontext import pytest_generate_tests_for_array_contexts
pytest_generate_tests = pytest_generate_tests_for_array_contexts(
Expand Down Expand Up @@ -58,10 +56,10 @@ def test_nodal_dg_interop(actx_factory, dim):

for ax in range(dim):
x_ax = ndgctx.pull_dof_array(actx, ndgctx.AXES[ax])
err = flat_norm(x_ax-thaw(discr.nodes()[ax], actx), np.inf)
err = flat_norm(x_ax - actx.thaw(discr.nodes()[ax]), np.inf)
assert err < 1e-15

n0 = thaw(discr.nodes()[0], actx)
n0 = actx.thaw(discr.nodes()[0])

ndgctx.push_dof_array("n0", n0)
n0_2 = ndgctx.pull_dof_array(actx, "n0")
Expand Down
Loading

0 comments on commit 97d99cb

Please sign in to comment.