Skip to content

Commit

Permalink
Temporary fix for causal trees missing values support #733 (#734)
Browse files Browse the repository at this point in the history
* Add temp missing values placeholder for causal trees
* Apply black codestyle
  • Loading branch information
alexander-pv authored Jan 29, 2024
1 parent 0040ac6 commit 9e1f892
Show file tree
Hide file tree
Showing 8 changed files with 82,960 additions and 76 deletions.
40 changes: 20 additions & 20 deletions causalml/dataset/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,30 +383,30 @@ def get_synthetic_preds_holdout(
# fit the model on training data only
learner.fit(X=X_train, treatment=w_train, y=y_train)
try:
preds_dict_train[
"{} Learner ({})".format(label_l, label_m)
] = learner.predict(X=X_train, p=p_hat_train).flatten()
preds_dict_valid[
"{} Learner ({})".format(label_l, label_m)
] = learner.predict(X=X_val, p=p_hat_val).flatten()
preds_dict_train["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_train, p=p_hat_train).flatten()
)
preds_dict_valid["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_val, p=p_hat_val).flatten()
)
except TypeError:
preds_dict_train[
"{} Learner ({})".format(label_l, label_m)
] = learner.predict(
X=X_train, treatment=w_train, y=y_train
).flatten()
preds_dict_valid[
"{} Learner ({})".format(label_l, label_m)
] = learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
preds_dict_train["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(
X=X_train, treatment=w_train, y=y_train
).flatten()
)
preds_dict_valid["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_val, treatment=w_val, y=y_val).flatten()
)
else:
learner = base_learner(model())
learner.fit(X=X_train, p=p_hat_train, treatment=w_train, y=y_train)
preds_dict_train[
"{} Learner ({})".format(label_l, label_m)
] = learner.predict(X=X_train).flatten()
preds_dict_valid[
"{} Learner ({})".format(label_l, label_m)
] = learner.predict(X=X_val).flatten()
preds_dict_train["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_train).flatten()
)
preds_dict_valid["{} Learner ({})".format(label_l, label_m)] = (
learner.predict(X=X_val).flatten()
)

return preds_dict_train, preds_dict_valid

Expand Down
8 changes: 5 additions & 3 deletions causalml/feature_selection/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,11 @@ def _GetNodeSummary(
if smooth:
results[ti].update(
{
ci: results_series[ti, ci]
if results_series.index.isin([(ti, ci)]).any()
else 1
ci: (
results_series[ti, ci]
if results_series.index.isin([(ti, ci)]).any()
else 1
)
}
)
else:
Expand Down
6 changes: 3 additions & 3 deletions causalml/inference/meta/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def get_shap_values(self):
for group, mod in self.models_tau.items():
explainer = shap.TreeExplainer(mod)
if self.r_learners is not None:
explainer.model.original_model.params[
"objective"
] = None # hacky way of running shap without error
explainer.model.original_model.params["objective"] = (
None # hacky way of running shap without error
)
shap_values = explainer.shap_values(self.X)
shap_dict[group] = shap_values

Expand Down
Loading

0 comments on commit 9e1f892

Please sign in to comment.