Skip to content

Commit

Permalink
Fix bugged input ordering interpolation action
Browse files Browse the repository at this point in the history
I had managed to not test the whole of the interpolation API when I
originally implemented this and it turned out that I would return None
when, for example, calling interpolate(f_input_ordering, P0DG).
  • Loading branch information
ReubenHill committed Aug 30, 2023
1 parent b55df2b commit 45e62bb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 4 deletions.
12 changes: 8 additions & 4 deletions firedrake/interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,11 +743,15 @@ def make_interpolator(expr, V, subset, access, bcs=None):
# (so for vector function spaces in 2 dimensions we might need a
# concatenation of 2 MPI.DOUBLE types when we are in real mode)
if tensor is not None:
# Callable will do interpolation into tensor (which is a Dat) when
# it is called.
wrapper.mpi_type, _ = get_dat_mpi_type(tensor)
# Callable will do interpolation into our pre-supplied function f
# when it is called.
assert f.dat is tensor
wrapper.mpi_type, _ = get_dat_mpi_type(f.dat)
assert not len(arguments)
callable = partial(wrapper.forward_operation, tensor)

def callable():
wrapper.forward_operation(f.dat)
return f
else:
assert len(arguments) == 1
assert tensor is None
Expand Down
18 changes: 18 additions & 0 deletions tests/vertexonly/test_vertex_only_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def functionspace_tests(vm):
idxs_to_include = input_ordering_parent_cell_nums != -1
assert np.allclose(h.dat.data_ro_with_halos[idxs_to_include], np.prod(vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include].reshape(-1, vm.input_ordering.geometric_dimension()), axis=1))
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
# check other interpolation APIs work identically
h2 = interpolate(g, W)
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
I = Interpolator(g, W)
h2 = I.interpolate()
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
h2.zero()
I.interpolate(output=h2)
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
# check we can interpolate expressions
h2 = Function(W)
h2.interpolate(2*g*Constant(1, domain=vm))
Expand Down Expand Up @@ -203,6 +212,15 @@ def vectorfunctionspace_tests(vm):
idxs_to_include = input_ordering_parent_cell_nums != -1
assert np.allclose(h.dat.data_ro[idxs_to_include], 2*vm.input_ordering.coordinates.dat.data_ro_with_halos[idxs_to_include])
assert np.all(h.dat.data_ro_with_halos[~idxs_to_include] == -1)
# check other interpolation APIs work identically
h2 = interpolate(g, W)
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
I = Interpolator(g, W)
h2 = I.interpolate()
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
h2.zero()
I.interpolate(output=h2)
assert np.allclose(h2.dat.data_ro_with_halos[idxs_to_include], h.dat.data_ro_with_halos[idxs_to_include])
# check we can interpolate expressions
h2 = Function(W)
h2.interpolate(2*g*Constant(1, domain=vm))
Expand Down

0 comments on commit 45e62bb

Please sign in to comment.