Skip to content

Commit 4242faf

Browse files
committed
Speedup Slice sampler
1 parent fee9a02 commit 4242faf

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

pymc/step_methods/slicer.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121

2222
from pymc.blocking import RaveledVars, StatsType
2323
from pymc.model import modelcontext
24-
from pymc.step_methods.arraystep import ArrayStep
24+
from pymc.pytensorf import compile_pymc, join_nonshared_inputs, make_shared_replacements
25+
from pymc.step_methods.arraystep import ArrayStepShared
2526
from pymc.step_methods.compound import Competence
2627
from pymc.util import get_value_vars_from_user_vars
2728
from pymc.vartypes import continuous_types
@@ -31,7 +32,7 @@
3132
LOOP_ERR_MSG = "max slicer iters %d exceeded"
3233

3334

34-
class Slice(ArrayStep):
35+
class Slice(ArrayStepShared):
3536
"""
3637
Univariate slice sampler step method.
3738
@@ -57,56 +58,60 @@ class Slice(ArrayStep):
5758
}
5859

5960
def __init__(self, vars=None, w=1.0, tune=True, model=None, iter_limit=np.inf, **kwargs):
60-
self.model = modelcontext(model)
61-
self.w = w
61+
model = modelcontext(model)
62+
self.w = np.asarray(w).copy()
6263
self.tune = tune
6364
self.n_tunes = 0.0
6465
self.iter_limit = iter_limit
6566

6667
if vars is None:
67-
vars = self.model.continuous_value_vars
68+
vars = model.continuous_value_vars
6869
else:
69-
vars = get_value_vars_from_user_vars(vars, self.model)
70+
vars = get_value_vars_from_user_vars(vars, model)
7071

71-
super().__init__(vars, [self.model.compile_logp()], **kwargs)
72+
point = model.initial_point()
73+
shared = make_shared_replacements(point, vars, model)
74+
[logp], raveled_inp = join_nonshared_inputs(
75+
point=point, outputs=[model.logp()], inputs=vars, shared_inputs=shared
76+
)
77+
self.logp = compile_pymc([raveled_inp], logp)
78+
self.logp.trust_input = True
7279

73-
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
80+
super().__init__(vars, shared)
81+
82+
def astep(self, apoint: RaveledVars) -> Tuple[RaveledVars, StatsType]:
7483
# The arguments are determined by the list passed via `super().__init__(..., fs, ...)`
75-
logp = args[0]
7684
q0_val = apoint.data
77-
self.w = np.resize(self.w, len(q0_val)) # this is a repmat
85+
86+
if q0_val.shape != self.w.shape:
87+
self.w = np.resize(self.w, len(q0_val)) # this is a repmat
7888

7989
nstep_out = nstep_in = 0
8090

8191
q = np.copy(q0_val)
8292
ql = np.copy(q0_val) # l for left boundary
8393
qr = np.copy(q0_val) # r for right boundary
8494

85-
# The points are not copied, so it's fine to update them inplace in the
86-
# loop below
87-
q_ra = RaveledVars(q, apoint.point_map_info)
88-
ql_ra = RaveledVars(ql, apoint.point_map_info)
89-
qr_ra = RaveledVars(qr, apoint.point_map_info)
90-
95+
logp = self.logp
9196
for i, wi in enumerate(self.w):
9297
# uniformly sample from 0 to p(q), but in log space
93-
y = logp(q_ra) - nr.standard_exponential()
98+
y = logp(q) - nr.standard_exponential()
9499

95100
# Create initial interval
96101
ql[i] = q[i] - nr.uniform() * wi # q[i] + r * w
97102
qr[i] = ql[i] + wi # Equivalent to q[i] + (1-r) * w
98103

99104
# Stepping out procedure
100105
cnt = 0
101-
while y <= logp(ql_ra): # changed lt to leq for locally uniform posteriors
106+
while y <= logp(ql): # changed lt to leq for locally uniform posteriors
102107
ql[i] -= wi
103108
cnt += 1
104109
if cnt > self.iter_limit:
105110
raise RuntimeError(LOOP_ERR_MSG % self.iter_limit)
106111
nstep_out += cnt
107112

108113
cnt = 0
109-
while y <= logp(qr_ra):
114+
while y <= logp(qr):
110115
qr[i] += wi
111116
cnt += 1
112117
if cnt > self.iter_limit:
@@ -115,7 +120,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
115120

116121
cnt = 0
117122
q[i] = nr.uniform(ql[i], qr[i])
118-
while y > logp(q_ra): # Changed leq to lt, to accommodate for locally flat posteriors
123+
while y > logp(q): # Changed leq to lt, to accommodate for locally flat posteriors
119124
# Sample uniformly from slice
120125
if q[i] > q0_val[i]:
121126
qr[i] = q[i]

0 commit comments

Comments
 (0)