Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cpp/demo/custom_kernel/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void assemble(MPI_Comm comm)
std::array<T, 9> A_hat_b = A_ref<T>(phi, weights);
auto kernel_a
= [A_hat = mdspan2_t<T, 3, 3>(A_hat_b.data()),
detJ](T* A, const T*, const T*, const T* x, const int*, const uint8_t*)
detJ](T* A, const T*, const T*, const T* x, const int*, const uint8_t*, void*)
{
T scale = detJ(mdspan2_t<const T, 3, 3>(x));
mdspan2_t<T, 3, 3> _A(A);
Expand All @@ -250,7 +250,7 @@ void assemble(MPI_Comm comm)
// Finite element RHS (f=1) kernel function
auto kernel_L
= [b_hat = b_ref<T>(phi, weights),
detJ](T* b, const T*, const T*, const T* x, const int*, const uint8_t*)
detJ](T* b, const T*, const T*, const T* x, const int*, const uint8_t*, void*)
{
T scale = detJ(mdspan2_t<const T, 3, 3>(x));
for (std::size_t i = 0; i < 3; ++i)
Expand Down
6 changes: 3 additions & 3 deletions cpp/dolfinx/fem/Expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class Expression
constants,
std::span<const geometry_type> X, std::array<std::size_t, 2> Xshape,
std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
const geometry_type*, const int*, const uint8_t*)>
const geometry_type*, const int*, const uint8_t*, void*)>
fn,
const std::vector<std::size_t>& value_shape,
std::shared_ptr<const FunctionSpace<geometry_type>> argument_space
Expand Down Expand Up @@ -144,7 +144,7 @@ class Expression

/// @brief Function for tabulating the Expression.
const std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
const geometry_type*, const int*, const uint8_t*)>&
const geometry_type*, const int*, const uint8_t*, void*)>&
kernel() const
{
return _fn;
Expand Down Expand Up @@ -179,7 +179,7 @@ class Expression

// Function to evaluate the Expression
std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
const geometry_type*, const int*, const uint8_t*)>
const geometry_type*, const int*, const uint8_t*, void*)>
_fn;

// Shape of the evaluated Expression
Expand Down
6 changes: 3 additions & 3 deletions cpp/dolfinx/fem/Form.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ struct integral_data
requires std::is_convertible_v<
std::remove_cvref_t<K>,
std::function<void(T*, const T*, const T*, const U*,
const int*, const uint8_t*)>>
const int*, const uint8_t*, void*)>>
and std::is_convertible_v<std::remove_cvref_t<V>,
std::vector<std::int32_t>>
and std::is_convertible_v<std::remove_cvref_t<W>,
Expand All @@ -130,7 +130,7 @@ struct integral_data

/// @brief The integration kernel.
std::function<void(T*, const T*, const T*, const U*, const int*,
const uint8_t*)>
const uint8_t*, void*)>
kernel;

/// @brief The entities to integrate over for this integral. These are
Expand Down Expand Up @@ -390,7 +390,7 @@ class Form
/// kernels for a given ID in mixed-topology meshes).
/// @return Function to call for `tabulate_tensor`.
std::function<void(scalar_type*, const scalar_type*, const scalar_type*,
const geometry_type*, const int*, const uint8_t*)>
const geometry_type*, const int*, const uint8_t*, void*)>
kernel(IntegralType type, int id, int kernel_idx) const
{
auto it = _integrals.find({type, id, kernel_idx});
Expand Down
4 changes: 2 additions & 2 deletions cpp/dolfinx/fem/assemble_expression_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void tabulate_expression(
std::next(coord_dofs.begin(), 3 * i));
}
fn(values_local.data(), &coeffs(e, 0), constants.data(),
coord_dofs.data(), nullptr, nullptr);
coord_dofs.data(), nullptr, nullptr, nullptr);
}
else
{
Expand All @@ -104,7 +104,7 @@ void tabulate_expression(
std::next(coord_dofs.begin(), 3 * i));
}
fn(values_local.data(), &coeffs(e, 0), constants.data(),
coord_dofs.data(), &entities(e, 1), nullptr);
coord_dofs.data(), &entities(e, 1), nullptr, nullptr);
}

P0(values_local, cell_info, e, size0);
Expand Down
6 changes: 3 additions & 3 deletions cpp/dolfinx/fem/assemble_matrix_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void assemble_cells(
// Tabulate tensor
std::ranges::fill(Ae, 0);
kernel(Ae.data(), &coeffs(c, 0), constants.data(), cdofs.data(), nullptr,
nullptr);
nullptr, nullptr);

// Compute A = P_0 \tilde{A} P_1^T (dof transformation)
P0(_Ae, cell_info0, cell0, ndim1); // B = P0 \tilde{A}
Expand Down Expand Up @@ -253,7 +253,7 @@ void assemble_exterior_facets(
// Tabulate tensor
std::ranges::fill(Ae, 0);
kernel(Ae.data(), &coeffs(f, 0), constants.data(), cdofs.data(),
&local_facet, &perm);
&local_facet, &perm, nullptr);

P0(_Ae, cell_info0, cell0, ndim1);
P1T(_Ae, cell_info1, cell1, ndim0);
Expand Down Expand Up @@ -423,7 +423,7 @@ void assemble_interior_facets(
: std::array{perms(cells[0], local_facet[0]),
perms(cells[1], local_facet[1])};
kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(),
local_facet.data(), perm.data());
local_facet.data(), perm.data(), nullptr);

// Local element layout is a 2x2 block matrix with structure
//
Expand Down
6 changes: 3 additions & 3 deletions cpp/dolfinx/fem/assemble_scalar_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ T assemble_cells(mdspan2_t x_dofmap,
}

fn(&value, &coeffs(index, 0), constants.data(), cdofs.data(), nullptr,
nullptr);
nullptr, nullptr);
}

return value;
Expand Down Expand Up @@ -92,7 +92,7 @@ T assemble_exterior_facets(
// Permutations
std::uint8_t perm = perms.empty() ? 0 : perms(cell, local_facet);
fn(&value, &coeffs(f, 0), constants.data(), cdofs.data(), &local_facet,
&perm);
&perm, nullptr);
}

return value;
Expand Down Expand Up @@ -144,7 +144,7 @@ T assemble_interior_facets(
: std::array{perms(cells[0], local_facet[0]),
perms(cells[1], local_facet[1])};
fn(&value, &coeffs(f, 0, 0), constants.data(), cdofs.data(),
local_facet.data(), perm.data());
local_facet.data(), perm.data(), nullptr);
}

return value;
Expand Down
12 changes: 6 additions & 6 deletions cpp/dolfinx/fem/assemble_vector_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ void _lift_bc_cells(
Ae.resize(num_rows * num_cols);
std::ranges::fill(Ae, 0);
kernel(Ae.data(), &coeffs(index, 0), constants.data(), cdofs.data(),
nullptr, nullptr);
nullptr, nullptr, nullptr);
P0(Ae, cell_info0, c0, num_cols);
P1T(Ae, cell_info1, c1, num_rows);

Expand Down Expand Up @@ -345,7 +345,7 @@ void _lift_bc_exterior_facets(
Ae.resize(num_rows * num_cols);
std::ranges::fill(Ae, 0);
kernel(Ae.data(), &coeffs(index, 0), constants.data(), cdofs.data(),
&local_facet, &perm);
&local_facet, &perm, nullptr);
P0(Ae, cell_info0, cell0, num_cols);
P1T(Ae, cell_info1, cell1, num_rows);

Expand Down Expand Up @@ -544,7 +544,7 @@ void _lift_bc_interior_facets(
: std::array{perms(cells[0], local_facet[0]),
perms(cells[1], local_facet[1])};
kernel(Ae.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(),
local_facet.data(), perm.data());
local_facet.data(), perm.data(), nullptr);

std::span<T> _Ae(Ae);
std::span<T> sub_Ae0 = _Ae.subspan(bs0 * dmap0_cell0.size() * num_cols,
Expand Down Expand Up @@ -676,7 +676,7 @@ void assemble_cells(
// Tabulate vector for cell
std::ranges::fill(be, 0);
kernel(be.data(), &coeffs(index, 0), constants.data(), cdofs.data(),
nullptr, nullptr);
nullptr, nullptr, nullptr);
P0(_be, cell_info0, c0, 1);

// Scatter cell vector to 'global' vector array
Expand Down Expand Up @@ -769,7 +769,7 @@ void assemble_exterior_facets(
// Tabulate element vector
std::ranges::fill(be, 0);
fn(be.data(), &coeffs(f, 0), constants.data(), cdofs.data(), &local_facet,
&perm);
&perm, nullptr);

P0(_be, cell_info0, cell0, 1);

Expand Down Expand Up @@ -877,7 +877,7 @@ void assemble_interior_facets(
: std::array{perms(cells[0], local_facet[0]),
perms(cells[1], local_facet[1])};
fn(be.data(), &coeffs(f, 0, 0), constants.data(), cdofs.data(),
local_facet.data(), perm.data());
local_facet.data(), perm.data(), nullptr);

std::span<T> _be(be);
std::span<T> sub_be = _be.subspan(bs * dmap0.size(), bs * dmap1.size());
Expand Down
2 changes: 1 addition & 1 deletion cpp/dolfinx/fem/traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ concept DofTransformKernel
template <class U, class T>
concept FEkernel = std::is_invocable_v<U, T*, const T*, const T*,
const scalar_value_t<T>*,
const int*, const std::uint8_t*>;
const int*, const std::uint8_t*, void*>;

/// @brief Concept for mdspan of rank 1 or 2.
template <class T>
Expand Down
20 changes: 10 additions & 10 deletions cpp/dolfinx/fem/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ Form<T, U> create_form_factory(
// Get list of integral IDs, and load tabulate tensor into memory for
// each
using kern_t = std::function<void(T*, const T*, const T*, const U*,
const int*, const std::uint8_t*)>;
const int*, const std::uint8_t*, void*)>;
std::map<std::tuple<IntegralType, int, int>, integral_data<T, U>> integrals;

auto check_geometry_hash
Expand Down Expand Up @@ -498,7 +498,7 @@ Form<T, U> create_form_factory(
{
k = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(integral->tabulate_tensor_complex64);
const unsigned char*, void*)>(integral->tabulate_tensor_complex64);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
else if constexpr (std::is_same_v<T, double>)
Expand All @@ -508,7 +508,7 @@ Form<T, U> create_form_factory(
{
k = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(integral->tabulate_tensor_complex128);
const unsigned char*, void*)>(integral->tabulate_tensor_complex128);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS

Expand Down Expand Up @@ -583,7 +583,7 @@ Form<T, U> create_form_factory(
{
k = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(integral->tabulate_tensor_complex64);
const unsigned char*, void*)>(integral->tabulate_tensor_complex64);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
else if constexpr (std::is_same_v<T, double>)
Expand All @@ -593,7 +593,7 @@ Form<T, U> create_form_factory(
{
k = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(integral->tabulate_tensor_complex128);
const unsigned char*, void*)>(integral->tabulate_tensor_complex128);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
assert(k);
Expand Down Expand Up @@ -689,7 +689,7 @@ Form<T, U> create_form_factory(
{
k = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(integral->tabulate_tensor_complex64);
const unsigned char*, void*)>(integral->tabulate_tensor_complex64);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
else if constexpr (std::is_same_v<T, double>)
Expand All @@ -699,7 +699,7 @@ Form<T, U> create_form_factory(
{
k = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(integral->tabulate_tensor_complex128);
const unsigned char*, void*)>(integral->tabulate_tensor_complex128);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
assert(k);
Expand Down Expand Up @@ -918,7 +918,7 @@ Expression<T, U> create_expression(
std::vector<std::size_t> value_shape(e.value_shape,
e.value_shape + e.num_components);
std::function<void(T*, const T*, const T*, const scalar_value_t<T>*,
const int*, const std::uint8_t*)>
const int*, const std::uint8_t*, void*)>
tabulate_tensor = nullptr;
if constexpr (std::is_same_v<T, float>)
tabulate_tensor = e.tabulate_tensor_float32;
Expand All @@ -927,7 +927,7 @@ Expression<T, U> create_expression(
{
tabulate_tensor = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(e.tabulate_tensor_complex64);
const unsigned char*, void*)>(e.tabulate_tensor_complex64);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
else if constexpr (std::is_same_v<T, double>)
Expand All @@ -937,7 +937,7 @@ Expression<T, U> create_expression(
{
tabulate_tensor = reinterpret_cast<void (*)(
T*, const T*, const T*, const scalar_value_t<T>*, const int*,
const unsigned char*)>(e.tabulate_tensor_complex128);
const unsigned char*, void*)>(e.tabulate_tensor_complex128);
}
#endif // DOLFINX_NO_STDC_COMPLEX_KERNELS
else
Expand Down
33 changes: 29 additions & 4 deletions python/demo/demo_static-condensation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
from dolfinx.io import XDMFFile
from dolfinx.jit import ffcx_jit
from dolfinx.mesh import locate_entities_boundary, meshtags
from ffcx.codegeneration.utils import empty_void_pointer
from ffcx.codegeneration.utils import numba_ufcx_kernel_signature as ufcx_signature

if np.issubdtype(PETSc.RealType, np.float32): # type: ignore
Expand Down Expand Up @@ -146,21 +147,45 @@ def sigma_u(u):


@numba.cfunc(ufcx_signature(PETSc.ScalarType, PETSc.RealType), nopython=True) # type: ignore
def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL):
def tabulate_A(A_, w_, c_, coords_, entity_local_index, permutation=ffi.NULL, custom_data=None):
"""Element kernel that applies static condensation."""

# Prepare target condensed local element tensor
A = numba.carray(A_, (Usize, Usize), dtype=PETSc.ScalarType)

# Tabulate all sub blocks locally
A00 = np.zeros((Ssize, Ssize), dtype=PETSc.ScalarType)
kernel00(ffi.from_buffer(A00), w_, c_, coords_, entity_local_index, permutation)
kernel00(
ffi.from_buffer(A00),
w_,
c_,
coords_,
entity_local_index,
permutation,
empty_void_pointer(),
)

A01 = np.zeros((Ssize, Usize), dtype=PETSc.ScalarType)
kernel01(ffi.from_buffer(A01), w_, c_, coords_, entity_local_index, permutation)
kernel01(
ffi.from_buffer(A01),
w_,
c_,
coords_,
entity_local_index,
permutation,
empty_void_pointer(),
)

A10 = np.zeros((Usize, Ssize), dtype=PETSc.ScalarType)
kernel10(ffi.from_buffer(A10), w_, c_, coords_, entity_local_index, permutation)
kernel10(
ffi.from_buffer(A10),
w_,
c_,
coords_,
entity_local_index,
permutation,
empty_void_pointer(),
)

# A = - A10 * A00^{-1} * A01
A[:, :] = -A10 @ np.linalg.solve(A00, A01)
Expand Down
10 changes: 5 additions & 5 deletions python/dolfinx/wrappers/fem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -480,9 +480,9 @@ void declare_objects(nb::module_& m, const std::string& type)

std::array<std::size_t, 2> shape{value_size, x.size() / 3};
std::vector<T> values(shape[0] * shape[1]);
std::function<void(T*, int, int, const U*)> f
= reinterpret_cast<void (*)(T*, int, int, const U*)>(addr);
f(values.data(), shape[1], shape[0], x.data());
std::function<void(T*, int, int, const U*, void*)> f
= reinterpret_cast<void (*)(T*, int, int, const U*, void*)>(addr);
f(values.data(), shape[1], shape[0], x.data(), nullptr);
dolfinx::fem::interpolate(self, std::span<const T>(values), shape,
std::span(cells.data(), cells.size()));
},
Expand Down Expand Up @@ -568,7 +568,7 @@ void declare_objects(nb::module_& m, const std::string& type)
auto tabulate_expression_ptr
= (void (*)(T*, const T*, const T*,
const typename geom_type<T>::value_type*,
const int*, const std::uint8_t*))fn_addr;
const int*, const std::uint8_t*, void*))fn_addr;
new (ex) dolfinx::fem::Expression<T, U>(
coefficients, constants, std::span(X.data(), X.size()),
{X.shape(0), X.shape(1)}, tabulate_expression_ptr, value_shape,
Expand Down Expand Up @@ -660,7 +660,7 @@ void declare_form(nb::module_& m, std::string type)
auto kn_ptr
= (void (*)(T*, const T*, const T*,
const typename geom_type<T>::value_type*,
const int*, const std::uint8_t*))ptr;
const int*, const std::uint8_t*, void*))ptr;
_integrals.insert(
{{type, id, 0},
{kn_ptr,
Expand Down
Loading