Skip to content

Add deprecation warnings #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions doubleml/_utils_resampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
from sklearn.model_selection import KFold, RepeatedKFold, RepeatedStratifiedKFold


# Remove warnings in future versions
def deprication_apply_cross_fitting():
warnings.warn('The apply_cross_fitting argument is deprecated and will be removed in future versions. '
'In the future, crossfitting is applied by default. '
'To rely on sample splitting please use external predictions.',
DeprecationWarning)
return


class DoubleMLResampling:
def __init__(self,
n_folds,
Expand All @@ -14,6 +23,8 @@ def __init__(self,
self.n_folds = n_folds
self.n_rep = n_rep
self.n_obs = n_obs
if not apply_cross_fitting:
deprication_apply_cross_fitting()
self.apply_cross_fitting = apply_cross_fitting
self.stratify = stratify
if (self.n_folds == 1) & self.apply_cross_fitting:
Expand Down
24 changes: 23 additions & 1 deletion doubleml/double_ml.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,21 @@
_implemented_data_backends = ['DoubleMLData', 'DoubleMLClusterData']


# Remove warnings in future versions
def deprication_apply_cross_fitting():
warnings.warn('The apply_cross_fitting argument is deprecated and will be removed in future versions. '
'In the future, crossfitting is applied by default. '
'To rely on sample splitting please use external predictions.',
DeprecationWarning)
return


def deprication_dml_procedure():
warnings.warn('The dml_procedure argument is deprecated and will be removed in future versions. '
'in the future, dml_procedure is always set to dml2.', DeprecationWarning)
return


class DoubleML(ABC):
"""Double Machine Learning.
"""
Expand Down Expand Up @@ -89,6 +104,9 @@ def __init__(self,
raise TypeError('draw_sample_splitting must be True or False. '
f'Got {str(draw_sample_splitting)}.')

if not apply_cross_fitting:
deprication_apply_cross_fitting()

# set resampling specifications
if self._is_cluster_data:
if (n_folds == 1) or (not apply_cross_fitting):
Expand All @@ -103,11 +121,15 @@ def __init__(self,
# default is no stratification
self._strata = None

# check and set dml_procedure and score
if (not isinstance(dml_procedure, str)) | (dml_procedure not in ['dml1', 'dml2']):
raise ValueError('dml_procedure must be "dml1" or "dml2". '
f'Got {str(dml_procedure)}.')
self._dml_procedure = dml_procedure

if dml_procedure == 'dml1':
deprication_dml_procedure()
self._dml_procedure = dml_procedure

self._score = score

if (self.n_folds == 1) & self.apply_cross_fitting:
Expand Down