Skip to content

Commit

Permalink
fix iPrePostNEGD input validation
Browse files Browse the repository at this point in the history
  • Loading branch information
jpreszler committed Sep 21, 2023
1 parent 133b987 commit 59e55b1
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from causalpy.custom_exceptions import BadIndexException # NOQA
from causalpy.custom_exceptions import DataException, FormulaException
from causalpy.plot_utils import plot_xY
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels
from causalpy.utils import _is_variable_dummy_coded

LEGEND_FONT_SIZE = 12
az.style.use("arviz-darkgrid")
Expand Down Expand Up @@ -895,6 +895,12 @@ def _input_validation(self):
raise DataException(
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
)
elif self.data["treated"].dtype != bool:
# treated is dummy encoded, but not boolean, so we cast to bool
# so patsy design_info is sensible
bool_treated = self.data["treated"].apply(lambda x: bool(x))
self.data = self.data.drop(labels=["treated"], axis="columns")
self.data["treated"] = bool_treated

def _is_treated(self, x):
"""Returns ``True`` if `x` is greater than or equal to the treatment threshold.
Expand Down Expand Up @@ -978,7 +984,8 @@ class PrePostNEGD(ExperimentalDesign):
:param formula:
A statistical model formula
:param group_variable_name:
Name of the column in data for the group variable
Name of the column in data for the group variable, should be either
binary or boolean
:param pretreatment_variable_name:
Name of the column in data for the pretreatment variable
:param model:
Expand Down Expand Up @@ -1058,17 +1065,19 @@ def __init__(
self.group_variable_name: np.zeros(self.pred_xi.shape),
}
)
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
self.pred_untreated = self.model.predict(X=np.asarray(new_x))
(new_x_untreated,) = build_design_matrices(
[self._x_design_info], x_pred_untreated
)
self.pred_untreated = self.model.predict(X=np.asarray(new_x_untreated))
# treated
x_pred_untreated = pd.DataFrame(
x_pred_treated = pd.DataFrame(
{
self.pretreatment_variable_name: self.pred_xi,
self.group_variable_name: np.ones(self.pred_xi.shape),
}
)
(new_x,) = build_design_matrices([self._x_design_info], x_pred_untreated)
self.pred_treated = self.model.predict(X=np.asarray(new_x))
(new_x_treated,) = build_design_matrices([self._x_design_info], x_pred_treated)
self.pred_treated = self.model.predict(X=np.asarray(new_x_treated))

# Evaluate causal impact as equal to the trestment effect
self.causal_impact = self.idata.posterior["beta"].sel(
Expand All @@ -1079,7 +1088,7 @@ def __init__(

def _input_validation(self) -> None:
"""Validate the input data and model formula for correctness"""
if not _series_has_2_levels(self.data[self.group_variable_name]):
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
raise DataException(
f"""
There must be 2 levels of the grouping variable
Expand Down Expand Up @@ -1165,7 +1174,7 @@ def _get_treatment_effect_coeff(self) -> str:
then we want `C(group)[T.1]`.
"""
for label in self.labels:
if ("group" in label) & (":" not in label):
if (self.group_variable_name in label) & (":" not in label):
return label

raise NameError("Unable to find coefficient name for the treatment effect")
Expand Down

0 comments on commit 59e55b1

Please sign in to comment.