Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

Commit b81244f

Browse files
brandondutraqimingj
authored andcommitted
CsvDataSet no longer globs files in init. (#187)
* CsvDataSet no longer globs files in init. * removed file_io, that fix will be done later * removed junk lines * sample uses .file * fixed csv dataset def files() * Update _dataset.py
1 parent cef4eae commit b81244f

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

datalab/mlalpha/_dataset.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,11 @@
3030
class CsvDataSet(object):
3131
"""DataSet based on CSV files and schema."""
3232

33-
def __init__(self, files, schema=None, schema_file=None):
33+
def __init__(self, file_pattern, schema=None, schema_file=None):
3434
"""
3535
Args:
36-
files: A list of CSV files. Can contain wildcards in file names. Can be local or GCS path.
36+
file_pattern: A list of CSV files. or a string. Can contain wildcards in
37+
file names. Can be local or GCS path.
3738
schema: A BigQuery schema object in the form of
3839
[{'name': 'col1', 'type': 'STRING'},
3940
{'name': 'col2', 'type': 'INTEGER'}]
@@ -59,21 +60,32 @@ def __init__(self, files, schema=None, schema_file=None):
5960
else:
6061
with ml.util._file.open_local_or_gcs(schema_file, 'r') as f:
6162
self._schema = json.load(f)
63+
6264
if isinstance(files, basestring):
6365
files = [files]
64-
self._files = []
65-
for file in files:
66-
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
67-
self._files += [str(x) for x in ml.util._file.glob_files(file)]
66+
self._input_files = files
67+
68+
self._glob_files = []
69+
70+
71+
@property
72+
def _input_files(self):
73+
"""Returns the file list that was given to this class without globing files."""
74+
return self._input_files
6875

6976
@property
7077
def files(self):
71-
return self._files
78+
if not self._glob_files:
79+
for file in self._input_files:
80+
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
81+
self._glob_files += [str(x) for x in ml.util._file.glob_files(file)]
82+
83+
return self._glob_files
7284

7385
@property
7486
def schema(self):
75-
return self._schema
76-
87+
return self._schema
88+
7789
def sample(self, n):
7890
""" Samples data into a Pandas DataFrame.
7991
Args:
@@ -85,7 +97,7 @@ def sample(self, n):
8597
"""
8698
row_total_count = 0
8799
row_counts = []
88-
for file in self._files:
100+
for file in self.files:
89101
with ml.util._file.open_local_or_gcs(file, 'r') as f:
90102
num_lines = sum(1 for line in f)
91103
row_total_count += num_lines
@@ -108,7 +120,7 @@ def sample(self, n):
108120
# Note that random.sample will raise Exception if skip_count is greater than rows count.
109121
skip_all = sorted(random.sample(xrange(0, row_total_count), skip_count))
110122
dfs = []
111-
for file, row_count in zip(self._files, row_counts):
123+
for file, row_count in zip(self.files, row_counts):
112124
skip = [x for x in skip_all if x < row_count]
113125
skip_all = [x - row_count for x in skip_all if x >= row_count]
114126
with ml.util._file.open_local_or_gcs(file, 'r') as f:

0 commit comments

Comments
 (0)