30
30
class CsvDataSet (object ):
31
31
"""DataSet based on CSV files and schema."""
32
32
33
- def __init__ (self , files , schema = None , schema_file = None ):
33
+ def __init__ (self , file_pattern , schema = None , schema_file = None ):
34
34
"""
35
35
Args:
36
- files: A list of CSV files. Can contain wildcards in file names. Can be local or GCS path.
36
+ file_pattern: A list of CSV files. or a string. Can contain wildcards in
37
+ file names. Can be local or GCS path.
37
38
schema: A BigQuery schema object in the form of
38
39
[{'name': 'col1', 'type': 'STRING'},
39
40
{'name': 'col2', 'type': 'INTEGER'}]
@@ -59,21 +60,32 @@ def __init__(self, files, schema=None, schema_file=None):
59
60
else :
60
61
with ml .util ._file .open_local_or_gcs (schema_file , 'r' ) as f :
61
62
self ._schema = json .load (f )
63
+
62
64
if isinstance (files , basestring ):
63
65
files = [files ]
64
- self ._files = []
65
- for file in files :
66
- # glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
67
- self ._files += [str (x ) for x in ml .util ._file .glob_files (file )]
66
+ self ._input_files = files
67
+
68
+ self ._glob_files = []
69
+
70
+
71
+ @property
72
+ def _input_files (self ):
73
+ """Returns the file list that was given to this class without globing files."""
74
+ return self ._input_files
68
75
69
76
@property
70
77
def files (self ):
71
- return self ._files
78
+ if not self ._glob_files :
79
+ for file in self ._input_files :
80
+ # glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
81
+ self ._glob_files += [str (x ) for x in ml .util ._file .glob_files (file )]
82
+
83
+ return self ._glob_files
72
84
73
85
@property
74
86
def schema (self ):
75
- return self ._schema
76
-
87
+ return self ._schema
88
+
77
89
def sample (self , n ):
78
90
""" Samples data into a Pandas DataFrame.
79
91
Args:
@@ -85,7 +97,7 @@ def sample(self, n):
85
97
"""
86
98
row_total_count = 0
87
99
row_counts = []
88
- for file in self ._files :
100
+ for file in self .files :
89
101
with ml .util ._file .open_local_or_gcs (file , 'r' ) as f :
90
102
num_lines = sum (1 for line in f )
91
103
row_total_count += num_lines
@@ -108,7 +120,7 @@ def sample(self, n):
108
120
# Note that random.sample will raise Exception if skip_count is greater than rows count.
109
121
skip_all = sorted (random .sample (xrange (0 , row_total_count ), skip_count ))
110
122
dfs = []
111
- for file , row_count in zip (self ._files , row_counts ):
123
+ for file , row_count in zip (self .files , row_counts ):
112
124
skip = [x for x in skip_all if x < row_count ]
113
125
skip_all = [x - row_count for x in skip_all if x >= row_count ]
114
126
with ml .util ._file .open_local_or_gcs (file , 'r' ) as f :
0 commit comments