21
21
22
22
from pymc .blocking import RaveledVars , StatsType
23
23
from pymc .model import modelcontext
24
- from pymc .step_methods .arraystep import ArrayStep
24
+ from pymc .pytensorf import compile_pymc , join_nonshared_inputs , make_shared_replacements
25
+ from pymc .step_methods .arraystep import ArrayStepShared
25
26
from pymc .step_methods .compound import Competence
26
27
from pymc .util import get_value_vars_from_user_vars
27
28
from pymc .vartypes import continuous_types
31
32
LOOP_ERR_MSG = "max slicer iters %d exceeded"
32
33
33
34
34
- class Slice (ArrayStep ):
35
+ class Slice (ArrayStepShared ):
35
36
"""
36
37
Univariate slice sampler step method.
37
38
@@ -57,56 +58,60 @@ class Slice(ArrayStep):
57
58
}
58
59
59
60
def __init__ (self , vars = None , w = 1.0 , tune = True , model = None , iter_limit = np .inf , ** kwargs ):
60
- self . model = modelcontext (model )
61
- self .w = w
61
+ model = modelcontext (model )
62
+ self .w = np . asarray ( w ). copy ()
62
63
self .tune = tune
63
64
self .n_tunes = 0.0
64
65
self .iter_limit = iter_limit
65
66
66
67
if vars is None :
67
- vars = self . model .continuous_value_vars
68
+ vars = model .continuous_value_vars
68
69
else :
69
- vars = get_value_vars_from_user_vars (vars , self . model )
70
+ vars = get_value_vars_from_user_vars (vars , model )
70
71
71
- super ().__init__ (vars , [self .model .compile_logp ()], ** kwargs )
72
+ point = model .initial_point ()
73
+ shared = make_shared_replacements (point , vars , model )
74
+ [logp ], raveled_inp = join_nonshared_inputs (
75
+ point = point , outputs = [model .logp ()], inputs = vars , shared_inputs = shared
76
+ )
77
+ self .logp = compile_pymc ([raveled_inp ], logp )
78
+ self .logp .trust_input = True
72
79
73
- def astep (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
80
+ super ().__init__ (vars , shared )
81
+
82
+ def astep (self , apoint : RaveledVars ) -> Tuple [RaveledVars , StatsType ]:
74
83
# The arguments are determined by the list passed via `super().__init__(..., fs, ...)`
75
- logp = args [0 ]
76
84
q0_val = apoint .data
77
- self .w = np .resize (self .w , len (q0_val )) # this is a repmat
85
+
86
+ if q0_val .shape != self .w .shape :
87
+ self .w = np .resize (self .w , len (q0_val )) # this is a repmat
78
88
79
89
nstep_out = nstep_in = 0
80
90
81
91
q = np .copy (q0_val )
82
92
ql = np .copy (q0_val ) # l for left boundary
83
93
qr = np .copy (q0_val ) # r for right boundary
84
94
85
- # The points are not copied, so it's fine to update them inplace in the
86
- # loop below
87
- q_ra = RaveledVars (q , apoint .point_map_info )
88
- ql_ra = RaveledVars (ql , apoint .point_map_info )
89
- qr_ra = RaveledVars (qr , apoint .point_map_info )
90
-
95
+ logp = self .logp
91
96
for i , wi in enumerate (self .w ):
92
97
# uniformly sample from 0 to p(q), but in log space
93
- y = logp (q_ra ) - nr .standard_exponential ()
98
+ y = logp (q ) - nr .standard_exponential ()
94
99
95
100
# Create initial interval
96
101
ql [i ] = q [i ] - nr .uniform () * wi # q[i] + r * w
97
102
qr [i ] = ql [i ] + wi # Equivalent to q[i] + (1-r) * w
98
103
99
104
# Stepping out procedure
100
105
cnt = 0
101
- while y <= logp (ql_ra ): # changed lt to leq for locally uniform posteriors
106
+ while y <= logp (ql ): # changed lt to leq for locally uniform posteriors
102
107
ql [i ] -= wi
103
108
cnt += 1
104
109
if cnt > self .iter_limit :
105
110
raise RuntimeError (LOOP_ERR_MSG % self .iter_limit )
106
111
nstep_out += cnt
107
112
108
113
cnt = 0
109
- while y <= logp (qr_ra ):
114
+ while y <= logp (qr ):
110
115
qr [i ] += wi
111
116
cnt += 1
112
117
if cnt > self .iter_limit :
@@ -115,7 +120,7 @@ def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
115
120
116
121
cnt = 0
117
122
q [i ] = nr .uniform (ql [i ], qr [i ])
118
- while y > logp (q_ra ): # Changed leq to lt, to accommodate for locally flat posteriors
123
+ while y > logp (q ): # Changed leq to lt, to accommodate for locally flat posteriors
119
124
# Sample uniformly from slice
120
125
if q [i ] > q0_val [i ]:
121
126
qr [i ] = q [i ]
0 commit comments