Skip to content

Commit

Permalink
patched exponentiation bug in octave/matlab
Browse files Browse the repository at this point in the history
  • Loading branch information
allen-adastra committed May 13, 2020
1 parent 158103c commit 0e38f70
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 16 deletions.
3 changes: 1 addition & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ sos_risk_assessment/data
*.egg-info/
.installed.cfg
*.egg
build/

build/
35 changes: 27 additions & 8 deletions algebraic_moments/code_printer.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,46 @@
from sympy.printing import octave_code
from sympy.printing.pycode import pycode
from sympy.utilities.codegen import codegen

class CodePrinter(object):
def print_moment_constraints(self, moment_constraints, moments, language):
def print_moment_constraints(self, moment_constraints, moments, deterministic_variables, language):
# Essentially a switch statement.
method = getattr(self, language, lambda: "Input language is not supported.")
return method(moment_constraints, moments)
return method(moment_constraints, moments, deterministic_variables)

def python(self, moment_constraints, moments):
def python(self, moment_constraints, moments, deterministic_variables):
# Parse required inputs.
print("# Parse required inputs.")
for moment in moments:
print(str(moment) + " = input_moments[\"" + str(moment) + "\"]")

for det_var in deterministic_variables:
print(str(det_var) +" = input_deterministic[\"" + str(det_var) + "\"]" )

# Generate constraint expressions.
print("\n# Moment constraints.")
for i, cons in enumerate(moment_constraints):
print("g" + str(i) + " = " + pycode(cons))

def matlab(self, moment_constraints, moments):
"""The octave function should produce code that is compatible with matlab.
def matlab(self, moment_constraints, moments, deterministic_variables):
"""The sympy function octave_code is designed to produce MATLAB compatible code.
"""
return self.octave(moment_constraints, moments)
return self.octave(moment_constraints, moments, deterministic_variables)

def octave(self, moment_constraints, moments):
def octave(self, moment_constraints, moments, deterministic_variables):
# Parse required inputs.
print("% Parse required inputs.")
for moment in moments:
print(str(moment) + " = input_moments." + str(moment) + ";")

for det_var in deterministic_variables:
print(str(det_var) + " = input_deterministic." + str(det_var) + ";")

# Generate constraint expressions
print("\n% Moment constraints.")
for i, cons in enumerate(moment_constraints):
print(octave_code(cons, assign_to="g"+str(i)))
# There is a bug in SymPy which expresses exponentiation in the form
# of a**b, which is compatible with Octave not MATLAB. An issue was
# opened to replace it with a^b, which is compatible with both languges.
# For now, the temporary fix is to perform a string replace.
print(octave_code(cons, assign_to="g"+str(i)).replace("**", "^"))
7 changes: 4 additions & 3 deletions algebraic_moments/moment_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from code_printer import CodePrinter
from objects import Moment

def generate_moment_constraints(expressions, random_vector, language):
"""
def generate_moment_constraints(expressions, random_vector, deterministic_variables, language):
"""[summary]
Args:
expressions ([type]): [description]
random_vector ([type]): [description]
deterministic_variables ([type]): [description]
language ([type]): [description]
"""
moments = [] # List of generated moments.
Expand All @@ -20,7 +21,7 @@ def generate_moment_constraints(expressions, random_vector, language):
results.append(moment_form(exp, random_vector, moments))

code_printer = CodePrinter()
code_printer.print_moment_constraints(results, moments, language)
code_printer.print_moment_constraints(results, moments, deterministic_variables, language)

def moment_form(expression, random_vector, moments):
raw_polynomial = sp.poly(expression, random_vector.variables)
Expand Down
10 changes: 7 additions & 3 deletions algebraic_moments/test/test_moment_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
dir_path = os.path.dirname(os.path.realpath(__file__))
import sys
sys.path.append(dir_path + "/../")
from objects import RandomVariable, RandomVector
from objects import RandomVariable, RandomVector, DeterministicVariable
from moment_form import generate_moment_constraints

def test_foo():
x = RandomVariable("x")
y = RandomVariable("y")
c = DeterministicVariable("c")
vector = RandomVector([x, y], [])
expressions = [(x*y**2 + y)**2, (y*x**2 + y**2)**3]
generate_moment_constraints(expressions, vector, "matlab")
expressions = [(c * x*y**2 + y)**2, (y*x**2 + c*y**2)**3]
deterministic_variables = [c]
generate_moment_constraints(expressions, vector, deterministic_variables, "matlab")

test_foo()

0 comments on commit 0e38f70

Please sign in to comment.