Skip to content

Commit

Permalink
test: add numjac test and refactor the files
Browse files Browse the repository at this point in the history
  • Loading branch information
rzyu45 committed Feb 20, 2024
1 parent 15bf810 commit 922575f
Show file tree
Hide file tree
Showing 32 changed files with 297 additions and 197 deletions.
4 changes: 2 additions & 2 deletions Solverz/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from Solverz.equation.param import Param, IdxParam, TimeSeriesParam
from Solverz.sym_algebra.symbols import idx, Para, Var, AliasVar
from Solverz.sym_algebra.functions import Sign, Abs, transpose, exp, Diag, Mat_Mul, sin, cos
from Solverz.numerical_interface.custom_function import minmod_flag, minmod
from Solverz.num_api.custom_function import minmod_flag, minmod
from Solverz.variable.variables import Vars, TimeVars, as_Vars
from Solverz.solvers.nlaesolver import nr_method, continuous_nr
from Solverz.solvers.daesolver import Rodas, Opt, implicit_trapezoid, backward_euler
from Solverz.solvers.fdesolver import fdae_solver, fdae_ss_solver
from Solverz.numerical_interface.num_eqn import made_numerical, parse_dae_v, parse_ae_v, render_modules
from Solverz.code_printer.make_pyfunc import made_numerical
from Solverz.utilities.io import save, load
from Solverz.utilities.profile import count_time
File renamed without changes.
32 changes: 32 additions & 0 deletions Solverz/code_printer/make_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import List

from Solverz.code_printer.py_printer import render_modules as py_render
from Solverz.equation.equations import AE, FDAE, DAE
from Solverz.variable.variables import Vars


class module_printer:
def __init__(self,
mdl: AE | FDAE | DAE,
variables: Vars | List[Vars],
name: str,
lang='python',
directory=None,
jit=False):
self.name = name
self.lang = lang
self.mdl = mdl
if isinstance(variables, Vars):
self.variables = [variables]
else:
self.variables = variables
self.directory = directory
self.jit = jit

def render(self):
if self.lang == 'python':
py_render(self.mdl,
*self.variables,
name=self.name,
directory=self.directory,
numba=self.jit)
5 changes: 5 additions & 0 deletions Solverz/code_printer/make_pyfunc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
import numpy as np

from Solverz.code_printer.py_printer import made_numerical


File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,125 @@
from numbers import Number

from Solverz.equation.equations import Equations as SymEquations, FDAE as SymFDAE, DAE as SymDAE, AE as SymAE
from Solverz.variable.variables import Vars
from Solverz.variable.variables import Vars, TimeVars
from Solverz.equation.param import TimeSeriesParam
from Solverz.sym_algebra.symbols import Var, SolDict, Para, idx, IdxSymBasic
from Solverz.sym_algebra.functions import zeros, CSC_array, Arange
from Solverz.utilities.address import Address
from Solverz.utilities.io import save

from Solverz.variable.variables import combine_Vars
from Solverz.num_api.custom_function import numerical_interface
from Solverz.num_api.num_eqn import nAE, nFDAE, nDAE

# %%


def parse_p(ae: SymEquations):
p = dict()
for param_name, param in ae.PARAM.items():
if isinstance(param, TimeSeriesParam):
p.update({param_name: param})
else:
p.update({param_name: param.v})
return p


def parse_trigger_fun(ae: SymEquations):
func = dict()
for para_name, param in ae.PARAM.items():
func.update({para_name + '_trigger_func': param.trigger_fun})

return func


def made_numerical(eqn: SymEquations, *xys, sparse=False, output_code=False):
"""
factory method of numerical equations
"""
print(f"Printing numerical codes of {eqn.name}")
eqn.assign_eqn_var_address(*xys)
code_F = print_F(eqn)
code_J = print_J(eqn, sparse)
custom_func = dict()
custom_func.update(numerical_interface)
custom_func.update(parse_trigger_fun(eqn))
F = Solverzlambdify(code_F, 'F_', modules=[custom_func, 'numpy'])
J = Solverzlambdify(code_J, 'J_', modules=[custom_func, 'numpy'])
p = parse_p(eqn)
print('Complete!')
if isinstance(eqn, SymAE) and not isinstance(eqn, SymFDAE):
num_eqn = nAE(F, J, p)
elif isinstance(eqn, SymFDAE):
num_eqn = nFDAE(F, J, p, eqn.nstep)
elif isinstance(eqn, SymDAE):
num_eqn = nDAE(eqn.M, F, J, p)
else:
raise ValueError(f'Unknown equation type {type(eqn)}')
if output_code:
return num_eqn, {'F': code_F, 'J': code_J}
else:
return num_eqn


def render_modules(eqn: SymEquations, *xys, name, directory=None, numba=False):
"""
factory method of numerical equations
"""
print(f"Printing python codes of {eqn.name}...")
eqn.assign_eqn_var_address(*xys)
p = parse_p(eqn)
code_F = print_F_numba(eqn)
code_inner_F = print_inner_F(eqn)
code_sub_inner_F = print_sub_inner_F(eqn)
code_J = print_J_numba(eqn)
codes = print_inner_J(eqn, *xys)
code_inner_J = codes['code_inner_J']
code_sub_inner_J = codes['code_sub_inner_J']
custom_func = dict()
custom_func.update(numerical_interface)
custom_func.update(parse_trigger_fun(eqn))

print('Complete!')

eqn_parameter = {}
if isinstance(eqn, SymAE) and not isinstance(eqn, SymFDAE):
eqn_type = 'AE'
elif isinstance(eqn, SymFDAE):
eqn_type = 'FDAE'
eqn_parameter.update({'nstep': eqn.nstep})
elif isinstance(eqn, SymDAE):
eqn_type = 'DAE'
eqn_parameter.update({'M': eqn.M})
else:
raise ValueError(f'Unknown equation type {type(eqn)}')

if len(xys) == 1:
y = xys[0]
else:
y = xys[0]
for arg in xys[1:]:
y = combine_Vars(y, arg)

code_dict = {'F': code_F,
'inner_F': code_inner_F,
'sub_inner_F': code_sub_inner_F,
'J': code_J,
'inner_J': code_inner_J,
'sub_inner_J': code_sub_inner_J}
eqn_parameter.update({'row': codes['row'], 'col': codes['col'], 'data': codes['data']})
print(f"Rendering python modules!")
render_as_modules(name,
code_dict,
eqn_type,
p,
eqn_parameter,
y,
[custom_func, 'numpy'],
numba,
directory)
print('Complete!')


def is_valid_python_module_name(module_name):
pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$')
return bool(pattern.match(module_name))
Expand Down Expand Up @@ -58,20 +168,20 @@ def print_init_code(eqn_type: str, module_name, eqn_param):
code += 'import time\n'
match eqn_type:
case 'AE':
code += 'from Solverz.numerical_interface.num_eqn import nAE\n'
code += 'from Solverz.num_api.num_eqn import nAE\n'
code += 'mdl = nAE(F_, J_, p)\n'
code_compile = 'mdl.F(y.array, p)\nmdl.J(y.array, p)\n'
case 'FDAE':
try:
nstep = eqn_param['nstep']
except KeyError as e:
raise ValueError("Cannot parse nstep attribute for FDAE object printing!")
code += 'from Solverz.numerical_interface.num_eqn import nFDAE\n'
code += 'from Solverz.num_api.num_eqn import nFDAE\n'
code += 'mdl = nFDAE(F_, J_, p, setting["nstep"])\n'
args_str = ', '.join(['y.array' for i in range(nstep)])
code_compile = f'mdl.F(0, y.array, p, {args_str})\nmdl.J(0, y.array, p, {args_str})\n'
case 'DAE':
code += 'from Solverz.numerical_interface.num_eqn import nDAE\n'
code += 'from Solverz.num_api.num_eqn import nDAE\n'
code += 'mdl = nDAE(setting["M"], F_, J_, p)\n'
code_compile = 'mdl.F(0, y.array, p)\nmdl.J(0, y.array, p)\n'
case _:
Expand Down Expand Up @@ -516,7 +626,8 @@ def print_J_numba(ae: SymEquations):
body.extend(param_assignments)
body.extend(print_trigger(ae))
body.extend([Assignment(Var('data', internal_use=True),
FunctionCall('inner_J', [symbols('_data_', real=True)] + [arg.name for arg in var_list + param_list]))])
FunctionCall('inner_J', [symbols('_data_', real=True)] + [arg.name for arg in
var_list + param_list]))])
body.extend([Return(coo_2_csc(ae))])
fd = FunctionDefinition.from_FunctionPrototype(fp, body)
return pycode(fd, fully_qualified_modules=False)
Expand Down Expand Up @@ -577,7 +688,8 @@ def print_F_numba(ae: SymEquations):
param_assignments, param_list = print_param(ae, numba_printer=True)
body.extend(param_assignments)
body.extend(print_trigger(ae))
body.extend([Return(FunctionCall('inner_F', [symbols('_F_', real=True)] + [arg.name for arg in var_list + param_list]))])
body.extend(
[Return(FunctionCall('inner_F', [symbols('_F_', real=True)] + [arg.name for arg in var_list + param_list]))])
fd = FunctionDefinition.from_FunctionPrototype(fp, body)
return pycode(fd, fully_qualified_modules=False)

Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion Solverz/equation/eqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from Solverz.sym_algebra.functions import Mat_Mul, Slice, F
from Solverz.sym_algebra.matrix_calculus import MixedEquationDiff
from Solverz.sym_algebra.transform import finite_difference, semi_descritize
from Solverz.numerical_interface.custom_function import numerical_interface
from Solverz.num_api.custom_function import numerical_interface


class Eqn:
Expand Down
2 changes: 1 addition & 1 deletion Solverz/equation/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
from scipy.interpolate import interp1d

from Solverz.numerical_interface.Array import Array
from Solverz.num_api.Array import Array


class Param:
Expand Down
File renamed without changes.
Empty file added Solverz/num_api/__init__.py
Empty file.
File renamed without changes.
43 changes: 43 additions & 0 deletions Solverz/num_api/num_eqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import Callable, Dict

import numpy as np

from Solverz.equation.param import TimeSeriesParam
from Solverz.variable.variables import Vars, TimeVars, combine_Vars


class nAE:

def __init__(self,
F: Callable,
J: Callable,
p: Dict):
self.F = F
self.J = J
self.p = p


class nFDAE:

def __init__(self,
F: callable,
J: callable,
p: dict,
nstep: int = 0):
self.F = F
self.J = J
self.p = p
self.nstep = nstep


class nDAE:

def __init__(self,
M,
F: Callable,
J: Callable,
p: Dict):
self.M = M
self.F = F
self.J = J
self.p = p
39 changes: 39 additions & 0 deletions Solverz/solvers/numjac.py → Solverz/num_api/numjac.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,42 @@ def numjac(F, t, y, Fty, thresh, S, g, vecon, central_diff):
dFdyt = csc_array(df)

return Fty, dFdyt, nfevals


def numjac_ae(F, y, thresh):
# Add t to the end of y and adjust thresh accordingly

facmax = 0.1
ny = len(y)
fac = np.sqrt(np.finfo(float).eps) + np.zeros(ny)
yscale = np.maximum(0.1 * np.abs(y), thresh)
del_ = (y + fac * yscale) - y
jj = np.where(del_ == 0)[0]

for j in jj:
while True:
if fac[j] < facmax:
fac[j] = min(100 * fac[j], facmax)
del_[j] = (y[j] + fac[j] * yscale[j]) - y[j]
if del_[j] != 0:
break
else:
del_[j] = thresh[j]
break

nfevals = ny
df = np.zeros((ny, ny))

for jj in range(0, ny):
ydel = y.copy()
ydel[jj] += del_[jj]

ydel_1 = y.copy()
ydel_1[jj] -= del_[jj]
df[:, jj] = (F(ydel) - F(ydel_1)) / (2 * del_[jj])
nfevals += 1

# Convert df to sparse matrix dFdyt
dFdyt = csc_array(df)

return dFdyt, nfevals
Empty file.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from scipy.sparse import csc_array

from Solverz.numerical_interface.Array import Array
from Solverz.num_api.Array import Array


def test_Array():
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from sympy import symbols, pycode, Integer

from Solverz.numerical_interface.code_printer import _parse_jac_eqn_address, _parse_jac_var_address, _parse_jac_data, \
from Solverz.code_printer.py_printer import _parse_jac_eqn_address, _parse_jac_var_address, _parse_jac_data, \
print_J_block, _print_F_assignment, _print_var_parser
from Solverz.sym_algebra.symbols import idx, Var, Para

Expand Down
25 changes: 25 additions & 0 deletions Solverz/num_api/test/test_numjac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy as np

from Solverz.num_api.numjac import numjac_ae


def f(x):
a = 1 + np.exp(x[0]) + np.sin(x[1])
b = x[1] ** 2 + np.cos(x[0])
return np.array([a, b])


def Janaly(x):
return np.array([[np.exp(x[0]), np.cos(x[1])],
[-np.sin(x[0]), 2 * x[1]]])


def Jnum(x):
return numjac_ae(lambda xi: f(xi),
x,
np.ones(len(x)) * 1.e-12)[0]


point = np.array([1.0, 2.0])

np.testing.assert_allclose(Jnum(point).toarray(), Janaly(point), rtol=1e-5)
Loading

0 comments on commit 922575f

Please sign in to comment.