Skip to content

IIVM: Subgroups option to adapt to cases with and without the subgroups of always-takers and never-takes. #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,9 @@ def set_ml_nuisance_params(self, learner, treat_var, params):
raise ValueError('Invalid treatment variable ' + treat_var + '. ' +
'Valid treatment variable ' + ' or '.join(self._dml_data.d_cols) + '.')

if isinstance(params, dict):
if params is None:
all_params = [None] * self.n_rep
elif isinstance(params, dict):
if self.apply_cross_fitting:
all_params = [[params] * self.n_folds] * self.n_rep
else:
Expand Down
65 changes: 53 additions & 12 deletions doubleml/double_ml_iivm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ class DoubleMLIIVM(DoubleML):
``psi_a, psi_b = score(y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls)``.
Default is ``'LATE'``.

subgroups: dict or None
Dictionary with options to adapt to cases with and without the subgroups of always-takers and never-takes. The
logical item ``always_takers`` speficies whether there are always takers in the sample. The logical item
``never_takers`` speficies whether there are never takers in the sample.
Default is ``{'always_takers': True, 'never_takers': True}``.

dml_procedure : str
A str (``'dml1'`` or ``'dml2'``) specifying the double machine learning algorithm.
Default is ``'dml2'``.
Expand Down Expand Up @@ -115,6 +121,7 @@ def __init__(self,
n_folds=5,
n_rep=1,
score='LATE',
subgroups=None,
dml_procedure='dml2',
trimming_rule='truncate',
trimming_threshold=1e-12,
Expand All @@ -138,6 +145,25 @@ def __init__(self,
if trimming_rule not in valid_trimming_rule:
raise ValueError('Invalid trimming_rule ' + trimming_rule + '. ' +
'Valid trimming_rule ' + ' or '.join(valid_trimming_rule) + '.')

if subgroups is None:
# this is the default for subgroups; via None to prevent a mutable default argument
subgroups = {'always_takers': True, 'never_takers': True}
else:
if not isinstance(subgroups, dict):
raise TypeError('Invalid subgroups ' + str(subgroups) + '. ' +
'subgroups must be of type dictionary.')
if (not all(k in subgroups for k in ['always_takers', 'never_takers']))\
| (not all(k in ['always_takers', 'never_takers'] for k in subgroups)):
raise ValueError('Invalid subgroups ' + str(subgroups) + '. ' +
'subgroups must be a dictionary with keys always_takers and never_takers.')
if not isinstance(subgroups['always_takers'], bool):
raise TypeError("subgroups['always_takers'] must be True or False. "
f'Got {str(subgroups["always_takers"])}.')
if not isinstance(subgroups['never_takers'], bool):
raise TypeError("subgroups['never_takers'] must be True or False. "
f'Got {str(subgroups["never_takers"])}.')
self.subgroups = subgroups
self.trimming_rule = trimming_rule
self.trimming_threshold = trimming_threshold

Expand Down Expand Up @@ -196,10 +222,16 @@ def _ml_nuisance_and_score_elements(self, smpls, n_jobs_cv):
est_params=self._get_params('ml_m'), method=self._predict_method['ml_m'])

# nuisance r
r_hat0 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z0, n_jobs=n_jobs_cv,
est_params=self._get_params('ml_r0'), method=self._predict_method['ml_r'])
r_hat1 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z1, n_jobs=n_jobs_cv,
est_params=self._get_params('ml_r1'), method=self._predict_method['ml_r'])
if self.subgroups['always_takers']:
r_hat0 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z0, n_jobs=n_jobs_cv,
est_params=self._get_params('ml_r0'), method=self._predict_method['ml_r'])
else:
r_hat0 = np.zeros_like(d)
if self.subgroups['never_takers']:
r_hat1 = _dml_cv_predict(self._learner['ml_r'], x, d, smpls=smpls_z1, n_jobs=n_jobs_cv,
est_params=self._get_params('ml_r1'), method=self._predict_method['ml_r'])
else:
r_hat1 = np.ones_like(d)

psi_a, psi_b = self._score_elements(y, z, d, g_hat0, g_hat1, m_hat, r_hat0, r_hat1, smpls)
preds = {'ml_g0': g_hat0,
Expand Down Expand Up @@ -262,18 +294,27 @@ def _ml_nuisance_tuning(self, smpls, param_grids, scoring_methods, n_folds_tune,
m_tune_res = _dml_tune(z, x, train_inds,
self._learner['ml_m'], param_grids['ml_m'], scoring_methods['ml_m'],
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
r0_tune_res = _dml_tune(d, x, train_inds_z0,
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
r1_tune_res = _dml_tune(d, x, train_inds_z1,
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)

if self.subgroups['always_takers']:
r0_tune_res = _dml_tune(d, x, train_inds_z0,
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
else:
r0_tune_res = None
r0_best_params = [None] * len(smpls)
if self.subgroups['never_takers']:
r1_tune_res = _dml_tune(d, x, train_inds_z1,
self._learner['ml_r'], param_grids['ml_r'], scoring_methods['ml_r'],
n_folds_tune, n_jobs_cv, search_mode, n_iter_randomized_search)
r1_best_params = [xx.best_params_ for xx in r1_tune_res]
else:
r1_tune_res = None
r1_best_params = [None] * len(smpls)

g0_best_params = [xx.best_params_ for xx in g0_tune_res]
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
m_best_params = [xx.best_params_ for xx in m_tune_res]
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
r1_best_params = [xx.best_params_ for xx in r1_tune_res]

params = {'ml_g0': g0_best_params,
'ml_g1': g1_best_params,
Expand Down
55 changes: 34 additions & 21 deletions doubleml/tests/_utils_iivm_manual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

def fit_nuisance_iivm(y, x, d, z, learner_m, learner_g, learner_r, smpls,
g0_params=None, g1_params=None, m_params=None, r0_params=None, r1_params=None,
trimming_threshold=1e-12):
trimming_threshold=1e-12, always_takers=True, never_takers=True):
ml_g0 = clone(learner_g)
g_hat0 = []
for idx, (train_index, test_index) in enumerate(smpls):
Expand Down Expand Up @@ -41,21 +41,28 @@ def fit_nuisance_iivm(y, x, d, z, learner_m, learner_g, learner_r, smpls,
if r0_params is not None:
ml_r0.set_params(**r0_params[idx])
train_index0 = np.intersect1d(np.where(z == 0)[0], train_index)
r_hat0.append(ml_r0.fit(x[train_index0], d[train_index0]).predict_proba(x[test_index])[:, 1])
if always_takers:
r_hat0.append(ml_r0.fit(x[train_index0], d[train_index0]).predict_proba(x[test_index])[:, 1])
else:
r_hat0.append(np.zeros_like(d[test_index]))

ml_r1 = clone(learner_r)
r_hat1 = []
for idx, (train_index, test_index) in enumerate(smpls):
if r1_params is not None:
ml_r1.set_params(**r1_params[idx])
train_index1 = np.intersect1d(np.where(z == 1)[0], train_index)
r_hat1.append(ml_r1.fit(x[train_index1], d[train_index1]).predict_proba(x[test_index])[:, 1])
if never_takers:
r_hat1.append(ml_r1.fit(x[train_index1], d[train_index1]).predict_proba(x[test_index])[:, 1])
else:
r_hat1.append(np.ones_like(d[test_index]))

return g_hat0, g_hat1, m_hat, r_hat0, r_hat1


def tune_nuisance_iivm(y, x, d, z, ml_m, ml_g, ml_r, smpls, n_folds_tune,
param_grid_g, param_grid_m, param_grid_r):
param_grid_g, param_grid_m, param_grid_r,
always_takers=True, never_takers=True):
g0_tune_res = [None] * len(smpls)
for idx, (train_index, _) in enumerate(smpls):
g0_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
Expand All @@ -79,27 +86,33 @@ def tune_nuisance_iivm(y, x, d, z, ml_m, ml_g, ml_r, smpls, n_folds_tune,
cv=m_tune_resampling)
m_tune_res[idx] = m_grid_search.fit(x[train_index, :], z[train_index])

r0_tune_res = [None] * len(smpls)
for idx, (train_index, _) in enumerate(smpls):
r0_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
r0_grid_search = GridSearchCV(ml_r, param_grid_r,
cv=r0_tune_resampling)
train_index0 = np.intersect1d(np.where(z == 0)[0], train_index)
r0_tune_res[idx] = r0_grid_search.fit(x[train_index0, :], d[train_index0])

r1_tune_res = [None] * len(smpls)
for idx, (train_index, _) in enumerate(smpls):
r1_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
r1_grid_search = GridSearchCV(ml_r, param_grid_r,
cv=r1_tune_resampling)
train_index1 = np.intersect1d(np.where(z == 1)[0], train_index)
r1_tune_res[idx] = r1_grid_search.fit(x[train_index1, :], d[train_index1])
if always_takers:
r0_tune_res = [None] * len(smpls)
for idx, (train_index, _) in enumerate(smpls):
r0_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
r0_grid_search = GridSearchCV(ml_r, param_grid_r,
cv=r0_tune_resampling)
train_index0 = np.intersect1d(np.where(z == 0)[0], train_index)
r0_tune_res[idx] = r0_grid_search.fit(x[train_index0, :], d[train_index0])
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
else:
r0_best_params = None

if never_takers:
r1_tune_res = [None] * len(smpls)
for idx, (train_index, _) in enumerate(smpls):
r1_tune_resampling = KFold(n_splits=n_folds_tune, shuffle=True)
r1_grid_search = GridSearchCV(ml_r, param_grid_r,
cv=r1_tune_resampling)
train_index1 = np.intersect1d(np.where(z == 1)[0], train_index)
r1_tune_res[idx] = r1_grid_search.fit(x[train_index1, :], d[train_index1])
r1_best_params = [xx.best_params_ for xx in r1_tune_res]
else:
r1_best_params = None

g0_best_params = [xx.best_params_ for xx in g0_tune_res]
g1_best_params = [xx.best_params_ for xx in g1_tune_res]
m_best_params = [xx.best_params_ for xx in m_tune_res]
r0_best_params = [xx.best_params_ for xx in r0_tune_res]
r1_best_params = [xx.best_params_ for xx in r1_tune_res]

return g0_best_params, g1_best_params, m_best_params, r0_best_params, r1_best_params

Expand Down
30 changes: 30 additions & 0 deletions doubleml/tests/test_doubleml_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,36 @@ def test_doubleml_exception_trimming_rule():
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(), trimming_rule='discard')


@pytest.mark.ci
def test_doubleml_exception_subgroups():
msg = 'Invalid subgroups True. subgroups must be of type dictionary.'
with pytest.raises(TypeError, match=msg):
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(),
subgroups=True)
msg = "Invalid subgroups {'abs': True}. subgroups must be a dictionary with keys always_takers and never_takers."
with pytest.raises(ValueError, match=msg):
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(),
subgroups={'abs': True})
msg = ("Invalid subgroups {'always_takers': True, 'never_takers': False, 'abs': 5}. "
"subgroups must be a dictionary with keys always_takers and never_takers.")
with pytest.raises(ValueError, match=msg):
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(),
subgroups={'always_takers': True, 'never_takers': False, 'abs': 5})
msg = ("Invalid subgroups {'always_takers': True}. "
"subgroups must be a dictionary with keys always_takers and never_takers.")
with pytest.raises(ValueError, match=msg):
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(),
subgroups={'always_takers': True})
msg = r"subgroups\['always_takers'\] must be True or False. Got 5."
with pytest.raises(TypeError, match=msg):
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(),
subgroups={'always_takers': 5, 'never_takers': False})
msg = r"subgroups\['never_takers'\] must be True or False. Got 5."
with pytest.raises(TypeError, match=msg):
_ = DoubleMLIIVM(dml_data_iivm, Lasso(), LogisticRegression(), LogisticRegression(),
subgroups={'always_takers': True, 'never_takers': 5})


@pytest.mark.ci
def test_doubleml_exception_resampling():
msg = "The number of folds must be of int type. 1.5 of type <class 'float'> was passed."
Expand Down
1 change: 1 addition & 0 deletions doubleml/tests/test_doubleml_model_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_irm_defaults():
def test_iivm_defaults():
_assert_resampling_default_settings(dml_iivm)
assert dml_iivm.score == 'LATE'
assert dml_iivm.subgroups == {'always_takers': True, 'never_takers': True}
assert dml_iivm.dml_procedure == 'dml2'
assert dml_iivm.trimming_rule == 'truncate'
assert dml_iivm.trimming_threshold == 1e-12
Loading