Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/datasources/googlesheets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# GoogleSheetsDataSource

::: pyspark_datasources.googlesheets.GoogleSheetsDataSource
15 changes: 8 additions & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ spark.read.format("github").load("apache/spark").show()

## Data Sources

| Data Source | Short Name | Description | Dependencies |
|-----------------------------------------------------|---------------|---------------------------------------------|-----------------|
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk`|
| Data Source | Short Name | Description | Dependencies |
| ------------------------------------------------------- | -------------- | --------------------------------------------- | ---------------- |
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk` |
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ nav:
- datasources/huggingface.md
- datasources/stock.md
- datasources/simplejson.md
- datasources/googlesheets.md

markdown_extensions:
- pymdownx.highlight:
Expand Down
3 changes: 2 additions & 1 deletion pyspark_datasources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .fake import FakeDataSource
from .github import GithubDataSource
from .googlesheets import GoogleSheetsDataSource
from .huggingface import HuggingFaceDatasets
from .stock import StockDataSource
from .simplejson import SimpleJsonDataSource
from .stock import StockDataSource
160 changes: 160 additions & 0 deletions pyspark_datasources/googlesheets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from dataclasses import dataclass
from typing import Dict, Optional

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StringType, StructField, StructType


@dataclass
class Sheet:
"""
A dataclass to identify a Google Sheets document.

Attributes
----------
spreadsheet_id : str
The ID of the Google Sheets document.
sheet_id : str, optional
The ID of the worksheet within the document.
"""

spreadsheet_id: str
sheet_id: Optional[str] # if None, the first sheet is used

@classmethod
def from_url(cls, url: str) -> "Sheet":
"""
Converts a Google Sheets URL to a Sheet object.
"""
from urllib.parse import parse_qs, urlparse

parsed = urlparse(url)
if parsed.netloc != "docs.google.com" or not parsed.path.startswith(
"/spreadsheets/d/"
):
raise ValueError("URL is not a Google Sheets URL")
qs = parse_qs(parsed.query)
spreadsheet_id = parsed.path.split("/")[3]
if "gid" in qs:
sheet_id = qs["gid"][0]
else:
sheet_id = None
return cls(spreadsheet_id, sheet_id)

def get_query_url(self, query: str = None):
"""
Gets the query url that returns the results of the query as a CSV file.

If no query is provided, returns the entire sheet.
If sheet ID is None, uses the first sheet.

See https://developers.google.com/chart/interactive/docs/querylanguage
"""
from urllib.parse import urlencode

path = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/gviz/tq"
url_query = {"tqx": "out:csv"}
if self.sheet_id:
url_query["gid"] = self.sheet_id
if query:
url_query["tq"] = query
return f"{path}?{urlencode(url_query)}"


class GoogleSheetsDataSource(DataSource):
"""
A DataSource for reading table from public Google Sheets.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also support private ones :)


Name: `googlesheets`

Schema: By default, all columns are treated as strings and the header row defines the column names.

Examples
--------
Register the data source.

>>> from pyspark_datasources import GoogleSheetsDataSource
>>> spark.dataSource.register(GoogleSheetsDataSource)

Load data from a public Google Sheets document using url.

>>> url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=0#gid=0"
>>> spark.read.format("googlesheets").options(url=url)
+-------+----------+-----------+--------------------+
|country| latitude| longitude| name|
+-------+----------+-----------+--------------------+
| AD| 42.546245| 1.601554| Andorra|
| ...| ...| ...| ...|
+-------+----------+-----------+--------------------+

Load data using `spreadsheet_id` and optional `sheet_id`.

>>> spark.read.format("googlesheets").options(spreadsheet_id="10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0", sheet_id="0")
+-------+----------+-----------+--------------------+
|country| latitude| longitude| name|
+-------+----------+-----------+--------------------+
| AD| 42.546245| 1.601554| Andorra|
| ...| ...| ...| ...|
+-------+----------+-----------+--------------------+

Specify custom schema.

>>> spark.read.format("googlesheets").options(url=url).schema("id string, lat double, long double, name string")
+---+----------+-----------+--------------------+
| id| lat| long| name|
+---+----------+-----------+--------------------+
| AD| 42.546245| 1.601554| Andorra|
|...| ...| ...| ...|
+---+----------+-----------+--------------------+
"""

@classmethod
def name(self):
return "googlesheets"

def __init__(self, options: Dict[str, str]):
if "url" in options:
self.sheet = Sheet.from_url(options["url"])
elif "spreadsheet_id" in options:
self.sheet = Sheet(options["spreadsheet_id"], options.get("sheet_id"))
else:
raise ValueError(
"You must specify a URL or spreadsheet_id in `.options()`."
)

def schema(self) -> StructType:
import pandas as pd

# Read schema from the first row of the sheet
df = pd.read_csv(self.sheet.get_query_url("select * limit 0"))
return StructType([StructField(col, StringType()) for col in df.columns])

def reader(self, schema: StructType) -> DataSourceReader:
return GoogleSheetsReader(self.sheet, schema)


class GoogleSheetsReader(DataSourceReader):
def __init__(self, sheet: Sheet, schema: StructType):
self.sheet = sheet
self.schema = schema

def read(self, partition):
from urllib.request import urlopen

from pyarrow import csv
from pyspark.sql.pandas.types import to_arrow_type

# Specify column types based on the schema
convert_options = csv.ConvertOptions(
column_types={
field.name: to_arrow_type(field.dataType) for field in self.schema
},
)
read_options = csv.ReadOptions(
column_names=self.schema.fieldNames(), # Rename columns
skip_rows=1, # Skip the header row
)
with urlopen(self.sheet.get_query_url()) as file:
yield from csv.read_csv(
file, read_options=read_options, convert_options=convert_options
).to_batches()
83 changes: 83 additions & 0 deletions tests/test_google_sheets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import pytest
from pyspark.errors.exceptions.captured import AnalysisException, PythonException

from pyspark_datasources import GoogleSheetsDataSource

from .test_data_sources import spark


def test_url(spark):
spark.dataSource.register(GoogleSheetsDataSource)
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = spark.read.format("googlesheets").options(url=url).load()
df.show()
assert df.count() == 2
assert len(df.columns) == 2
assert df.schema.simpleString() == "struct<num:string,name:string>"


def test_spreadsheet_id(spark):
spark.dataSource.register(GoogleSheetsDataSource)
df = (
spark.read.format("googlesheets")
.options(spreadsheet_id="10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0")
.load()
)
df.show()
assert df.count() == 2
assert len(df.columns) == 2


def test_missing_options(spark):
spark.dataSource.register(GoogleSheetsDataSource)
with pytest.raises(AnalysisException) as excinfo:
spark.read.format("googlesheets").load()
assert "ValueError" in str(excinfo.value)


def test_mutual_exclusive_options(spark):
spark.dataSource.register(GoogleSheetsDataSource)
with pytest.raises(AnalysisException) as excinfo:
spark.read.format("googlesheets").options(
url="a",
spreadsheet_id="b",
).load()
assert "ValueError" in str(excinfo.value)


def test_custom_schema(spark):
spark.dataSource.register(GoogleSheetsDataSource)
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = (
spark.read.format("googlesheets")
.options(url=url)
.schema("a double, b string")
.load()
)
df.show()
assert df.count() == 2
assert len(df.columns) == 2
assert df.schema.simpleString() == "struct<a:double,b:string>"


def test_custom_schema_mismatch_count(spark):
spark.dataSource.register(GoogleSheetsDataSource)
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = spark.read.format("googlesheets").options(url=url).schema("a double").load()
with pytest.raises(PythonException) as excinfo:
df.show()
assert "CSV parse error" in str(excinfo.value)


def test_custom_schema_mismatch_type(spark):
spark.dataSource.register(GoogleSheetsDataSource)
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = (
spark.read.format("googlesheets")
.options(url=url)
.schema("a double, b double")
.load()
)
with pytest.raises(PythonException) as excinfo:
df.show()
assert "CSV conversion error" in str(excinfo.value)