1
1
from doubleml import DoubleMLData
2
+ import os
3
+ import boto3
4
+ import pandas as pd
2
5
3
6
4
7
class DoubleMLDataS3 (DoubleMLData ):
@@ -18,6 +21,8 @@ def __init__(self,
18
21
z_cols ,
19
22
use_other_treat_as_covariate )
20
23
self ._bucket = bucket
24
+ self ._file_ending = os .path .splitext (file_key )[1 ]
25
+ assert self ._file_ending in ['.csv' ]
21
26
self ._file_key = file_key
22
27
23
28
@property
@@ -36,6 +41,36 @@ def get_payload(self):
36
41
}
37
42
return payload
38
43
44
+ @classmethod
45
+ def from_s3 (cls ,
46
+ bucket ,
47
+ file_key ,
48
+ y_col ,
49
+ d_cols ,
50
+ x_cols = None ,
51
+ z_cols = None ,
52
+ use_other_treat_as_covariate = True ):
53
+ s3_client = boto3 .client ('s3' )
54
+ response = s3_client .get_object (Bucket = bucket ,
55
+ Key = file_key )
56
+ file = response ["Body" ]
57
+ file_ending = os .path .splitext (file_key )[1 ]
58
+ assert file_ending in ['.csv' ]
59
+ # load csv as a pd.DataFrame
60
+ data = pd .read_csv (file )
61
+
62
+ return cls (bucket , file_key , data , y_col , d_cols , x_cols , z_cols , use_other_treat_as_covariate )
63
+
64
+ def store_and_upload_to_s3 (self ):
65
+ # load csv as a pd.DataFrame
66
+ file_name = os .path .split (self .file_key )[1 ]
67
+ self .data .to_csv (file_name )
68
+ s3_client = boto3 .client ('s3' )
69
+ response = s3_client .upload_file (Filename = file_name ,
70
+ Bucket = self .bucket ,
71
+ Key = self .file_key )
72
+ return response
73
+
39
74
40
75
class DoubleMLDataJson (DoubleMLData ):
41
76
def __init__ (self ,
0 commit comments