Skip to content
This repository was archived by the owner on Sep 3, 2022. It is now read-only.

CsvDataSet no longer globs files in init. #187

Merged
merged 7 commits into from
Feb 13, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions datalab/mlalpha/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,11 @@
class CsvDataSet(object):
"""DataSet based on CSV files and schema."""

def __init__(self, files, schema=None, schema_file=None):
def __init__(self, file_pattern, schema=None, schema_file=None):
"""
Args:
files: A list of CSV files. Can contain wildcards in file names. Can be local or GCS path.
file_pattern: A list of CSV files. or a string. Can contain wildcards in
file names. Can be local or GCS path.
schema: A BigQuery schema object in the form of
[{'name': 'col1', 'type': 'STRING'},
{'name': 'col2', 'type': 'INTEGER'}]
Expand All @@ -59,21 +60,32 @@ def __init__(self, files, schema=None, schema_file=None):
else:
with ml.util._file.open_local_or_gcs(schema_file, 'r') as f:
self._schema = json.load(f)

if isinstance(files, basestring):
files = [files]
self._files = []
for file in files:
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
self._files += [str(x) for x in ml.util._file.glob_files(file)]
self._input_files = files

self._glob_files = []


@property
def _input_files(self):
"""Returns the file list that was given to this class without globing files."""
return self._input_files

@property
def files(self):
return self._files
if not self._glob_files:
for file in self._input_files:
# glob_files() returns unicode strings which doesn't make DataFlow happy. So str().
self._glob_files += [str(x) for x in ml.util._file.glob_files(file)]

return self._glob_files

@property
def schema(self):
return self._schema
return self._schema

def sample(self, n):
""" Samples data into a Pandas DataFrame.
Args:
Expand All @@ -85,7 +97,7 @@ def sample(self, n):
"""
row_total_count = 0
row_counts = []
for file in self._files:
for file in self.files:
with ml.util._file.open_local_or_gcs(file, 'r') as f:
num_lines = sum(1 for line in f)
row_total_count += num_lines
Expand All @@ -108,7 +120,7 @@ def sample(self, n):
# Note that random.sample will raise Exception if skip_count is greater than rows count.
skip_all = sorted(random.sample(xrange(0, row_total_count), skip_count))
dfs = []
for file, row_count in zip(self._files, row_counts):
for file, row_count in zip(self.files, row_counts):
skip = [x for x in skip_all if x < row_count]
skip_all = [x - row_count for x in skip_all if x >= row_count]
with ml.util._file.open_local_or_gcs(file, 'r') as f:
Expand Down