-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
526cbfb
commit 31e2863
Showing
11 changed files
with
1,458 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# To make module work | ||
|
||
__author__ = 'Mitchell, FHT' | ||
__date__ = (2017, 8, 20) | ||
__verbose__ = True | ||
|
||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[.ShellClassInfo] | ||
InfoTip=This folder is shared online. | ||
IconFile=C:\Program Files (x86)\Google\Drive\googledrivesync.exe | ||
IconIndex=16 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
# driver.py | ||
|
||
|
||
import collections | ||
import numpy as np | ||
|
||
from method import methods | ||
from stepper import Stepper | ||
from problem import Problem | ||
from timers import Timer | ||
from typing import * | ||
|
||
__author__ = 'Mitchell, FHT' | ||
__date__ = (2017, 8, 20) | ||
__verbose__ = True | ||
|
||
def static_stepsize_control_factory(h: float) -> Callable: | ||
def static_stepsize_control(t, y, yd): | ||
return h | ||
return static_stepsize_control | ||
|
||
|
||
valid_methods = {'obr', 'adams'} | ||
valid_problems = {'rtbp', 'shm'} | ||
|
||
class Driver: | ||
def __init__(self, | ||
problem: Problem, | ||
h0: float, | ||
K: int, | ||
L: int = 1, | ||
t0: float = 0., | ||
tf: float = None, | ||
method_name: str = 'obr', | ||
corrector_steps: int = 1, | ||
stepsize_control: Callable = None, | ||
explicit_free_params: Mapping[str, float] = None, | ||
implicit_free_params: Mapping[str, float] = None): | ||
|
||
""" | ||
:param problem: | ||
:param h0: | ||
:param K: | ||
:param L: | ||
:param method_name: | ||
:param corrector_steps: | ||
:param stepsize_control: | ||
:return: | ||
""" | ||
|
||
|
||
assert isinstance(problem, Problem), repr(problem) | ||
assert h0 > 0, repr(h0) | ||
assert isinstance(K, int) and K >= 1, repr(K) | ||
assert isinstance(L, int) and L >= 1, repr(L) | ||
assert isinstance(method_name, str), repr(method_name) | ||
assert isinstance(corrector_steps, int) and corrector_steps >= 0, repr(corrector_steps) | ||
|
||
|
||
|
||
|
||
self.problem = problem | ||
self.h0 = h0 | ||
self.K = K | ||
self.L = L | ||
self.t0 = t0 | ||
self.tf = tf | ||
self.method_name = method_name.lower().strip() | ||
self.corrector_steps = corrector_steps | ||
self.stepsize_control = stepsize_control if stepsize_control is not None \ | ||
else static_stepsize_control_factory(h0) | ||
|
||
self.explicit_free_params = explicit_free_params or {} | ||
self.implicit_free_params = implicit_free_params or {} | ||
|
||
self.Method = methods()[self.method_name] | ||
self.explicit = self.Method(K, L, 'explicit', self.explicit_free_params) | ||
if corrector_steps > 0: | ||
self.implicit = self.Method(K, L, 'implicit', self.implicit_free_params) | ||
else: | ||
self.implicit = None | ||
|
||
self.stepper = Stepper(self.stepsize_control, self.explicit, self.implicit) | ||
|
||
|
||
def run(self, y0: np.ndarray, *, t0=None, tf=None) -> Tuple[np.ndarray, np.ndarray]: | ||
|
||
if t0 is None: | ||
t0 = self.t0 | ||
if tf is None: | ||
tf = self.tf | ||
assert tf is not None, (tf, self.tf) | ||
|
||
y0 = np.array(y0) | ||
# make sure y0 is an array | ||
if y0.shape == (): | ||
# This allows floats to be entered for 1 D | ||
y0 = y0[np.newaxis] | ||
assert len(y0.shape) == 1 | ||
|
||
y = [y0] | ||
yd = collections.deque([self.problem.vector_derivs(t0, y0)], maxlen=self.K) | ||
|
||
if __verbose__: | ||
print('Running driver...') | ||
|
||
t = t0 | ||
k = 1 | ||
counter = 0 | ||
ts = [t0] | ||
if __verbose__: | ||
total_iters = int((tf - t0)/self.h0) | ||
timer = Timer(5) | ||
|
||
while t < tf: | ||
h1, t, y1 = self.stepper.predict(t, y, yd, k) | ||
for _ in range(self.corrector_steps): | ||
y1d = self.problem.vector_derivs(t, y1) | ||
y2 = self.stepper.correct(yd, y1d, k, h1) | ||
# do anything fancy with y2 and y1 | ||
y1 = y2 | ||
|
||
ts.append(t) | ||
y.append(y1) | ||
yd.append(self.problem.vector_derivs(t, y[-1])) | ||
if k < self.K: # Use maximum number of previous step up to K | ||
k += 1 | ||
counter += 1 | ||
if __verbose__ and timer.check(): | ||
print(f'Time elapsed: {timer()}, Iterations: {counter:,} ' | ||
f'(~{100.*counter/total_iters:.1f}%).') | ||
|
||
if __verbose__: | ||
print('Finished iterating.') | ||
print(f'y has length {len(y):,} and final value') | ||
print(f' yf = [{", ".join(format(x, ".3g") for x in y[-1])}].') | ||
print(f'Total time elapsed: {timer()}.') | ||
print('_'*30, end='\n\n') | ||
|
||
return np.array(ts), np.array(y) | ||
|
||
|
||
def displacement(self, y0: np.array, **kwargs): | ||
|
||
t, y = self.run(y0, **kwargs) | ||
return y[-1] - y[0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
# dsympy.py - helper functions for converting differential equations in sympy | ||
# into numeric lambda functions | ||
|
||
import sympy as sp | ||
import collections | ||
from sympy.utilities.lambdify import lambdastr as _lambdastr | ||
from sympy.core.function import AppliedUndef | ||
from typing import * | ||
|
||
__author__ = 'Mitchell, FHT' | ||
__date__ = (2017, 8, 20) | ||
__verbose__ = True | ||
|
||
# needs to be at the top for typing. | ||
class AutoFunc: | ||
""" | ||
Holds the representation of an auto-generated function | ||
""" | ||
|
||
def __init__(self, func, as_str, args, sym_func=None, name=None): | ||
self.func = func | ||
self.as_str = as_str | ||
self.args = self.params = tuple(str(arg) for arg in args) | ||
self.sym_func = sym_func | ||
self.name = name | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.func(*args, **kwargs) | ||
|
||
def __str__(self): | ||
return self.as_str | ||
|
||
def __repr__(self): | ||
if self.name is None: | ||
return "<AutoFunc with arg{}: {}>".format( | ||
's' if len(self.args) != 1 else '', | ||
', '.join(self.args) if self.args else () | ||
) | ||
return f'<AutoFunc: {self.name}>' | ||
|
||
|
||
def dummify_undefined_functions(expr:sp.Expr, ret_map:bool=False) -> sp.expr: | ||
""" | ||
Used for solving issues with lambdify and differentials. Replaces undefined | ||
functions (eg. f(t), df/dt, dg^2/dxdy) with symbols (named f and df_dt and | ||
df_dxdy for the previous examples respectively). | ||
By u/PeterE from | ||
http://stackoverflow.com/questions/29920641/lambdify-a-sp-expression-that-contains-a-derivative-of-undefinedfunction | ||
Issues (u/PeterE): | ||
* no guard against name-collisions | ||
* perhaps not the best possible name-scheme: df_dxdy for Derivative(f(x,y), x, y) | ||
* it is assumed that all derivatives are of the form: Derivative(s(t), t, ...) | ||
with s(t) being an UndefinedFunction and t a Symbol. I have no idea what will | ||
happen if any argument to Derivative is a more complex expression. I kind of | ||
think/hope that the (automatic) simplification process will reduce any more | ||
complex derivative into an expression consisting of 'basic' derivatives. But | ||
I certainly do not guard against it. | ||
* largely untested (except for my specific use-cases) | ||
""" | ||
|
||
mapping = {} | ||
|
||
# replace all Derivative terms | ||
for der in expr.atoms(sp.Derivative): | ||
f_name = der.expr.func.__name__ | ||
var_names = [var.name for var in der.variables] | ||
name = "d%s_d%s" % (f_name, 'd'.join(var_names)) | ||
mapping[der] = sp.Symbol(name) | ||
|
||
# replace undefined functions | ||
for f in expr.atoms(AppliedUndef): | ||
f_name = f.func.__name__ | ||
mapping[f] = sp.Symbol(f_name) | ||
|
||
new_expr = expr.subs(mapping) | ||
return new_expr if not ret_map else (new_expr, mapping) | ||
|
||
|
||
def dlambdify(params: tuple, | ||
expr: sp.Expr, | ||
*, | ||
show: bool=False, | ||
retstr: bool=False, | ||
**kwargs | ||
) -> Callable: | ||
""" | ||
See sp.lambdify. Used to create lambdas (or strings of lambda expressions | ||
if `retstr` is True) from sp expressions. Fixes the issues of derivatives | ||
not working in sp's basic lambdify. | ||
""" | ||
|
||
try: | ||
iter(params) | ||
except TypeError: | ||
params = (params,) | ||
|
||
|
||
#dparams = [dummify_undefined_functions(s) for s in params] | ||
dexpr = dummify_undefined_functions(expr) | ||
|
||
if show or retstr: | ||
s = _lambdastr(params, dexpr, dummify=False, **kwargs) | ||
if show: | ||
print(s) | ||
if retstr: | ||
return s | ||
|
||
return sp.lambdify(params, dexpr, dummify=False, **kwargs) | ||
|
||
def dlambdastr(params: tuple, expr: sp.Expr, **kwargs) -> str: | ||
""" | ||
Equivalent to dlambdify(params, expr, retstr=True, **kwargs) | ||
""" | ||
return dlambdify(params, expr, retstr=True, **kwargs) | ||
|
||
|
||
def auto(expr, | ||
consts: dict = None, | ||
params: List[str] = None, | ||
dfuncs: dict = None, | ||
name: str = None, | ||
*, | ||
show: bool = False, | ||
just_func: bool = False, | ||
**kwargs | ||
) -> AutoFunc: | ||
""" | ||
Similar to dlambdify, but automatically discovers all parameters in | ||
`expr`. | ||
`consts`, if used, should be dict of {sp.Symbol: float}, and will | ||
replace any constants in `expr` with values. Otherwise, they will be included | ||
in the final lambda. | ||
If `show` is True, will print the lambda python expression made. | ||
If `just_func` is True, will only return the function, otherwise a | ||
AutoFunc instance with attributes: func, args, as_str. | ||
""" | ||
|
||
assert hasattr(params, '__iter__'), \ | ||
f'prams must be iterable, currently {type(params).__name__!r}' | ||
|
||
if consts is None: | ||
consts = {} | ||
for const, value in consts.items(): | ||
expr = expr.subs(const, value) | ||
|
||
dexpr = dummify_undefined_functions(expr) | ||
|
||
if dfuncs is None: | ||
dfuncs = {} | ||
for dfunc, value in dfuncs.items(): | ||
dexpr = dexpr.subs(dexpr, value) | ||
|
||
if params is None: | ||
params = sorted(dexpr.atoms(sp.Symbol), | ||
key=lambda s: [len(str(s)), str(s)]) | ||
elif any(isinstance(p, str) for p in params): | ||
# this actually works if params is just a str, not a tuple | ||
params = sp.symbols(params) | ||
|
||
s = _lambdastr(params, dexpr, dummify=False, **kwargs) | ||
if show: | ||
print(s) | ||
|
||
f = sp.lambdify(params, dexpr, dummify=False, **kwargs) | ||
return AutoFunc(f, s, params, sym_func=expr, name=name) if not just_func else f | ||
|
||
|
||
|
||
def test_func() -> NamedTuple: | ||
""" | ||
Returns t, x0, x, f, df | ||
""" | ||
t, x0 = sp.symbols('t x0') | ||
x = sp.Function('x')(t) | ||
|
||
f = 1/sp.sqrt(x - x0) | ||
df = sp.diff(f, t) | ||
|
||
return collections.namedtuple('Syms', 't, x0, x, f, df')(t, x0, x, f, df) | ||
|
||
|
||
def test() -> Tuple[sp.Expr, str]: | ||
""" | ||
Test case to ensure all is working smoothly | ||
""" | ||
|
||
t, x0, x, f, df = test_func() | ||
return df, dlambdastr([x, sp.diff(x, t)], df) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
# main.py | ||
|
||
from driver import Driver | ||
from problem import Lagrange | ||
from solver import Newton | ||
|
||
__author__ = 'Mitchell, FHT' | ||
__date__ = (2017, 8, 20) | ||
__verbose__ = True | ||
|
||
def lagrange(x0=-1e-2, v0=1e-3, dx=1e-5): | ||
""" | ||
4 months of work... for this?!?! | ||
:param x0: | ||
:param v0: | ||
:param dx: | ||
:return: | ||
""" | ||
|
||
problem = Lagrange(1, legendre_order=3) | ||
driver = Driver(problem, h0=0.001, K=2, L=1, tf=0.5) | ||
solver = Newton(driver.displacement) | ||
ans = solver.solve([x0, 0, 0, 0, v0, 0], dx=dx) | ||
|
||
t, y = driver.run(ans[-1]) | ||
problem.plot(t, y) | ||
return ans | ||
|
||
|
||
if __name__ == '__main__': | ||
ans = lagrange() | ||
print(ans[-1]) |
Oops, something went wrong.