@@ -34,19 +34,66 @@ def padded_where(x, to_len, padval=-1):
3434class PartialOrder (Transform ):
3535 """Create a PartialOrder transform
3636
37- This is a more flexible version of the pymc ordered transform that
37+ A more flexible version of the pymc ordered transform that
3838 allows specifying a (strict) partial order on the elements.
3939
40- It works in O(N*D) in runtime, but takes O(N^3) in initialization,
41- where N is the number of nodes in the dag and
42- D is the maximum in-degree of a node in the transitive reduction.
43-
40+ Examples
41+ --------
42+ .. code:: python
43+
44+ import numpy as np
45+ import pymc as pm
46+ import pymc_extras as pmx
47+
48+ # Define two partial orders on 4 elements
49+ # am[i,j] = 1 means i < j
50+ adj_mats = np.array([
51+ # 0 < {1, 2} < 3
52+ [[0, 1, 1, 0],
53+ [0, 0, 0, 1],
54+ [0, 0, 0, 1],
55+ [0, 0, 0, 0]],
56+
57+ # 1 < 0 < 3 < 2
58+ [[0, 0, 0, 1],
59+ [1, 0, 0, 0],
60+ [0, 0, 0, 0],
61+ [0, 0, 1, 0]],
62+ ])
63+
64+ # Create the partial order from the adjacency matrices
65+ po = pmx.PartialOrder(adj_mats)
66+
67+ with pm.Model() as model:
68+ # Generate 3 samples from both partial orders
69+ pm.Normal("po_vals", shape=(3,2,4), transform=po,
70+ initval=po.initvals((3,2,4)))
71+
72+ idata = pm.sample()
73+
74+ # Verify that for first po, the zeroth element is always the smallest
75+ assert (idata.posterior['po_vals'][:,:,:,0,0] <
76+ idata.posterior['po_vals'][:,:,:,0,1:]).all()
77+
78+ # Verify that for second po, the second element is always the largest
79+ assert (idata.posterior['po_vals'][:,:,:,1,2] >=
80+ idata.posterior['po_vals'][:,:,:,1,:]).all()
81+
82+ Technical notes
83+ ----------------
84+ Partial order needs to be strict, i.e. without equalities.
85+ A DAG defining the partial order is sufficient, as transitive closure is automatically computed.
86+ Code works in O(N*D) in runtime, but takes O(N^3) in initialization,
87+ where N is the number of nodes in the dag and D is the maximum
88+ in-degree of a node in the transitive reduction.
4489 """
4590
4691 name = "partial_order"
4792
4893 def __init__ (self , adj_mat ):
4994 """
95+ Initialize the PartialOrder transform
96+
5097 Parameters
5198 ----------
5299 adj_mat: ndarray
@@ -99,10 +146,43 @@ def __init__(self, adj_mat):
99146 self .dag = np .swapaxes (dag_T , - 2 , - 1 )
100147 self .is_start = np .all (self .dag [..., :, :] == - 1 , axis = - 1 )
101148
102- def initvals (self , lower = - 1 , upper = 1 ):
149+ def initvals (self , shape = None , lower = - 1 , upper = 1 ):
150+ """
151+ Create a set of appropriate initial values for the variable.
152+ NB! It is important that proper initial values are used,
153+ as only properly ordered values are in the range of the transform.
154+
155+ Parameters
156+ ----------
157+ shape: tuple, default None
158+ shape of the initial values. If None, adj_mat[:-1] is used
159+ lower: float, default -1
160+ lower bound for the initial values
161+ upper: float, default 1
162+ upper bound for the initial values
163+
164+ Returns
165+ -------
166+ vals: ndarray
167+ initial values for the transformed variable
168+ """
169+
170+ if shape is None :
171+ shape = self .dag .shape [:- 1 ]
172+
173+ if shape [- len (self .dag .shape [:- 1 ]) :] != self .dag .shape [:- 1 ]:
174+ raise ValueError ("Shape must match the shape of the adjacency matrix" )
175+
176+ # Create the initial values
103177 vals = np .linspace (lower , upper , self .dag .shape [- 2 ])
104178 inds = np .argsort (self .ts_inds , axis = - 1 )
105- return vals [inds ]
179+ ivals = vals [inds ]
180+
181+ # Expand the initial values to the extra dimensions
182+ extra_dims = shape [: - len (self .dag .shape [:- 1 ])]
183+ ivals = np .tile (ivals , extra_dims + tuple ([1 ] * len (self .dag .shape [:- 1 ])))
184+
185+ return ivals
106186
107187 def backward (self , value , * inputs ):
108188 minv = dtype_minval (value .dtype )
0 commit comments