Skip to content

Commit 8431daf

Browse files
committed
fix gain statistics and dummy learner docstrings
1 parent eda1137 commit 8431daf

File tree

2 files changed

+39
-31
lines changed

2 files changed

+39
-31
lines changed

doubleml/utils/dummy_learners.py

+30-24
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,32 @@ class DMLDummyRegressor(BaseEstimator):
66
A dummy regressor that raises an AttributeError when attempting to access
77
its fit, predict, or set_params methods.
88
9-
Attributes
9+
Parameters
1010
----------
11-
_estimator_type : str
12-
Type of the estimator, set to "regressor".
1311
14-
Methods
15-
-------
16-
fit(*args)
17-
Raises AttributeError: "Accessed fit method of DummyRegressor!"
18-
predict(*args)
19-
Raises AttributeError: "Accessed predict method of DummyRegressor!"
20-
set_params(*args)
21-
Raises AttributeError: "Accessed set_params method of DummyRegressor!"
2212
"""
2313

2414
_estimator_type = "regressor"
2515

2616
def fit(*args):
17+
"""
18+
Raises AttributeError: "Accessed fit method of DummyRegressor!"
19+
"""
20+
2721
raise AttributeError("Accessed fit method of DMLDummyRegressor!")
2822

2923
def predict(*args):
24+
"""
25+
Raises AttributeError: "Accessed predict method of DummyRegressor!"
26+
"""
27+
3028
raise AttributeError("Accessed predict method of DMLDummyRegressor!")
3129

3230
def set_params(*args):
31+
"""
32+
Raises AttributeError: "Accessed set_params method of DummyRegressor!"
33+
"""
34+
3335
raise AttributeError("Accessed set_params method of DMLDummyRegressor!")
3436

3537

@@ -38,33 +40,37 @@ class DMLDummyClassifier(BaseEstimator):
3840
A dummy classifier that raises an AttributeError when attempting to access
3941
its fit, predict, set_params, or predict_proba methods.
4042
41-
Attributes
43+
Parameters
4244
----------
43-
_estimator_type : str
44-
Type of the estimator, set to "classifier".
4545
46-
Methods
47-
-------
48-
fit(*args)
49-
Raises AttributeError: "Accessed fit method of DummyClassifier!"
50-
predict(*args)
51-
Raises AttributeError: "Accessed predict method of DummyClassifier!"
52-
set_params(*args)
53-
Raises AttributeError: "Accessed set_params method of DummyClassifier!"
54-
predict_proba(*args, **kwargs)
55-
Raises AttributeError: "Accessed predict_proba method of DummyClassifier!"
5646
"""
5747

5848
_estimator_type = "classifier"
5949

6050
def fit(*args):
51+
"""
52+
Raises AttributeError: "Accessed fit method of DummyClassifier!"
53+
"""
54+
6155
raise AttributeError("Accessed fit method of DMLDummyClassifier!")
6256

6357
def predict(*args):
58+
"""
59+
Raises AttributeError: "Accessed predict method of DummyClassifier!"
60+
"""
61+
6462
raise AttributeError("Accessed predict method of DMLDummyClassifier!")
6563

6664
def set_params(*args):
65+
"""
66+
Raises AttributeError: "Accessed set_params method of DummyClassifier!"
67+
"""
68+
6769
raise AttributeError("Accessed set_params method of DMLDummyClassifier!")
6870

6971
def predict_proba(*args, **kwargs):
72+
"""
73+
Raises AttributeError: "Accessed predict_proba method of DummyClassifier!"
74+
"""
75+
7076
raise AttributeError("Accessed predict_proba method of DMLDummyClassifier!")

doubleml/utils/gain_statistics.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,21 @@
33

44
def gain_statistics(dml_long, dml_short):
55
"""
6-
Compute gain statistics as benchmark values for sensitivity parameters cf_d and cf_y.
6+
Compute gain statistics as benchmark values for sensitivity parameters ``cf_d`` and ``cf_y``.
77
8-
Parameters:
8+
Parameters
99
----------
1010
11-
dml_long : :class:`doubleml.DoubleML` model including all observed confounders
12-
dml_short : :class:`doubleml.DoubleML` model that excludes one or several benchmark confounders
11+
dml_long :
12+
:class:`doubleml.DoubleML` model including all observed confounders
1313
14+
dml_short :
15+
:class:`doubleml.DoubleML` model that excludes one or several benchmark confounders
1416
15-
Returns:
17+
Returns
1618
--------
17-
Benchmarking dictionary (dict) with values for cf_d, cf_y, rho, and delta_theta.
18-
19+
benchmark_dict : dict
20+
Benchmarking dictionary (dict) with values for ``cf_d``, ``cf_y``, ``rho``, and ``delta_theta``.
1921
"""
2022
if not isinstance(dml_long.sensitivity_elements, dict):
2123
raise TypeError("dml_long does not contain the necessary sensitivity elements. "

0 commit comments

Comments
 (0)