9
9
10
10
import doubleml as dml
11
11
12
- from ._utils import draw_smpls
12
+ from ... tests . _utils import draw_smpls
13
13
from ._utils_did_manual import boot_did
14
14
from ._utils_did_cs_manual import fit_did_cs , tune_nuisance_did_cs
15
15
@@ -77,15 +77,22 @@ def dml_did_cs_fixture(generate_data_did_cs, learner_g, learner_m, score, in_sam
77
77
ml_g = clone (learner_g )
78
78
ml_m = clone (learner_m )
79
79
80
+ n_obs = len (y )
81
+ all_smpls = draw_smpls (n_obs , n_folds , n_rep = 1 , groups = d + 2 * t )
82
+
80
83
np .random .seed (3141 )
81
84
obj_dml_data = dml .DoubleMLData .from_arrays (x , y , d , t = t )
82
85
dml_did_cs_obj = dml .DoubleMLDIDCS (obj_dml_data ,
83
86
ml_g , ml_m ,
84
87
n_folds ,
85
88
score = score ,
86
89
in_sample_normalization = in_sample_normalization ,
87
- dml_procedure = dml_procedure )
90
+ dml_procedure = dml_procedure ,
91
+ draw_sample_splitting = False )
92
+ # synchronize the sample splitting
93
+ dml_did_cs_obj .set_sample_splitting (all_smpls = all_smpls )
88
94
95
+ np .random .seed (3141 )
89
96
# tune hyperparameters
90
97
tune_res = dml_did_cs_obj .tune (par_grid , tune_on_folds = tune_on_folds ,
91
98
n_folds_tune = n_folds_tune ,
@@ -95,8 +102,6 @@ def dml_did_cs_fixture(generate_data_did_cs, learner_g, learner_m, score, in_sam
95
102
dml_did_cs_obj .fit ()
96
103
97
104
np .random .seed (3141 )
98
- n_obs = len (y )
99
- all_smpls = draw_smpls (n_obs , n_folds )
100
105
smpls = all_smpls [0 ]
101
106
102
107
if tune_on_folds :
@@ -152,14 +157,14 @@ def dml_did_cs_fixture(generate_data_did_cs, learner_g, learner_m, score, in_sam
152
157
153
158
@pytest .mark .ci
154
159
def test_dml_did_cs_coef (dml_did_cs_fixture ):
155
- assert math .isclose (dml_did_cs_fixture ['coef' ],
160
+ assert math .isclose (dml_did_cs_fixture ['coef' ][ 0 ] ,
156
161
dml_did_cs_fixture ['coef_manual' ],
157
162
rel_tol = 1e-9 , abs_tol = 1e-4 )
158
163
159
164
160
165
@pytest .mark .ci
161
166
def test_dml_did_cs_se (dml_did_cs_fixture ):
162
- assert math .isclose (dml_did_cs_fixture ['se' ],
167
+ assert math .isclose (dml_did_cs_fixture ['se' ][ 0 ] ,
163
168
dml_did_cs_fixture ['se_manual' ],
164
169
rel_tol = 1e-9 , abs_tol = 1e-4 )
165
170
0 commit comments