Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions causalpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,5 @@
"RegressionKink",
"skl_models",
"SyntheticControl",
"variable_selection_priors",
]
63 changes: 53 additions & 10 deletions causalpy/experiments/instrumental_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,16 @@ class InstrumentalVariable(BaseExperiment):
If priors are not specified we will substitute MLE estimates for
the beta coefficients. Example: ``priors = {"mus": [0, 0],
"sigmas": [1, 1], "eta": 2, "lkj_sd": 2}``.
vs_prior_type : str or None, default=None
Type of variable selection prior: 'spike_and_slab', 'horseshoe', or None.
If None, uses standard normal priors.
vs_hyperparams : dict, optional
Hyperparameters for variable selection priors. Only used if vs_prior_type
is not None.
binary_treatment : bool, default=False
A indicator for whether the treatment to be modelled is binary or not.
Determines which PyMC model we use to model the joint outcome and
treatment.

Example
--------
Expand Down Expand Up @@ -85,6 +95,16 @@ class InstrumentalVariable(BaseExperiment):
... formula=formula,
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
... )
>>> # With variable selection
>>> iv = cp.InstrumentalVariable(
... instruments_data=instruments_data,
... data=data,
... instruments_formula=instruments_formula,
... formula=formula,
... model=InstrumentalVariableRegression(sample_kwargs=sample_kwargs),
... vs_prior_type="spike_and_slab",
... vs_hyperparams={"slab_sigma": 5.0},
... )
"""

supports_ols = False
Expand All @@ -98,6 +118,9 @@ def __init__(
formula: str,
model: BaseExperiment | None = None,
priors: dict | None = None,
vs_prior_type=None,
vs_hyperparams=None,
binary_treatment=False,
**kwargs: dict,
) -> None:
super().__init__(model=model)
Expand All @@ -107,6 +130,9 @@ def __init__(
self.formula = formula
self.instruments_formula = instruments_formula
self.model = model
self.vs_prior_type = vs_prior_type
self.vs_hyperparams = vs_hyperparams or {}
self.binary_treatment = binary_treatment
self.input_validation()

y, X = dmatrices(formula, self.data)
Expand All @@ -130,15 +156,33 @@ def __init__(
COORDS = {"instruments": self.labels_instruments, "covariates": self.labels}
self.coords = COORDS
if priors is None:
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 1,
}
if binary_treatment:
# Different default priors for binary treatment
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"sigma_U": 1.0,
"rho_bounds": [-0.99, 0.99],
}
else:
# Original continuous treatment priors
priors = {
"mus": [self.ols_beta_first_params, self.ols_beta_second_params],
"sigmas": [1, 1],
"eta": 2,
"lkj_sd": 1,
}
self.priors = priors
self.model.fit( # type: ignore[call-arg,union-attr]
X=self.X, Z=self.Z, y=self.y, t=self.t, coords=COORDS, priors=self.priors
X=self.X,
Z=self.Z,
y=self.y,
t=self.t,
coords=COORDS,
priors=self.priors,
vs_prior_type=vs_prior_type,
vs_hyperparams=vs_hyperparams,
binary_treatment=self.binary_treatment,
)

def input_validation(self) -> None:
Expand All @@ -159,9 +203,8 @@ def input_validation(self) -> None:
if check_binary:
warnings.warn(
"""Warning. The treatment variable is not Binary.
This is not necessarily a problem but it violates
the assumption of a simple IV experiment.
The coefficients should be interpreted appropriately."""
We will use the multivariate normal likelihood
for continuous treatment."""
)

def get_2SLS_fit(self) -> None:
Expand Down
Loading