Skip to content

Commit

Permalink
k
Browse files Browse the repository at this point in the history
  • Loading branch information
ksagiyam committed Sep 7, 2024
1 parent 8a8c6b4 commit d7ab8c7
Showing 1 changed file with 26 additions and 31 deletions.
57 changes: 26 additions & 31 deletions gem/optimise.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,15 @@ def _replace_indices_atomic(i, self, subst):
return substitute.get(i, i)


def _replace_indices_check(func):
def wrapper(node, self, subst):
if any(isinstance(i, VariableIndex) for i, _ in subst):
raise NotImplementedError("Can not replace VariableIndex (will need inverse)")
return func(node, self, subst)
return wrapper

@replace_indices.register(Delta)
@_replace_indices_check
def replace_indices_delta(node, self, subst):
i = _replace_indices_atomic(node.i, self, subst)
j = _replace_indices_atomic(node.j, self, subst)
Expand All @@ -116,6 +124,7 @@ def replace_indices_delta(node, self, subst):


@replace_indices.register(Indexed)
@_replace_indices_check
def replace_indices_indexed(node, self, subst):
child, = node.children
substitute = dict(subst)
Expand All @@ -137,6 +146,7 @@ def replace_indices_indexed(node, self, subst):


@replace_indices.register(FlexiblyIndexed)
@_replace_indices_check
def replace_indices_flexiblyindexed(node, self, subst):
child, = node.children
assert not child.free_indices
Expand Down Expand Up @@ -286,14 +296,6 @@ def select_expression(expressions, index):
return ComponentTensor(selected, alpha)


def _get_base_index(i):
if isinstance(i, VariableIndex):
base_index, = i.expression.free_indices
else:
base_index = i
return base_index


def delta_elimination(sum_indices, factors):
"""IndexSum-Delta cancellation.
Expand All @@ -304,35 +306,28 @@ def delta_elimination(sum_indices, factors):
sum_indices = list(sum_indices) # copy for modification

def substitute(expression, from_, to_):
if isinstance(from_, VariableIndex):
raise NotImplementedError("Can not replace VariableIndex (Will need inverse)")
base_index = _get_base_index(from_)
if base_index not in expression.free_indices:
if from_ not in expression.free_indices:
return expression
elif isinstance(expression, Delta):
mapper = MemoizerArg(filtered_replace_indices)
return mapper(expression, ((from_, to_),))
else:
if isinstance(from_, VariableIndex):
raise NotImplementedError
else:
return Indexed(ComponentTensor(expression, (from_,)), (to_,))

def make_delta_item(factors):
for f in factors:
if isinstance(f, Delta):
for from_, to_ in [(f.i, f.j), (f.j, f.i)]:
base_index = _get_base_index(from_)
if base_index in sum_indices:
return f, from_, to_, base_index
return ()

delta_item = make_delta_item(factors)
while delta_item:
delta, from_, to_, base_index = delta_item
sum_indices.remove(base_index)
return Indexed(ComponentTensor(expression, (from_,)), (to_,))

delta_queue = [(f, index)
for f in factors if isinstance(f, Delta)
for index in (f.i, f.j) if index in sum_indices]
while delta_queue:
delta, from_ = delta_queue[0]
to_, = list({delta.i, delta.j} - {from_})

sum_indices.remove(from_)

factors = [substitute(f, from_, to_) for f in factors]
delta_item = make_delta_item(factors)

delta_queue = [(f, index)
for f in factors if isinstance(f, Delta)
for index in (f.i, f.j) if index in sum_indices]

return sum_indices, factors

Expand Down

0 comments on commit d7ab8c7

Please sign in to comment.