-
Notifications
You must be signed in to change notification settings - Fork 41
/
autograd_extended.py
93 lines (67 loc) · 2.56 KB
/
autograd_extended.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""
From Appendix B in the paper
Implementation of autograd
"""
import scipy.integrate
import autograd.numpy as np
from autograd.extend import primitive, defvjp_argnums
from autograd import make_vjp
from autograd.misc import flatten
from autograd.builtins import tuple
odeint = primitive(scipy.integrate.odeint)
def grad_odeint_all(yt, func, y0, t, func_args, **kwargs):
"""
Extended from "Scalable Inference of Ordinary Differential"
Equation Models of Biochemical Processes". Sec. 2.4.2
Fabian Froehlich, Carolin Loos, Jan Hasenauer, 2017
https://arxiv.org/pdf/1711.08079.pdf
"""
T, D = np.shape(yt)
flat_args, unflatten = flatten(func_args)
def flat_func(y, t, flat_args):
return func(y, t, *unflatten(flat_args))
def unpack(x):
# y, vjp_y, vjp_t, vjp_args
return x[0:D], x[D:2 * D], x[2 * D], x[2 * D + 1:]
def augmented_dynamics(augmented_state, t, flat_args):
# Original system augemented with vjp_y, vjp_t and vjp_args
y, vjp_y, _, _ = unpack(augmented_state)
vjp_all, dy_dt = make_vjp(flat_func, argnum=(0, 1, 2))(y, t, flat_args)
vjp_y, vjp_t, vjp_args = vjp_all(-vjp_y)
return np.hstack((dy_dt, vjp_y, vjp_t, vjp_args))
def vjp_all(g, **kwargs):
vjp_y = g[-1, :]
vjp_t0 = 0
time_vjp_list = []
vjp_args = np.zeros(np.size(flat_args))
for i in range(T - 1, 0, -1):
# Compute effect of moving current time.
vjp_cur_t = np.dot(func(yt[i, :], t[i], *func_args), g[i, :])
time_vjp_list.append(vjp_cur_t)
vjp_t0 = vjp_t0 - vjp_cur_t
# Run augmented system backwards to the previous observation
aug_y0 = np.hstack((yt[i, :], vjp_y, vjp_t0, vjp_args))
aug_ans = odeint(augmented_dynamics, aug_y0,
np.array(t[i], t[i - 1]), tuple((flat_args,)), **kwargs)
_, vjp_y, vjp_t0, vjp_args = unpack(aug_ans[1])
# Add gradient from current output
vjp_y = vjp_y + g[i - 1, :]
time_vjp_list.append(vjp_t0)
vjp_times = np.hstack(time_vjp_list)[::-1]
return None, vjp_y, vjp_times, unflatten(vjp_args)
return vjp_all
def grad_argnums_wrapper(all_vjp_builder):
"""
A generic autograd helper funciton. Takes a function that
builds vjps for all arguments, and wraps it to return only required vjps.
"""
def build_selected_vjps(argnums, ans, combined_args, kwargs):
vjp_func = all_vjp_builder(ans, *combined_args, **kwargs)
def chosen_vjps(g):
# Return whichever vjps were asked for
all_vjps = vjp_func(g)
return [all_vjps[argnum] for argnum in argnums]
return chosen_vjps
return build_selected_vjps
if __name__ == '__main__':
print(defvjp_argnums(odeint, grad_argnums_wrapper(grad_odeint_all)))