|
8 | 8 | from firedrake.nullspace import VectorSpaceBasis, MixedVectorSpaceBasis
|
9 | 9 | from firedrake.solving_utils import _SNESContext
|
10 | 10 | from firedrake.tsfc_interface import extract_numbered_coefficients
|
11 |
| -from firedrake.utils import ScalarType_c, IntType_c, cached_property |
| 11 | +from firedrake.utils import IntType_c, cached_property |
12 | 12 | from finat.element_factory import create_element
|
13 | 13 | from tsfc import compile_expression_dual_evaluation
|
14 | 14 | from pyop2 import op2
|
@@ -1236,13 +1236,18 @@ def _weight(self):
|
1236 | 1236 |
|
1237 | 1237 | @cached_property
|
1238 | 1238 | def _kernels(self):
|
| 1239 | + from firedrake.interpolation import interpolate, Interpolator |
1239 | 1240 | try:
|
1240 |
| - prolong = partial(firedrake.assemble, firedrake.interpolate(self.uc, self.Vf), tensor=self.uf) |
1241 |
| - prolong() |
1242 |
| - self.rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat) |
1243 |
| - self.rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat) |
1244 |
| - restrict = partial(firedrake.assemble, firedrake.interpolate(firedrake.TestFunction(self.Vc), self.rf), tensor=self.rc) |
1245 |
| - except NotImplementedError: |
| 1241 | + assert self.Vf.ufl_element().mapping() == self.Vc.ufl_element().mapping() |
| 1242 | + P = Interpolator(interpolate(self.uc, self.Vf), self.Vf) |
| 1243 | + prolong = partial(P.assemble, tensor=self.uf) |
| 1244 | + |
| 1245 | + rf = firedrake.Function(self.Vf.dual(), val=self.uf.dat) |
| 1246 | + rc = firedrake.Function(self.Vc.dual(), val=self.uc.dat) |
| 1247 | + vc = firedrake.TestFunction(self.Vc) |
| 1248 | + R = Interpolator(interpolate(vc, rf), self.Vf) |
| 1249 | + restrict = partial(R.assemble, tensor=rc) |
| 1250 | + except (AttributeError, AssertionError, NotImplementedError): |
1246 | 1251 | # We generate custom prolongation and restriction kernels because
|
1247 | 1252 | # dual evaluation of EnrichedElement is not yet implemented in FInAT
|
1248 | 1253 | uf_map = get_permuted_map(self.Vf)
|
@@ -1439,49 +1444,6 @@ def make_blas_kernels(self, Vf, Vc):
|
1439 | 1444 | ldargs=BLASLAPACK_LIB.split(), requires_zeroed_output_arguments=True)
|
1440 | 1445 | return cache.setdefault(key, (prolong_kernel, restrict_kernel, coefficients))
|
1441 | 1446 |
|
1442 |
| - def make_kernels(self, Vf, Vc): |
1443 |
| - """ |
1444 |
| - Interpolation and restriction kernels between arbitrary elements. |
1445 |
| -
|
1446 |
| - This is temporary while we wait for dual evaluation in FInAT. |
1447 |
| - """ |
1448 |
| - cache = self._cache_kernels |
1449 |
| - key = (Vf.ufl_element(), Vc.ufl_element()) |
1450 |
| - try: |
1451 |
| - return cache[key] |
1452 |
| - except KeyError: |
1453 |
| - pass |
1454 |
| - prolong_kernel, _ = prolongation_transfer_kernel_action(Vf, self.uc) |
1455 |
| - matrix_kernel, coefficients = prolongation_transfer_kernel_action(Vf, firedrake.TrialFunction(Vc)) |
1456 |
| - |
1457 |
| - # The way we transpose the prolongation kernel is suboptimal. |
1458 |
| - # A local matrix is generated each time the kernel is executed. |
1459 |
| - element_kernel = cache_generate_code(matrix_kernel, Vf._comm) |
1460 |
| - element_kernel = element_kernel.replace("void expression_kernel", "static void expression_kernel") |
1461 |
| - coef_args = "".join([", c%d" % i for i in range(len(coefficients))]) |
1462 |
| - coef_decl = "".join([", const %s *restrict c%d" % (ScalarType_c, i) for i in range(len(coefficients))]) |
1463 |
| - dimc = Vc.finat_element.space_dimension() * Vc.block_size |
1464 |
| - dimf = Vf.finat_element.space_dimension() * Vf.block_size |
1465 |
| - restrict_code = f""" |
1466 |
| - {element_kernel} |
1467 |
| -
|
1468 |
| - void restriction({ScalarType_c} *restrict Rc, const {ScalarType_c} *restrict Rf, const {ScalarType_c} *restrict w{coef_decl}) |
1469 |
| - {{ |
1470 |
| - {ScalarType_c} Afc[{dimf}*{dimc}] = {{0}}; |
1471 |
| - expression_kernel(Afc{coef_args}); |
1472 |
| - for ({IntType_c} i = 0; i < {dimf}; i++) |
1473 |
| - for ({IntType_c} j = 0; j < {dimc}; j++) |
1474 |
| - Rc[j] += Afc[i*{dimc} + j] * Rf[i] * w[i]; |
1475 |
| - }} |
1476 |
| - """ |
1477 |
| - restrict_kernel = op2.Kernel( |
1478 |
| - restrict_code, |
1479 |
| - "restriction", |
1480 |
| - requires_zeroed_output_arguments=True, |
1481 |
| - events=matrix_kernel.events, |
1482 |
| - ) |
1483 |
| - return cache.setdefault(key, (prolong_kernel, restrict_kernel, coefficients)) |
1484 |
| - |
1485 | 1447 | def multTranspose(self, mat, rf, rc):
|
1486 | 1448 | """
|
1487 | 1449 | Implement restriction: restrict residual on fine grid rf to coarse grid rc.
|
|
0 commit comments