Skip to content

Commit 4befce3

Browse files
PETSc FEM python wrapper: attach field information
1 parent c2d11dd commit 4befce3

File tree

4 files changed

+128
-1
lines changed

4 files changed

+128
-1
lines changed

cpp/dolfinx/la/petsc.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,52 @@ std::vector<IS> la::petsc::create_index_sets(
147147
return is;
148148
}
149149
//-----------------------------------------------------------------------------
150+
std::vector<IS> la::petsc::create_global_index_sets(
151+
const std::vector<
152+
std::pair<std::reference_wrapper<const common::IndexMap>, int>>& maps)
153+
{
154+
std::vector<IS> is;
155+
156+
std::int64_t offset = 0;
157+
std::int64_t merged_local_size = 0;
158+
MPI_Comm comm = MPI_COMM_NULL;
159+
160+
for (auto& map : maps)
161+
{
162+
if (comm == MPI_COMM_NULL)
163+
{
164+
comm = map.first.get().comm();
165+
}
166+
int result;
167+
MPI_Comm_compare(comm, map.first.get().comm(), &result);
168+
if (result != MPI_IDENT && result != MPI_CONGRUENT)
169+
{
170+
throw std::runtime_error("Not supported on the different communicators.");
171+
}
172+
int bs = map.second;
173+
std::int32_t size = map.first.get().size_local();
174+
merged_local_size += size * bs;
175+
}
176+
if (comm == MPI_COMM_NULL)
177+
return is;
178+
179+
int ierr = MPI_Exscan(&merged_local_size, &offset, 1, MPI_INT64_T,
180+
MPI_SUM, comm);
181+
dolfinx::MPI::check_error(comm, ierr);
182+
183+
for (auto& map : maps)
184+
{
185+
int bs = map.second;
186+
std::int32_t size = map.first.get().size_local();
187+
IS _is;
188+
ISCreateStride(map.first.get().comm(), bs * size, offset, 1, &_is);
189+
is.push_back(_is);
190+
offset += bs * size;
191+
}
192+
193+
return is;
194+
}
195+
//-----------------------------------------------------------------------------
150196
std::vector<std::vector<PetscScalar>> la::petsc::get_local_vectors(
151197
const Vec x,
152198
const std::vector<

cpp/dolfinx/la/petsc.h

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,25 @@ Vec create_vector_wrap(const la::Vector<V>& x)
100100
/// @note The caller is responsible for destruction of each IS.
101101
///
102102
/// @param[in] maps Vector of IndexMaps and corresponding block sizes
103-
/// @return Vector of PETSc Index Sets, created on` PETSC_COMM_SELF`
103+
/// @return Vector of PETSc Index Sets, created on `PETSC_COMM_SELF`
104104
std::vector<IS> create_index_sets(
105105
const std::vector<
106106
std::pair<std::reference_wrapper<const common::IndexMap>, int>>& maps);
107107

108+
/// @brief Compute PETSc IndexSets (IS) for a stack of index maps.
109+
///
110+
/// This function stacks the owned part of the maps and returns
111+
/// indices in the global space. The maps must have the same communicator.
112+
///
113+
/// @note Collective
114+
/// @note The caller is responsible for destruction of each IS.
115+
///
116+
/// @param[in] maps Vector of IndexMaps and corresponding block sizes
117+
/// @return Vector of PETSc Index Sets, created on the index map communicators
118+
std::vector<IS> create_global_index_sets(
119+
const std::vector<
120+
std::pair<std::reference_wrapper<const common::IndexMap>, int>>& maps);
121+
108122
/// Copy blocks from Vec into local arrays
109123
std::vector<std::vector<PetscScalar>> get_local_vectors(
110124
const Vec x,

python/dolfinx/fem/petsc.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,11 @@ def __init__(
875875
self._solver = PETSc.KSP().create(self.A.comm) # type: ignore[attr-defined]
876876
self.solver.setOperators(self.A, self.P_mat)
877877

878+
# Attach problem information
879+
dm = self.solver.getDM()
880+
dm.setCreateFieldDecomposition(partial(create_field_decomposition, u, self.L))
881+
self.solver.getPC().setDM(dm)
882+
878883
if petsc_options_prefix == "":
879884
raise ValueError("PETSc options prefix cannot be empty.")
880885

@@ -1346,6 +1351,11 @@ def __init__(
13461351
)
13471352
self.solver.setFunction(partial(assemble_residual, u, self.F, self.J, bcs), self.b)
13481353

1354+
# Attach problem information
1355+
dm = self.solver.getDM()
1356+
dm.setCreateFieldDecomposition(partial(create_field_decomposition, u, self.F))
1357+
self.solver.getKSP().setDM(dm)
1358+
13491359
if petsc_options_prefix == "":
13501360
raise ValueError("PETSc options prefix cannot be empty.")
13511361

@@ -1710,3 +1720,35 @@ def _(x: PETSc.Vec, u: typing.Union[_Function, Sequence[_Function]]): # type: i
17101720
dolfinx.la.petsc.assign(x, data0 + data1)
17111721
else:
17121722
dolfinx.la.petsc.assign(x, u.x.array)
1723+
1724+
1725+
def create_field_decomposition(
1726+
u: typing.Union[Sequence[_Function], _Function],
1727+
form: typing.Union[Form, Sequence[Form]],
1728+
_dm: PETSc.DM, # type: ignore[name-defined]
1729+
):
1730+
"""Return index sets of the fields and their associated names.
1731+
1732+
Args:
1733+
u: Function tied to the solution vector.
1734+
form: Form of the residual or of the right-hand side.
1735+
It can be a sequence of forms.
1736+
_dm: The DM instance.
1737+
1738+
Returns:
1739+
names: field names.
1740+
ises: list of index sets in global numbering.
1741+
dms: list of subDMs. This function returns `None`.
1742+
"""
1743+
1744+
if not isinstance(form, Sequence):
1745+
form = [form]
1746+
spaces = _extract_function_spaces(form)
1747+
ises = _cpp.la.petsc.create_global_index_sets(
1748+
[(V.dofmaps(0).index_map, V.dofmaps(0).index_map_bs) for V in spaces] # type: ignore[union-attr]
1749+
)
1750+
if isinstance(u, Sequence):
1751+
names = [f"{v.name + '_' if v.name != 'f' else ''}{i}" for i, v in enumerate(u)]
1752+
else:
1753+
names = [f"dolfinx_field_{i}" for i in range(len(form))]
1754+
return names, ises, None

python/dolfinx/wrappers/petsc.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,31 @@ void petsc_la_module(nb::module_& m)
210210
},
211211
nb::arg("maps"));
212212

213+
m.def(
214+
"create_global_index_sets",
215+
[](const std::vector<std::pair<const dolfinx::common::IndexMap*, int>>&
216+
maps)
217+
{
218+
using X = std::vector<std::pair<
219+
std::reference_wrapper<const dolfinx::common::IndexMap>, int>>;
220+
X _maps;
221+
std::ranges::transform(maps, std::back_inserter(_maps),
222+
[](auto m) -> typename X::value_type
223+
{ return {*m.first, m.second}; });
224+
std::vector<IS> index_sets
225+
= dolfinx::la::petsc::create_global_index_sets(_maps);
226+
227+
std::vector<nb::object> py_index_sets;
228+
for (auto is : index_sets)
229+
{
230+
PyObject* obj = PyPetscIS_New(is);
231+
PetscObjectDereference((PetscObject)is);
232+
py_index_sets.push_back(nb::steal(obj));
233+
}
234+
return py_index_sets;
235+
},
236+
nb::arg("maps"));
237+
213238
m.def(
214239
"scatter_local_vectors",
215240
[](Vec x,

0 commit comments

Comments
 (0)