Skip to content

Commit

Permalink
Merge pull request #218 from uber/fix_xclassifier_bug
Browse files Browse the repository at this point in the history
Fix BaseXClassifier bug
  • Loading branch information
yungmsh authored Jul 21, 2020
2 parents b9fa503 + d0c6191 commit f14c290
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 15 deletions.
11 changes: 4 additions & 7 deletions causalml/inference/meta/rlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ class BaseRClassifier(BaseRLearner):
"""

def __init__(self,
learner=None,
outcome_learner=None,
effect_learner=None,
ate_alpha=.05,
Expand All @@ -492,18 +491,16 @@ def __init__(self,
"""Initialize an R-learner classifier.
Args:
learner (optional): a model to estimate outcomes and treatment effects. Even if specified, the user
must still specify either the outcome learner or the effect learner.
outcome_learner (optional): a model to estimate outcomes. Should have a predict_proba() method.
effect_learner (optional): a model to estimate treatment effects. It needs to take `sample_weight` as an
input argument for `fit()`
outcome_learner: a model to estimate outcomes. Should be a classifier.
effect_learner: a model to estimate treatment effects. It needs to take `sample_weight` as an
input argument for `fit()`. Should be a regressor.
ate_alpha (float, optional): the confidence level alpha of the ATE estimate
control_name (str or int, optional): name of control group
n_fold (int, optional): the number of cross validation folds for outcome_learner
random_state (int or RandomState, optional): a seed (int) or random number generator (RandomState)
"""
super().__init__(
learner=learner,
learner=None,
outcome_learner=outcome_learner,
effect_learner=effect_learner,
ate_alpha=ate_alpha,
Expand Down
27 changes: 19 additions & 8 deletions causalml/inference/meta/xlearner.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ class BaseXClassifier(BaseXLearner):
"""

def __init__(self,
learner=None,
outcome_learner=None,
effect_learner=None,
control_outcome_learner=None,
treatment_outcome_learner=None,
control_effect_learner=None,
Expand All @@ -572,20 +573,30 @@ def __init__(self,
"""Initialize an X-learner classifier.
Args:
learner (optional): a model to estimate outcomes or treatment effects in both the control and treatment
groups. Even if specified, the user must still input either the outcome learner or the effect learner
pair.
outcome_learner (optional): a model to estimate outcomes in both the control and treatment groups.
Should be a regressor.
effect_learner (optional): a model to estimate treatment effects in both the control and treatment groups.
Should be a classifier.
control_outcome_learner (optional): a model to estimate outcomes in the control group.
Should have a predict_proba() method.
Should be a regressor.
treatment_outcome_learner (optional): a model to estimate outcomes in the treatment group.
Should have a predict_proba() method.
control_effect_learner (optional): a model to estimate treatment effects in the control group
Should be a regressor.
control_effect_learner (optional): a model to estimate treatment effects in the control group.
Should be a classifier.
treatment_effect_learner (optional): a model to estimate treatment effects in the treatment group
Should be a classifier.
ate_alpha (float, optional): the confidence level alpha of the ATE estimate
control_name (str or int, optional): name of control group
"""
if outcome_learner is not None:
control_outcome_learner = outcome_learner
treatment_outcome_learner = outcome_learner
if effect_learner is not None:
control_effect_learner = effect_learner
treatment_effect_learner = effect_learner

super().__init__(
learner=learner,
learner=None,
control_outcome_learner=control_outcome_learner,
treatment_outcome_learner=treatment_outcome_learner,
control_effect_learner=control_effect_learner,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_meta_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ def test_BaseXClassifier(generate_classification_data):
test_size=0.2,
random_state=RANDOM_SEED)

# specify all 4 learners
uplift_model = BaseXClassifier(control_outcome_learner=XGBClassifier(),
control_effect_learner=XGBRegressor(),
treatment_outcome_learner=XGBClassifier(),
Expand All @@ -541,6 +542,18 @@ def test_BaseXClassifier(generate_classification_data):
tau_pred = uplift_model.predict(X=df_test[x_names].values,
p=df_test['propensity_score'].values)

# specify 2 learners
uplift_model = BaseXClassifier(outcome_learner=XGBClassifier(),
effect_learner=XGBRegressor())

uplift_model.fit(X=df_train[x_names].values,
treatment=df_train['treatment_group_key'].values,
y=df_train[CONVERSION].values)

tau_pred = uplift_model.predict(X=df_test[x_names].values,
p=df_test['propensity_score'].values)

# calculate metrics
auuc_metrics = pd.DataFrame({'tau_pred': tau_pred.flatten(),
'W': df_test['treatment_group_key'].values,
CONVERSION: df_test[CONVERSION].values,
Expand Down

0 comments on commit f14c290

Please sign in to comment.