Skip to content

Commit

Permalink
Merge pull request #6 from synthesized-io/gcp-storage-access
Browse files Browse the repository at this point in the history
GCP storage access, addition of large datasets, and enablement of parquet loading
  • Loading branch information
hdaly0 authored Jan 17, 2024
2 parents b15f26c + 8a58683 commit a9071b0
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ dynamic = ["version"]
dependencies = [
"pandas >= 1.2",
"pyspark>=0.7.0",
"fastparquet"
]

[project.optional-dependencies]
Expand Down
2 changes: 2 additions & 0 deletions src/synthesized_datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,12 @@
_Dataset("bitcoin_price", "time-series/bitcoin_price.csv", [_Tag.FINANCE, _Tag.TIME_SERIES])
_Dataset("brent_oil_prices", "time-series/brent-oil-prices.csv", [_Tag.FINANCE, _Tag.TIME_SERIES])
_Dataset("simple_fraud", "time-series/fraud-time-series.csv", [_Tag.FRAUD, _Tag.BINARY_CLASSIFICATION])
_Dataset("simple_fraud_5gb", "https://storage.googleapis.com/synthesized-datasets-public/simple_fraud_5GB.parquet", [_Tag.FRAUD, _Tag.BINARY_CLASSIFICATION])
_Dataset("household_power_consumption_small", "time-series/household_power_consumption_small.csv", [_Tag.TIME_SERIES])
_Dataset("mock_medical_data", "time-series/mock_medical_data.csv", [_Tag.HEALTHCARE, _Tag.TIME_SERIES])
_Dataset("noaa_isd_weather_additional_dtypes_small", "time-series/NoaaIsdWeather_added_dtypes_small.csv", [_Tag.TIME_SERIES])
_Dataset("noaa_isd_weather_additional_dtypes_medium", "time-series/NoaaIsdWeather_added_dtypes_medium.csv", [_Tag.TIME_SERIES])
_Dataset("noaa_isd_weather_additional_dtypes_100gb", "https://storage.googleapis.com/synthesized-datasets-public/noaa_100gb_dtypes_set.parquet", [_Tag.TIME_SERIES])
_Dataset("occupancy_data", "time-series/occupancy-data.csv", [_Tag.TIME_SERIES])
_Dataset("s_and_p_500_5yr", "time-series/sandp500_5yr.csv", [_Tag.FINANCE, _Tag.TIME_SERIES])
_Dataset("time_series_basic", "time-series/time_series_basic.csv", [_Tag.TIME_SERIES])
Expand Down
16 changes: 12 additions & 4 deletions src/synthesized_datasets/_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pyspark import SparkFiles as _SparkFiles


_ROOT_URL = "https://raw.githubusercontent.com/synthesized-io/datasets/master/"
_ROOT_GITHUB_URL = "https://raw.githubusercontent.com/synthesized-io/datasets/master/"


class _Tag(_Enum):
Expand All @@ -31,7 +31,7 @@ def __repr__(self):
class _Dataset:
def __init__(self, name: str, url: str, tags: _typing.Optional[_typing.List[_Tag]] = None):
self._name = name
self._url = _ROOT_URL + url
self._url = url if url.startswith("https://storage.googleapis.com") else _ROOT_GITHUB_URL + url
self._tags: _typing.List[_Tag] = tags if tags is not None else []
_REGISTRIES[_Tag.ALL]._register(self)
for tag in self._tags:
Expand All @@ -51,7 +51,11 @@ def tags(self) -> _typing.List[_Tag]:

def load(self) -> _pd.DataFrame:
"""Loads the dataset."""
df = _pd.read_csv(self.url)
if self.url.endswith("parquet"):
df = _pd.read_parquet(self.url)
else:
# CSV load is the default
df = _pd.read_csv(self.url)
df.attrs["name"] = self.name
return df

Expand All @@ -63,7 +67,11 @@ def load_spark(self, spark: _typing.Optional[_ps.SparkSession] = None) -> _ps.Da

spark.sparkContext.addFile(self.url)
_, filename = _os.path.split(self.url)
df = spark.read.csv(_SparkFiles.get(filename), header=True, inferSchema=True)
if self.url.endswith("parquet"):
df = spark.read.parquet(_SparkFiles.get(filename))
else:
# CSV load is the default
df = spark.read.csv(_SparkFiles.get(filename), header=True, inferSchema=True)
df.name = self.name
return df

Expand Down

0 comments on commit a9071b0

Please sign in to comment.