-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsample_outcome.py
311 lines (264 loc) · 11.2 KB
/
sample_outcome.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
import inspect
import warnings
import jax
import jax.numpy as jnp
import jax.random as jr
from flowjax.bijections import (
Affine,
Invert,
Tanh,
)
from flowjax.bijections.bijection import AbstractBijection
from flowjax.bijections.utils import Identity
from flowjax.distributions import AbstractDistribution, _StandardUniform
from jaxtyping import ArrayLike
from frugal_flows.bijections import (
LocCond,
MaskedAutoregressiveHeterogeneous,
MaskedAutoregressiveFirstUniform,
UnivariateNormalCDF,
)
def sample_outcome(
key: jr.PRNGKey,
n_samples: float,
causal_model: str,
causal_condition: ArrayLike | None = None,
frugal_flow: AbstractDistribution | None = None,
causal_effect_idx_in_flow: int | None = None,
causal_cdf: AbstractBijection | None = UnivariateNormalCDF,
u_yx: ArrayLike | None = None,
**treatment_kwargs: dict,
):
"""
Samples outcomes from a given causal model using frugal flows.
Args:
key: The PRNGKey for random number generation.
n_samples: The number of outcome samples to generate.
causal_model: The causal model to use for outcome generation. Must be one of ["logistic_regression", "causal_cdf", "location_translation"].
causal_condition: The causal condition to use for outcome generation. Default is None.
frugal_flow: The frugal flow object to use for outcome generation. Default is None. For causal model "location_translation", a frugal flow object is always required, and if u_yx is also provided, the u_yx quantiles will be used to sample from the flow object. For other causal models, either a frugal flow object or u_yx is required.
causal_cdf: The causal CDF object to use for outcome generation. Default is UnivariateNormalCDF.
u_yx: The input samples for the causal model. Default is None, in which case a frugal flow object is always required.
**treatment_kwargs: Additional keyword arguments for the treatment model.
Returns:
outcome_samples: The generated outcome samples.
Raises:
ValueError: If the input arguments are invalid or missing.
"""
valid_causal_models = ["logistic_regression", "causal_cdf", "location_translation"]
if (u_yx is None) & (frugal_flow is None):
raise ValueError("Either a frugal flow object or u_yx is required")
if (
(u_yx is not None)
& (frugal_flow is not None)
& (causal_model != "location_translation")
):
raise ValueError(
f"Only one between frugal flow object and u_yx can be provided for {causal_model} model"
)
if (
(u_yx is not None)
& (frugal_flow is not None)
& (causal_model == "location_translation")
):
# produce a flow_fake_condition even if u_yx is provided as it will be used to sample from the flow object
flow_dim = frugal_flow.shape[0]
if frugal_flow.cond_shape is None:
flow_fake_condition = None
else:
flow_fake_condition = jnp.ones((n_samples, frugal_flow.cond_shape[0]))
warnings.warn(
f"Since both frugal flow object and u_yx are provided to {causal_model} model, u_yx quantiles will be used to sample from the flow object. If you want to fully sample from the flow object, please provide only the frugal flow object."
)
if (causal_model == "location_translation") & (frugal_flow is None):
raise ValueError(
f"A frugal flow object is required for simulating outcome with {causal_model} model"
)
if u_yx is not None:
assert len(u_yx) == n_samples
if (u_yx.min() < 0.0) | (u_yx.max() > 1.0):
raise ValueError("u_yx input must be between 0. and 1.")
if causal_model == "location_translation":
# This model expects the input to be in (-1,1)
corruni_standard = (u_yx * 2) - 1
else:
# This model expects the input to be in (0,1)
corruni_standard = u_yx
elif frugal_flow is not None:
flow_dim = frugal_flow.shape[0]
if frugal_flow.cond_shape is None:
flow_fake_condition = None
else:
flow_fake_condition = jnp.ones((n_samples, frugal_flow.cond_shape[0]))
# verify flow has a compatible structure
assert isinstance(frugal_flow.base_dist, _StandardUniform)
assert isinstance(frugal_flow.bijection.bijections[0].tree, Affine)
assert (isinstance(
frugal_flow.bijection.bijections[1].bijection.bijection.bijections[0],
MaskedAutoregressiveFirstUniform,
))|(isinstance(
frugal_flow.bijection.bijections[1].bijection.bijection.bijections[0],
MaskedAutoregressiveHeterogeneous,
))
maf_dim = (
frugal_flow.bijection.bijections[1]
.bijection.bijection.bijections[0]
.shape[0]
)
spline_n_params = int(
frugal_flow.bijection.bijections[1]
.bijection.bijection.bijections[0]
.masked_autoregressive_mlp.layers[-1]
.out_features
/ maf_dim
)
assert (
frugal_flow.bijection.bijections[1]
.bijection.bijection.bijections[0]
.transformer_constructor(jnp.ones((spline_n_params)))
.interval
== 1
)
try:
assert isinstance(
frugal_flow.bijection.bijections[2].tree.bijections[0], Identity
)
except Exception:
assert (isinstance(frugal_flow.bijection.bijections[2].tree, Invert)) & (
isinstance(frugal_flow.bijection.bijections[2].tree.bijection, Affine)
)
# obtain u_y samples from flow
uni_standard = jr.uniform(key, shape=(n_samples, flow_dim))
uni_minus1_plus1 = jax.vmap(frugal_flow.bijection.bijections[0].tree.transform)(
uni_standard
)
corruni_minus1_plus1 = jax.vmap(frugal_flow.bijection.bijections[1].transform)(
uni_minus1_plus1, flow_fake_condition
)
corruni = jax.vmap(frugal_flow.bijection.bijections[2].tree.transform)(
corruni_minus1_plus1, flow_fake_condition
)
if causal_effect_idx_in_flow is None:
warnings.warn(
"causal_effect_idx_in_flow has not been provided and is therefore set to the default of 0. This assumes no heterogeneous effects were modelled in frugal flow training."
)
causal_effect_idx_in_flow = 0
corruni_y = corruni[:, causal_effect_idx_in_flow]
try:
# in this case the flow expects the input to be in (-1,1)
assert isinstance(
frugal_flow.bijection.bijections[2].tree.bijections[0], Identity
)
corruni_standard = corruni_y
except Exception:
# in this case the flow expects the input to be in (0,1)
assert (isinstance(frugal_flow.bijection.bijections[2].tree, Invert)) & (
isinstance(frugal_flow.bijection.bijections[2].tree.bijection, Affine)
)
corruni_standard = (corruni_y / 2) + 0.5
if causal_model == "logistic_regression":
outcome_samples = logistic_outcome(
u_y=corruni_standard,
causal_condition=causal_condition,
**treatment_kwargs,
)
elif causal_model == "causal_cdf":
outcome_samples, _ = causal_cdf_outcome(
u_y=corruni_standard,
causal_condition=causal_condition,
causal_cdf=causal_cdf,
**treatment_kwargs,
)
elif causal_model == "location_translation":
try:
assert isinstance(
frugal_flow.bijection.bijections[4].bijections[0].bijection, Tanh
)
except Exception:
raise ValueError(
f"{causal_model} causal_model requires a 'location_translation' pretrained frugal_flow"
)
outcome_samples = location_translation_outcome(
u_y=corruni_standard,
causal_condition=causal_condition,
flow_condition=flow_fake_condition,
frugal_flow=frugal_flow,
**treatment_kwargs,
)
else:
raise ValueError(
f"Invalid causal_model choice. Please choose from: {valid_causal_models}"
)
return outcome_samples
def logistic_outcome(
u_y: ArrayLike, ate: float, causal_condition: ArrayLike, const: float
):
"""
Computes the logistic outcome based on the given inputs.
Args:
u_y: The input quantiles, of shape (n_samples,)
ate: The average treatment effect. Float.
causal_condition: The (univariate) causal condition. It is an Array with shape (n_samples, 1) or (n_samples,).
const: The constant term. Float.
Returns:
The computed logistic outcome.
"""
def get_y(u_y, ate, x, const):
p = jax.nn.sigmoid(ate * x + const)
return (u_y >= (1 - p)).astype(int).squeeze()
return jax.vmap(get_y, in_axes=(0, None, 0, None))(
u_y, ate, causal_condition, const
)
def causal_cdf_outcome(
u_y: ArrayLike,
causal_cdf: AbstractBijection,
causal_condition: ArrayLike,
**treatment_kwargs: dict,
):
if causal_condition is not None:
if causal_condition.ndim == 1:
# Reshape one-dimensional array to two dimensions with second dim as 1
causal_condition = causal_condition.reshape(-1, 1)
if "cond_dim" not in treatment_kwargs.keys():
treatment_kwargs["cond_dim"] = causal_condition.shape[1]
causal_cdf_init_params = [
i
for i in inspect.signature(causal_cdf.__init__).parameters.keys()
if ((i != "self") and (i != "cond_dim"))
]
for param in causal_cdf_init_params:
if param not in treatment_kwargs.keys():
treatment_kwargs[param] = None
warnings.warn(
f"The parameter {param} has not been provided and is therefore set to None."
)
causal_cdf_simulate = causal_cdf(**treatment_kwargs)
samples = jax.vmap(causal_cdf_simulate.inverse)(u_y, causal_condition)
return samples, causal_cdf_simulate
def location_translation_outcome(
u_y: ArrayLike,
frugal_flow: AbstractDistribution,
causal_condition: ArrayLike,
flow_condition: ArrayLike,
**treatment_kwargs: dict,
):
"""
Compute the outcome samples for the location_translation causal model.
Args:
u_y (ArrayLike): The input quantiles, of shape (n_samples,)
frugal_flow (AbstractDistribution): The frugal flow object.
causal_condition (ArrayLike): The causal condition.
flow_condition (ArrayLike): The flow condition.
**treatment_kwargs (dict): Additional keyword arguments for the treatment model.
Returns:
ArrayLike: The generated outcome samples.
"""
causal_minus1_plus1 = jax.vmap(
frugal_flow.bijection.bijections[3].bijections[0].transform
)(u_y[:, None], flow_condition)
causal_reals = jax.vmap(
frugal_flow.bijection.bijections[4].bijections[0].transform
)(causal_minus1_plus1.flatten(), flow_condition)
loc_cond_cdf_simulate = LocCond(**treatment_kwargs)
samples = jax.vmap(loc_cond_cdf_simulate.transform)(causal_reals, causal_condition)
return samples