Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flatten function #112

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions src/protosym/core/differentiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@

from dataclasses import dataclass
from dataclasses import field
from functools import reduce
from typing import TYPE_CHECKING as _TYPE_CHECKING

from protosym.core.tree import forward_graph
from protosym.core.tree import Tr


__all__ = [
Expand All @@ -20,11 +22,239 @@


if _TYPE_CHECKING:
from typing import Callable, Sequence
from protosym.core.atom import AtomType
from protosym.core.tree import Tree, SubsFunc

_DiffRules = dict[tuple[Tree, int], SubsFunc]


@dataclass(frozen=True)
class RingOps:
"""Collection of ring operations."""

Integer: AtomType[int]
iadd: Callable[[int, int], int]
imul: Callable[[int, int], int]
add: Tree
mul: Tree
pow: Tree

def split_integers(self, args: Sequence[Tree]) -> tuple[list[int], list[Tree]]:
integers: list[int] = []
new_args: list[Tree] = []
for arg in args:
if not arg.children and (atom := arg.value).atom_type == self.Integer:
integers.append(atom.value) # type: ignore
else:
new_args.append(arg)
return integers, new_args

def flatten_add(self, args: list[Tree]) -> Tree:
integers: list[int] = []
new_args: list[Tree] = []

# Associativity of add
# (x + y) + z -> x + y + z
for arg in args:
if arg.children and arg.children[0] == self.add:
new_args.extend(arg.children[1:])
else:
new_args.append(arg)

# Collect integer part
# x + 1 + 2 -> 3 + x
new_args2: list[Tree] = []
for arg in new_args:
if not arg.children and (atom := arg.value).atom_type == self.Integer:
integers.append(atom.value) # type:ignore
else:
new_args2.append(arg)

# Process all muls, extracting their coefficients
# 2*x + 3*x -> 5*x
totals = {}
for arg in new_args2:
if arg.children and arg.children[0] == self.mul:
intfacs, factors = self.split_integers(arg.children[1:])
if len(factors) == 1:
[fac] = factors
else:
fac = self.mul(*factors)
integer = reduce(self.imul, intfacs, 1)
if fac not in totals:
totals[fac] = 0
totals[fac] += integer
else:
if arg not in totals:
totals[arg] = 0
totals[arg] += 1

new_args3: list[Tree] = []
for fac, c in totals.items():
if c == 0:
continue
elif c == 1:
new_args3.append(fac)
elif fac.children and fac.children[0] == self.mul:
new_args3.append(self.mul(Tr(self.Integer(c)), *fac.children[1:]))
else:
new_args3.append(self.mul(Tr(self.Integer(c)), fac))

int_value = reduce(self.iadd, integers, 0)

if int_value:
new_args3.insert(0, Tr(self.Integer(int_value)))

if not new_args3:
expr = Tr(self.Integer(0))
elif len(new_args3) == 1:
[expr] = new_args3
else:
expr = self.add(*new_args3)

return expr

def flatten_mul(self, args: list[Tree]) -> Tree:
integers: list[int] = []
new_args: list[Tree] = []

for arg in args:
if arg.children and arg.children[0] == self.mul:
new_args.extend(arg.children[1:])
else:
new_args.append(arg)

new_args2: list[Tree] = []
for arg in new_args:
if not arg.children and (atom := arg.value).atom_type == self.Integer:
integers.append(atom.value) # type:ignore
else:
new_args2.append(arg)

powers = {}
for arg in new_args2:
if arg.children and arg.children[0] == self.pow:
base, s_exp = arg.children[1:]
exp: int
if s_exp.value.atom_type == self.Integer:
exp = s_exp.value.value # type: ignore
else:
base, exp = arg, 1
else:
base, exp = arg, 1

if base not in powers:
powers[base] = 0
powers[base] += exp

new_args3: list[Tree] = []
for base, exp in powers.items():
if exp == 0:
continue
elif exp == 1:
new_args3.append(base)
else:
new_args3.append(self.pow(base, Tr(self.Integer(exp))))

int_value = reduce(self.imul, integers, 1)

if int_value == 0:
return Tr(self.Integer(int_value))
elif int_value != 1:
new_args3.insert(0, Tr(self.Integer(int_value)))

if not new_args3:
expr = Tr(self.Integer(1))
elif len(new_args3) == 1:
[expr] = new_args3
else:
expr = self.mul(*new_args3)

return expr

def flatten_pow(self, args: list[Tree]) -> Tree:
base, exponent = args

# (x**y)**a -> x**(a*y) for integer a
if base.children and base.children[0] == self.pow:
base_base, base_exp = base.children[1:]
if not exponent.children and exponent.value.atom_type == self.Integer:
exponent = self.flatten_mul([base_exp, exponent])
base = base_base

if exponent == Tr(self.Integer(0)):
expr = Tr(self.Integer(1))
elif exponent == Tr(self.Integer(1)):
expr = base
else:
expr = self.pow(base, exponent)
return expr

def flatten(self, expr: Tree) -> Tree:
"""Apply the standard ring simplification rules.

Identity (addition): :math:`x + 0 = x`
Identity (multiplication): :math:`x * 1 = x`
Associativity (addition): :math:`(x + y) + z = x + (y + z)`
Associativity (multiplication): :math:`(x * y) * z = x * (y * z)`
Commutativity (addition): :math:`x + y = y + x`
Commutativity (multiplication): :math:`x * y = y * x`
Add to Mul: :math:`2*x + 3*x = 5*x`
Mul to Pow: :math:`x^2 * x^3 = x^5`
"""
graph = forward_graph(expr)
stack = list(graph.atoms)
for func, indices in graph.operations:

args = [stack[i] for i in indices]

if func == self.add:
expr = self.flatten_add(args)
elif func == self.mul:
expr = self.flatten_mul(args)
elif func == self.pow and len(args) == 2:
expr = self.flatten_pow(args)
else:
expr = func(*args)

stack.append(expr)

return stack[-1]

def flatten(self, expr: Tree) -> Tree:
"""Apply common ring simplification rules.

Identity (addition): :math:`x + 0 = x`
Identity (multiplication): :math:`x * 1 = x`
Associativity (addition): :math:`(x + y) + z = x + (y + z)`
Associativity (multiplication): :math:`(x * y) * z = x * (y * z)`
Commutativity (addition): :math:`x + y = y + x`
Commutativity (multiplication): :math:`x * y = y * x`
Add to Mul: :math:`2*x + 3*x = 5*x`
Mul to Pow: :math:`x^2 * x^3 = x^5`
...
"""
graph = forward_graph(expr)
stack = list(graph.atoms)
for func, indices in graph.operations:

args = [stack[i] for i in indices]

if func == self.add:
expr = self.flatten_add(args)
elif func == self.mul:
expr = self.flatten_mul(args)
elif func == self.pow and len(args) == 2:
expr = self.flatten_pow(args)
else:
expr = func(*args)

stack.append(expr)

return stack[-1]


@dataclass(frozen=True)
class DiffProperties:
"""Collection of properties needed for differentiation."""
Expand Down
21 changes: 21 additions & 0 deletions src/protosym/core/sym.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from protosym.core.atom import AtomType
from protosym.core.differentiate import diff_forward
from protosym.core.differentiate import DiffProperties
from protosym.core.differentiate import RingOps
from protosym.core.evaluate import Evaluator
from protosym.core.exceptions import BadRuleError
from protosym.core.tree import SubsFunc
Expand Down Expand Up @@ -617,6 +618,26 @@ def __call__(
return self.evaluator(expr.rep, values_rep)


class SymRingOps(Generic[T_sym]):
"""Representation of ring operations."""

def __init__(
self,
new_sym: Type[T_sym],
integer: SymAtomType[T_sym, int],
iadd: Callable[[int, int], int],
imul: Callable[[int, int], int],
add: Sym,
mul: Sym,
pow: Sym,
):
self.new_sym = new_sym
self.ringops = RingOps(integer.atom_type, iadd, imul, add.rep, mul.rep, pow.rep)

def __call__(self, expr: T_sym) -> T_sym:
return self.new_sym(self.ringops.flatten(expr.rep))


class SymDifferentiator(Generic[T_sym]):
"""Representation of differentiation rules.

Expand Down
49 changes: 42 additions & 7 deletions src/protosym/simplecas/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

from functools import reduce
from functools import wraps
from operator import add
from operator import mul
from typing import Any
from typing import Callable
from typing import Optional
Expand All @@ -17,6 +19,7 @@
from protosym.core.sym import HeadRule
from protosym.core.sym import Sym
from protosym.core.sym import SymDifferentiator
from protosym.core.sym import SymRingOps
from protosym.core.tree import SubsFunc
from protosym.core.tree import topological_sort
from protosym.simplecas.exceptions import ExpressifyError
Expand Down Expand Up @@ -425,26 +428,50 @@ def count_ops_graph(self) -> int:
"""
return len(topological_sort(self.rep))

def diff(self, sym: Expr, ntimes: int = 1) -> Expr:
def flatten(self) -> Expr:
"""Apply the usual simplification rules for Add, Mul and Pow."""
return ring_ops(self)

def diff(self, sym: Expr, ntimes: int = 1, flatten: bool = True) -> Expr:
"""Differentiate ``expr`` wrt ``sym`` (``ntimes`` times).

>>> from protosym.simplecas import x, sin
>>> sin(x).diff(x)
cos(x)

Currently no simplification is done which can lead to some strange
looking output:
Large expressions can be generated and differentiated efficiently:

>>> expr = sin(sin(sin(sin(sin(x))))).diff(x, 10)
>>> expr.count_ops_graph()
20427
>>> expr.count_ops_tree()
597557

By default ``flatten`` is called during the calculation of derivatives
but that can be disabled by passing ``flatten=False`` which gives
unsimplified derivative expressions:

>>> sin(x).diff(x, 4)
sin(x)
>>> sin(x).diff(x, 4, flatten=False)
(-1*(-1*sin(x)))

Large expressions can be generated and differentiated efficiently:
Although ``flatten`` makes simple expressions look simpler it can also
make the graph structure of large expressions more complicated:

>>> expr = sin(sin(sin(sin(sin(x))))).diff(x, 10)
>>> expr.count_ops_graph()
>>> expr = sin(sin(sin(sin(sin(x)))))
>>> expr.diff(x, 10, flatten=False).count_ops_graph()
1552
>>> expr.count_ops_tree()
>>> expr.diff(x, 10, flatten=True).count_ops_graph()
20427
>>> expr.diff(x, 10, flatten=False).count_ops_tree()
893621974
>>> expr.diff(x, 10, flatten=True).count_ops_tree()
597557

Note that the smallest operation count here arises when not flattening
and when using the graph representation rather than the tree
representation.

Differentiation rules for new functions can be added as needed:

Expand Down Expand Up @@ -473,6 +500,8 @@ def diff(self, sym: Expr, ntimes: int = 1) -> Expr:
deriv = self
for _ in range(ntimes):
deriv = diff(deriv, sym)
if flatten:
deriv = deriv.flatten()
return deriv

def bin_expand(self) -> Expr:
Expand Down Expand Up @@ -582,6 +611,12 @@ def call(self, args: list[Tree]) -> Tree:
count_ops_tree[AtomRule[a]] = one_func(a)
count_ops_tree[HeadRule(a, b)] = sum_plus_one(a, b)

#
# Basic ring operations for simplifying Add, Mul, Pow.
#

ring_ops = SymRingOps(Expr, Integer, iadd=add, imul=mul, add=Add, mul=Mul, pow=Pow)

#
# Differentiation.
#
Expand Down
Loading