Skip to content

Commit b7bd072

Browse files
authored
Remove Interpolator from public API and other interpolation refactoring (#4595)
Simplifies interpolation code and introduces new features: * Interpolator becomes an internal object. We have removed the __new__ method and dispatch instead with a get_interpolator function. The public API is to call assemble on a symbolic interpolate as detailed in the manual. * Arguments to interpolate are no longer silently renumbered. See Stop renumbering arguments in the expression to interpolate #4582 for details. * Removed frozen_interpolator. New features: * Implemented assembly of cross-mesh interpolation operators, both forward and adjoint. * We can now matfree adjoint interpolate cross-mesh and VomOntoVom. * Removed interp_data dictionary in favour of the InterpolateOptions dataclass. We get type hinting, better IDE support, single source of truth for these options.
1 parent bd4d748 commit b7bd072

File tree

8 files changed

+977
-1112
lines changed

8 files changed

+977
-1112
lines changed

firedrake/assemble.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from firedrake.bcs import DirichletBC, EquationBC, EquationBCSplit
2424
from firedrake.functionspaceimpl import WithGeometry, FunctionSpace, FiredrakeDualSpace
2525
from firedrake.functionspacedata import entity_dofs_key, entity_permutations_key
26+
from firedrake.interpolation import get_interpolator
2627
from firedrake.petsc import PETSc
2728
from firedrake.slate import slac, slate
2829
from firedrake.slate.slac.kernel_builder import CellFacetKernelArg, LayerCountKernelArg
@@ -613,17 +614,8 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
613614
rank = len(expr.arguments())
614615
if rank > 2:
615616
raise ValueError("Cannot assemble an Interpolate with more than two arguments")
616-
# Get the target space
617-
V = v.function_space().dual()
618-
619-
# Get the interpolator
620-
interp_data = expr.interp_data.copy()
621-
default_missing_val = interp_data.pop('default_missing_val', None)
622-
if rank == 1 and isinstance(tensor, firedrake.Function):
623-
V = tensor
624-
interpolator = firedrake.Interpolator(expr, V, bcs=bcs, **interp_data)
625-
# Assembly
626-
return interpolator.assemble(tensor=tensor, default_missing_val=default_missing_val)
617+
interpolator = get_interpolator(expr)
618+
return interpolator.assemble(tensor=tensor, bcs=bcs)
627619
elif tensor and isinstance(expr, (firedrake.Function, firedrake.Cofunction, firedrake.MatrixBase)):
628620
return tensor.assign(expr)
629621
elif tensor and isinstance(expr, ufl.ZeroBaseForm):

firedrake/bcs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# A module implementing strong (Dirichlet) boundary conditions.
22
import numpy as np
33

4-
import functools
4+
from functools import partial, reduce
55
import itertools
66

77
import ufl
@@ -167,7 +167,7 @@ def hermite_stride(bcnodes):
167167
# Edge conditions have only been tested with Lagrange elements.
168168
# Need to expand the list.
169169
bcnodes1.append(hermite_stride(self._function_space.boundary_nodes(ss)))
170-
bcnodes1 = functools.reduce(np.intersect1d, bcnodes1)
170+
bcnodes1 = reduce(np.intersect1d, bcnodes1)
171171
bcnodes.append(bcnodes1)
172172
return np.concatenate(bcnodes)
173173

@@ -359,11 +359,11 @@ def function_arg(self, g):
359359
raise RuntimeError(f"Provided boundary value {g} does not match shape of space")
360360
try:
361361
self._function_arg = firedrake.Function(V)
362-
# Use `Interpolator` instead of assembling an `Interpolate` form
363-
# as the expression compilation needs to happen at this stage to
364-
# determine if we should use interpolation or projection
365-
# -> e.g. interpolation may not be supported for the element.
366-
self._function_arg_update = firedrake.Interpolator(g, self._function_arg)._interpolate
362+
interpolator = firedrake.get_interpolator(firedrake.interpolate(g, V))
363+
# Call this here to check if the element supports interpolation
364+
# TODO: It's probably better to have a more explicit way of checking this
365+
interpolator._get_callable()
366+
self._function_arg_update = partial(interpolator.assemble, tensor=self._function_arg)
367367
except (NotImplementedError, AttributeError):
368368
# Element doesn't implement interpolation
369369
self._function_arg = firedrake.Function(V).project(g)

firedrake/interpolation.py

Lines changed: 905 additions & 1064 deletions
Large diffs are not rendered by default.

firedrake/preconditioners/hiptmair.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from firedrake.preconditioners.hypre_ams import chop
1111
from firedrake.preconditioners.facet_split import restrict
1212
from firedrake.parameters import parameters
13-
from firedrake.interpolation import Interpolator
13+
from firedrake.interpolation import interpolate
1414
from ufl.algorithms.ad import expand_derivatives
1515
import firedrake.dmhooks as dmhooks
1616
import firedrake.utils as utils
@@ -202,7 +202,7 @@ def coarsen(self, pc):
202202

203203
coarse_space_bcs = tuple(coarse_space_bcs)
204204
if G_callback is None:
205-
interp_petscmat = chop(Interpolator(dminus(trial), V, bcs=bcs + coarse_space_bcs).callable().handle)
205+
interp_petscmat = chop(assemble(interpolate(dminus(trial), V), bcs=bcs + coarse_space_bcs).petscmat)
206206
else:
207207
interp_petscmat = G_callback(coarse_space, V, coarse_space_bcs, bcs)
208208

firedrake/preconditioners/pmg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,14 +1248,14 @@ def _kernels(self):
12481248
return self._build_custom_interpolators()
12491249

12501250
def _build_native_interpolators(self):
1251-
from firedrake.interpolation import interpolate, Interpolator
1252-
P = Interpolator(interpolate(self.uc, self.Vf), self.Vf)
1251+
from firedrake.interpolation import interpolate, get_interpolator
1252+
P = get_interpolator(interpolate(self.uc, self.Vf))
12531253
prolong = partial(P.assemble, tensor=self.uf)
12541254

12551255
rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat)
12561256
rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat)
12571257
vc = firedrake.TestFunction(self.Vc)
1258-
R = Interpolator(interpolate(vc, rf), self.Vf)
1258+
R = get_interpolator(interpolate(vc, rf))
12591259
restrict = partial(R.assemble, tensor=rc)
12601260
return prolong, restrict
12611261

tests/firedrake/regression/test_interpolate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def test_interpolator_reuse(family, degree, mode):
584584
u = Function(V.dual())
585585
expr = interpolate(TestFunction(V), u)
586586

587-
I = Interpolator(expr, V)
587+
I = get_interpolator(expr)
588588

589589
for k in range(3):
590590
u.assign(rg.uniform(u.function_space()))

tests/firedrake/regression/test_interpolate_cross_mesh.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,12 +339,18 @@ def test_exact_refinement():
339339
expr_in_V_fine = x**2 + y**2 + 1
340340
f_fine = Function(V_fine).interpolate(expr_in_V_fine)
341341

342+
# Build interpolation matrices in both directions
343+
coarse_to_fine = assemble(interpolate(TrialFunction(V_coarse), V_fine))
344+
coarse_to_fine_adjoint = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual())))
345+
342346
# If we now interpolate f_coarse into V_fine we should get a function
343347
# which has no interpolation error versus f_fine because we were able to
344348
# exactly represent expr_in_V_coarse in V_coarse and V_coarse is a subset
345349
# of V_fine
346350
f_coarse_on_fine = assemble(interpolate(f_coarse, V_fine))
347351
assert np.allclose(f_coarse_on_fine.dat.data_ro, f_fine.dat.data_ro)
352+
f_coarse_on_fine_mat = assemble(coarse_to_fine @ f_coarse)
353+
assert np.allclose(f_coarse_on_fine_mat.dat.data_ro, f_fine.dat.data_ro)
348354

349355
# Adjoint interpolation takes us from V_fine^* to V_coarse^* so we should
350356
# also get an exact result here.
@@ -354,6 +360,10 @@ def test_exact_refinement():
354360
assert np.allclose(
355361
cofunction_fine_on_coarse.dat.data_ro, cofunction_coarse.dat.data_ro
356362
)
363+
cofunction_fine_on_coarse_mat = assemble(action(coarse_to_fine_adjoint, cofunction_fine))
364+
assert np.allclose(
365+
cofunction_fine_on_coarse_mat.dat.data_ro, cofunction_coarse.dat.data_ro
366+
)
357367

358368
# Now we test with expressions which are NOT exactly representable in the
359369
# function spaces by introducing a cube term. This can't be represented
@@ -550,7 +560,7 @@ def test_missing_dofs():
550560
V_src = FunctionSpace(m_src, "CG", 2)
551561
V_dest = FunctionSpace(m_dest, "CG", 3)
552562
with pytest.raises(DofNotDefinedError):
553-
Interpolator(TestFunction(V_src), V_dest)
563+
assemble(interpolate(TrialFunction(V_src), V_dest))
554564
f_src = Function(V_src).interpolate(expr)
555565
f_dest = assemble(interpolate(f_src, V_dest, allow_missing_dofs=True))
556566
dest_eval = PointEvaluator(m_dest, coords)
@@ -680,6 +690,32 @@ def test_interpolate_matrix_cross_mesh():
680690
f_interp2.dat.data_wo[:] = f_at_points_correct_order3.dat.data_ro[:]
681691
assert np.allclose(f_interp2.dat.data_ro, g.dat.data_ro)
682692

693+
interp_mat2 = assemble(interpolate(TrialFunction(U), V))
694+
assert interp_mat2.arguments() == (TestFunction(V.dual()), TrialFunction(U))
695+
f_interp3 = assemble(interp_mat2 @ f)
696+
assert f_interp3.function_space() == V
697+
assert np.allclose(f_interp3.dat.data_ro, g.dat.data_ro)
698+
699+
700+
@pytest.mark.parallel([1, 3])
701+
def test_interpolate_matrix_cross_mesh_adjoint():
702+
mesh_fine = UnitSquareMesh(4, 4)
703+
mesh_coarse = UnitSquareMesh(2, 2)
704+
705+
V_coarse = FunctionSpace(mesh_coarse, "CG", 1)
706+
V_fine = FunctionSpace(mesh_fine, "CG", 1)
707+
708+
cofunc_fine = assemble(conj(TestFunction(V_fine)) * dx)
709+
710+
interp = assemble(interpolate(TestFunction(V_coarse), TrialFunction(V_fine.dual())))
711+
cofunc_coarse = assemble(Action(interp, cofunc_fine))
712+
assert interp.arguments() == (TestFunction(V_coarse), TrialFunction(V_fine.dual()))
713+
assert cofunc_coarse.function_space() == V_coarse.dual()
714+
715+
# Compare cofunc_fine with direct interpolation
716+
cofunc_coarse_direct = assemble(conj(TestFunction(V_coarse)) * dx)
717+
assert np.allclose(cofunc_coarse.dat.data_ro, cofunc_coarse_direct.dat.data_ro)
718+
683719

684720
@pytest.mark.parallel([2, 3, 4])
685721
def test_voting_algorithm_edgecases():

tests/firedrake/vertexonly/test_vertex_only_fs.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def pseudo_random_coords(size):
6868

6969
# Function Space Generation Tests
7070

71-
def functionspace_tests(vm, petsc_raises):
71+
def functionspace_tests(vm):
7272
# Prep
7373
num_cells = len(vm.coordinates.dat.data_ro)
7474
num_cells_mpi_global = MPI.COMM_WORLD.allreduce(num_cells, op=MPI.SUM)
@@ -144,27 +144,25 @@ def functionspace_tests(vm, petsc_raises):
144144
h_star = h.riesz_representation(riesz_map="l2")
145145
g = assemble(interpolate(TestFunction(V), h_star))
146146
assert np.allclose(g.dat.data_ro_with_halos, np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
147-
with petsc_raises(NotImplementedError):
148-
# Can't use adjoint on interpolates with expressions yet
149-
g2 = assemble(interpolate(2 * TestFunction(V), h_star))
150-
assert np.allclose(g2.dat.data_ro_with_halos, 2*np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
147+
148+
g2 = assemble(interpolate(2 * TestFunction(V), h_star))
149+
assert np.allclose(g2.dat.data_ro_with_halos, 2*np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
151150

152151
h_star = assemble(interpolate(TestFunction(W), g))
153152
h = h_star.riesz_representation(riesz_map="l2")
154153
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))
155154
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == 0)
156-
with petsc_raises(NotImplementedError):
157-
# Can't use adjoint on interpolates with expressions yet
158-
h2 = assemble(interpolate(2 * TestFunction(W), g))
159-
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], 2*np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))
155+
156+
h2 = assemble(interpolate(2 * TestFunction(W), g))
157+
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], 2*np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension), axis=1))
160158

161159
g = assemble(interpolate(h, V))
162160
assert np.allclose(g.dat.data_ro_with_halos, np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
163161
g2 = assemble(interpolate(2 * h, V))
164162
assert np.allclose(g2.dat.data_ro_with_halos, 2*np.prod(vm.coordinates.dat.data_ro_with_halos.reshape(-1, vm.geometric_dimension), axis=1))
165163

166164

167-
def vectorfunctionspace_tests(vm, petsc_raises):
165+
def vectorfunctionspace_tests(vm):
168166
# Prep
169167
gdim = vm.geometric_dimension
170168
num_cells = len(vm.coordinates.dat.data_ro)
@@ -240,18 +238,16 @@ def vectorfunctionspace_tests(vm, petsc_raises):
240238
h_star = h.riesz_representation(riesz_map="l2")
241239
g = assemble(interpolate(TestFunction(V), h_star))
242240
assert np.allclose(g.dat.data_ro_with_halos, 2*vm.coordinates.dat.data_ro_with_halos)
243-
with petsc_raises(NotImplementedError):
244-
# Can't use adjoint on interpolate with expressions yet
245-
g2 = assemble(interpolate(2 * TestFunction(V), h_star))
246-
assert np.allclose(g2.dat.data_ro_with_halos, 4*vm.coordinates.dat.data_ro_with_halos)
241+
242+
g2 = assemble(interpolate(2 * TestFunction(V), h_star))
243+
assert np.allclose(g2.dat.data_ro_with_halos, 4*vm.coordinates.dat.data_ro_with_halos)
247244

248245
h_star = assemble(interpolate(TestFunction(W), g))
249246
assert np.allclose(h_star.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
250247
assert np.all(h_star.dat.data_ro_with_halos[~idxs_to_include] == 0)
251-
with petsc_raises(NotImplementedError):
252-
# Can't use adjoint on interpolate with expressions yet
253-
h2 = assemble(interpolate(2 * TestFunction(W), g))
254-
assert np.allclose(h2.dat.data_ro[idxs_to_include], 4*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
248+
249+
h2 = assemble(interpolate(2 * TestFunction(W), g))
250+
assert np.allclose(h2.dat.data_ro[idxs_to_include], 4*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
255251

256252
h = h_star.riesz_representation(riesz_map="l2")
257253
g = assemble(interpolate(h, V))
@@ -261,12 +257,12 @@ def vectorfunctionspace_tests(vm, petsc_raises):
261257

262258

263259
@pytest.mark.parallel([1, 3])
264-
def test_functionspaces(parentmesh, vertexcoords, petsc_raises):
260+
def test_functionspaces(parentmesh, vertexcoords):
265261
vm = VertexOnlyMesh(parentmesh, vertexcoords, missing_points_behaviour="ignore")
266-
functionspace_tests(vm, petsc_raises)
267-
vectorfunctionspace_tests(vm, petsc_raises)
268-
functionspace_tests(vm.input_ordering, petsc_raises)
269-
vectorfunctionspace_tests(vm.input_ordering, petsc_raises)
262+
functionspace_tests(vm)
263+
vectorfunctionspace_tests(vm)
264+
functionspace_tests(vm.input_ordering)
265+
vectorfunctionspace_tests(vm.input_ordering)
270266

271267

272268
@pytest.mark.parallel(nprocs=2)

0 commit comments

Comments
 (0)