Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/actions/install/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ runs:
--extra-index-url https://download.pytorch.org/whl/cpu \
"./firedrake-repo[${{ inputs.deps }}]"

pip install -v --no-deps --ignore-installed git+https://github.com/firedrakeproject/ufl.git@pbrubeck/form-product
firedrake-clean
pip list

Expand Down
267 changes: 253 additions & 14 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,123 @@ def assemble(self, tensor=None, current_state=None):
"""


class _LRCDescriptor:
"""Unassembled low-rank matrix contribution.

The descriptor lets BaseForm traversal carry low-rank columns through
parent nodes such as FormSum before a single MATLRC is created.
"""

def __init__(self, expr, base_terms=(), row_factors=(), col_factors=(), weights=(), bcs=None):
self.expr = expr
self.base_terms = tuple(base_terms)
self.row_factors = tuple(row_factors)
self.col_factors = tuple(col_factors)
self.weights = tuple(weights)
self.bcs = bcs

def scaled(self, weight):
weight = PETSc.ScalarType(weight)
return _LRCDescriptor(
self.expr,
base_terms=tuple((matrix, weight * w) for matrix, w in self.base_terms),
row_factors=self.row_factors,
col_factors=self.col_factors,
weights=tuple(weight * w for w in self.weights),
)

@cached_property
def petscmat(self):
U = self._vector_columns_to_dense(self.row_factors, bcs=self.bcs)
V = self._vector_columns_to_dense(self.col_factors, bcs=self.bcs)
c = self._lrc_weights_to_vec(self.weights)
A = self._lrc_base_petscmat(self.base_terms)
petscmat = PETSc.Mat().createLRC(A, U, c, V)
petscmat.assemble()
return petscmat

@staticmethod
def _vector_columns_to_dense(functions, bcs=None):
if not functions:
raise ValueError("Cannot build an LRC factor matrix with no columns")

first, *_ = functions
with first.dat.vec_ro as vec:
comm = vec.comm
local_size = vec.getLocalSize()
global_size = vec.getSize()

dense = PETSc.Mat().createDense(
size=((local_size, global_size), (len(functions), len(functions))), comm=comm
)
dense.setUp()
values = dense.getDenseArray()
for i, function in enumerate(functions):
function = _LRCDescriptor._apply_bcs(function, bcs)
with function.dat.vec_ro as vec:
if vec.getLocalSize() != local_size or vec.getSize() != global_size:
raise ValueError("LRC factor vectors do not share the same layout")
values[:, i] = vec.getArray(readonly=True)
dense.assemble()
return dense

@staticmethod
def _lrc_weights_to_vec(weights):
c = PETSc.Vec().createSeq(len(weights), comm=PETSc.COMM_SELF)
c.setValues(range(len(weights)), weights)
c.assemble()
return c

@staticmethod
def _bc_matches_space(bc, function_space):
V = bc.function_space()
while True:
if V == function_space or V.dual() == function_space:
return True
if V.parent is None:
return False
V = V.parent

@staticmethod
def _lrc_factor_bcs(factor, bcs):
if not bcs:
return ()

arg, = factor.arguments()
function_space = arg.function_space()
return tuple(bc for bc in bcs if _LRCDescriptor._bc_matches_space(bc, function_space))

@staticmethod
def _lrc_base_petscmat(base_terms):
if not base_terms:
return None

if len(base_terms) == 1:
matrix, weight = base_terms[0]
if weight == 1:
return matrix.petscmat

result = PETSc.Mat()
for i, (matrix, weight) in enumerate(base_terms):
if i == 0:
matrix.petscmat.copy(result=result)
result.scale(weight)
else:
result.axpy(weight, matrix.petscmat)
result.assemble()
return result

@staticmethod
def _apply_bcs(factor, bcs):
bcs = _LRCDescriptor._lrc_factor_bcs(factor, bcs)
if not bcs:
return factor
factor = factor.copy(deepcopy=True)
for bc in bcs:
bc.zero(factor)
return factor


class BaseFormAssembler(AbstractFormAssembler):
"""Base form assembler.

Expand Down Expand Up @@ -400,6 +517,114 @@ def _as_pyop2_type(tensor, indices=None):
assert indices is None
return tensor

@staticmethod
def _as_scalar_value(value):
if isinstance(value, ufl.ZeroBaseForm):
value = 0.0
elif isinstance(value, ufl.constantvalue.Zero):
value = 0.0
elif isinstance(value, ufl.constantvalue.ScalarValue):
value = value.value()
elif isinstance(value, (firedrake.Constant, firedrake.Function)):
value = value.dat.data_ro

if isinstance(value, numpy.ndarray):
# Assert singleton ndarray
value = value.item()
if not isinstance(value, numbers.Complex):
raise ValueError("Expecting a scalar expression")
return value

def _form_product_lrc_descriptor(self, expr, tensor, bcs, assembled_factors, scale=1):
if self._mat_type != "lrc":
raise ValueError("FormProduct assembly requires mat_type='lrc'")
if tensor is not None:
raise NotImplementedError("Assembly of FormProduct into an existing tensor is not supported")
if len(expr.factors()) != 2:
raise NotImplementedError("LRC FormProduct assembly currently supports exactly two factors")
if len(expr.arguments()) != 2:
raise ValueError("LRC FormProduct assembly requires aggregate rank 2")
if any(len(factor.arguments()) != 1 for factor in expr.factors()):
raise ValueError("LRC FormProduct assembly requires rank-1 factors")
if len(assembled_factors) != 2:
raise TypeError("Not enough operands for FormProduct")
if not all(isinstance(factor, (firedrake.Cofunction, firedrake.Function)) for factor in assembled_factors):
raise TypeError("LRC FormProduct factors must assemble to Functions or Cofunctions")

return _LRCDescriptor(expr, row_factors=assembled_factors[:1], col_factors=assembled_factors[1:],
weights=(PETSc.ScalarType(scale),), bcs=bcs)

def _form_sum_lrc_descriptor(self, expr, tensor, bcs, assembled_components, weights):
if not any(isinstance(component, _LRCDescriptor) for component in assembled_components):
return None

if self._mat_type != "lrc":
raise ValueError("FormSum with FormProduct terms requires mat_type='lrc'")
if tensor is not None:
raise NotImplementedError("Assembly of FormSum with LRC terms into an existing tensor is not supported")

base_terms = []
row_factors = []
col_factors = []
lrc_weights = []
for component, weight in zip(assembled_components, weights):
if isinstance(component, _LRCDescriptor):
descriptor = component.scaled(weight)
base_terms.extend(descriptor.base_terms)
row_factors.extend(descriptor.row_factors)
col_factors.extend(descriptor.col_factors)
lrc_weights.extend(descriptor.weights)
elif isinstance(component, MatrixBase):
base_terms.append((component, PETSc.ScalarType(weight)))
else:
raise TypeError("Mismatching FormSum shapes")

return _LRCDescriptor(expr, base_terms=base_terms, row_factors=row_factors,
col_factors=col_factors, weights=lrc_weights)

def _assemble_lrc_descriptor(self, descriptor):
return Matrix(descriptor.expr, descriptor.petscmat, bcs=descriptor.bcs,
options_prefix=self._options_prefix,
fc_params=self._form_compiler_params)

def _assemble_form_product(self, expr, tensor, bcs, assembled_factors):
ranks = tuple(len(factor.arguments()) for factor in expr.factors())
if len(assembled_factors) != len(expr.factors()):
return self._form_product_lrc_descriptor(expr, tensor, bcs, assembled_factors)
if sum(ranks) > 2 or not any(rank == 0 for rank in ranks):
return self._form_product_lrc_descriptor(expr, tensor, bcs, assembled_factors)

scalar_weight = PETSc.ScalarType(1)
higher_rank_factors = []
assembled_higher_rank_factors = []
for factor, rank, assembled_factor in zip(expr.factors(), ranks, assembled_factors):
if rank == 0:
scalar_weight *= self._as_scalar_value(assembled_factor)
else:
higher_rank_factors.append(factor)
assembled_higher_rank_factors.append(assembled_factor)

if not higher_rank_factors:
return tensor.assign(scalar_weight) if tensor else scalar_weight

if len(higher_rank_factors) == 1:
assembled_higher_rank_form = assembled_higher_rank_factors[0]
elif len(higher_rank_factors) == 2 and all(len(factor.arguments()) == 1
for factor in higher_rank_factors):
higher_rank_form = ufl.FormProduct(*higher_rank_factors)
descriptor = self._form_product_lrc_descriptor(
higher_rank_form, tensor, bcs, assembled_higher_rank_factors,
scale=scalar_weight)
return _LRCDescriptor(expr,
row_factors=descriptor.row_factors,
col_factors=descriptor.col_factors,
weights=descriptor.weights)
else:
raise ValueError("FormProduct preprocessing requires remaining aggregate rank <= 2")

weighted_form = ufl.FormSum((assembled_higher_rank_form, scalar_weight))
return self.base_form_assembly_visitor(weighted_form, tensor, bcs, assembled_higher_rank_form)

def assemble(self, tensor=None, current_state=None):
"""Assemble the form.

Expand Down Expand Up @@ -432,6 +657,9 @@ def visitor(e, *operands):
visited = {}
result = BaseFormAssembler.base_form_postorder_traversal(self._form, visitor, visited)

if isinstance(result, _LRCDescriptor):
return self._assemble_lrc_descriptor(result)

# Deal with 1-form bcs outside the visitor
rank = len(self._form.arguments())
if rank == 1 and not isinstance(result, ufl.ZeroBaseForm):
Expand Down Expand Up @@ -460,17 +688,22 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
assembler = OneFormAssembler(form, form_compiler_parameters=self._form_compiler_params,
zero_bc_nodes=self._zero_bc_nodes, diagonal=self._diagonal, weight=self._weight)
elif rank == 2:
mat_type = self._sub_mat_type if self._mat_type == "lrc" else self._mat_type
assembler = TwoFormAssembler(form, bcs=bcs, form_compiler_parameters=self._form_compiler_params,
mat_type=self._mat_type, sub_mat_type=self._sub_mat_type,
mat_type=mat_type, sub_mat_type=self._sub_mat_type,
options_prefix=self._options_prefix, appctx=self._appctx, weight=self._weight,
allocation_integral_types=self.allocation_integral_types)
else:
raise AssertionError
return assembler.assemble(tensor=tensor)
elif isinstance(expr, ufl.FormProduct):
return self._assemble_form_product(expr, tensor, bcs, args)
elif isinstance(expr, ufl.Adjoint):
if len(args) != 1:
raise TypeError("Not enough operands for Adjoint")
mat, = args
if isinstance(mat, _LRCDescriptor):
mat = self._assemble_lrc_descriptor(mat)
result = tensor.petscmat if tensor else PETSc.Mat()
# Out-of-place Hermitian transpose
mat.petscmat.hermitianTranspose(out=result)
Expand All @@ -482,6 +715,10 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
if len(args) != 2:
raise TypeError("Not enough operands for Action")
lhs, rhs = args
if isinstance(lhs, _LRCDescriptor):
lhs = self._assemble_lrc_descriptor(lhs)
if isinstance(rhs, _LRCDescriptor):
rhs = self._assemble_lrc_descriptor(rhs)
if isinstance(lhs, MatrixBase):
if isinstance(rhs, (firedrake.Cofunction, firedrake.Function)):
petsc_mat = lhs.petscmat
Expand Down Expand Up @@ -525,19 +762,11 @@ def base_form_assembly_visitor(self, expr, tensor, bcs, *args):
# Assemble weights
weights = []
for w in expr.weights():
if isinstance(w, ufl.constantvalue.Zero):
w = 0.0
elif isinstance(w, ufl.constantvalue.ScalarValue):
w = w.value()
elif isinstance(w, (firedrake.Constant, firedrake.Function)):
w = w.dat.data_ro

if isinstance(w, numpy.ndarray):
# Assert singleton ndarray
w = w.item()
if not isinstance(w, numbers.Complex):
raise ValueError("Expecting a scalar weight expression")
weights.append(w)
weights.append(self._as_scalar_value(w))

lrc_descriptor = self._form_sum_lrc_descriptor(expr, tensor, bcs, args, weights)
if lrc_descriptor is not None:
return lrc_descriptor

# Scalar FormSum
if all(isinstance(op, numbers.Complex) for op in args):
Expand Down Expand Up @@ -683,12 +912,22 @@ def base_form_preorder_traversal(expr, visitor, visited={}):
def reconstruct_node_from_operands(expr, operands):
if isinstance(expr, (ufl.Adjoint, ufl.Action)):
return expr._ufl_expr_reconstruct_(*operands)
elif isinstance(expr, ufl.FormProduct):
return ufl.FormProduct(*operands) if operands else expr
elif isinstance(expr, ufl.FormSum):
return ufl.FormSum(*[(op, w) for op, w in zip(operands, expr.weights())])
return expr

@staticmethod
def base_form_operands(expr):
if isinstance(expr, ufl.FormProduct):
ranks = tuple(len(factor.arguments()) for factor in expr.factors())
if (len(expr.factors()) == 2 and len(expr.arguments()) == 2
and all(rank == 1 for rank in ranks)):
return expr.ufl_operands
if sum(ranks) <= 2 and any(rank == 0 for rank in ranks):
return expr.ufl_operands
return []
if isinstance(expr, (ufl.FormSum, ufl.Adjoint, ufl.Action)):
return expr.ufl_operands
if isinstance(expr, ufl.Form):
Expand Down
Loading
Loading