Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/datasources/googlesheets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# GoogleSheetsDataSource

::: pyspark_datasources.googlesheets.GoogleSheetsDataSource
15 changes: 8 additions & 7 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ spark.read.format("github").load("apache/spark").show()

## Data Sources

| Data Source | Short Name | Description | Dependencies |
|-----------------------------------------------------|---------------|---------------------------------------------|-----------------|
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk`|
| Data Source | Short Name | Description | Dependencies |
| ------------------------------------------------------- | -------------- | --------------------------------------------- | ---------------- |
| [GithubDataSource](./datasources/github.md) | `github` | Read pull requests from a Github repository | None |
| [FakeDataSource](./datasources/fake.md) | `fake` | Generate fake data using the `Faker` library | `faker` |
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Read JSON data from a file | `databricks-sdk` |
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ nav:
- datasources/huggingface.md
- datasources/stock.md
- datasources/simplejson.md
- datasources/googlesheets.md

markdown_extensions:
- pymdownx.highlight:
Expand Down
3 changes: 2 additions & 1 deletion pyspark_datasources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .fake import FakeDataSource
from .github import GithubDataSource
from .googlesheets import GoogleSheetsDataSource
from .huggingface import HuggingFaceDatasets
from .stock import StockDataSource
from .simplejson import SimpleJsonDataSource
from .stock import StockDataSource
195 changes: 195 additions & 0 deletions pyspark_datasources/googlesheets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from dataclasses import dataclass
from typing import Dict, Optional

from pyspark.sql.datasource import DataSource, DataSourceReader
from pyspark.sql.types import StringType, StructField, StructType


@dataclass
class Sheet:
"""
A dataclass to identify a Google Sheets document.

Attributes
----------
spreadsheet_id : str
The ID of the Google Sheets document.
sheet_id : str, optional
The ID of the worksheet within the document.
"""

spreadsheet_id: str
sheet_id: Optional[str] = None # if None, the first sheet is used

@classmethod
def from_url(cls, url: str) -> "Sheet":
"""
Converts a Google Sheets URL to a Sheet object.
"""
from urllib.parse import parse_qs, urlparse

parsed = urlparse(url)
if parsed.netloc != "docs.google.com" or not parsed.path.startswith(
"/spreadsheets/d/"
):
raise ValueError("URL is not a Google Sheets URL")
qs = parse_qs(parsed.query)
spreadsheet_id = parsed.path.split("/")[3]
if "gid" in qs:
sheet_id = qs["gid"][0]
else:
sheet_id = None
return cls(spreadsheet_id, sheet_id)

def get_query_url(self, query: Optional[str] = None):
"""
Gets the query url that returns the results of the query as a CSV file.

If no query is provided, returns the entire sheet.
If sheet ID is None, uses the first sheet.

See https://developers.google.com/chart/interactive/docs/querylanguage
"""
from urllib.parse import urlencode

path = f"https://docs.google.com/spreadsheets/d/{self.spreadsheet_id}/gviz/tq"
url_query = {"tqx": "out:csv"}
if self.sheet_id:
url_query["gid"] = self.sheet_id
if query:
url_query["tq"] = query
return f"{path}?{urlencode(url_query)}"


@dataclass
class Parameters:
sheet: Sheet
has_header: bool


class GoogleSheetsDataSource(DataSource):
"""
A DataSource for reading table from public Google Sheets.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great to also support private ones :)


Name: `googlesheets`

Schema: By default, all columns are treated as strings and the header row defines the column names.

Options
--------
- `url`: The URL of the Google Sheets document.
- `path`: The ID of the Google Sheets document.
- `sheet_id`: The ID of the worksheet within the document.
- `has_header`: Whether the sheet has a header row. Default is `true`.

Either `url` or `path` must be specified, but not both.

Examples
--------
Register the data source.

>>> from pyspark_datasources import GoogleSheetsDataSource
>>> spark.dataSource.register(GoogleSheetsDataSource)

Load data from a public Google Sheets document using `path` and optional `sheet_id`.

>>> spreadsheet_id = "10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0"
>>> spark.read.format("googlesheets").options(sheet_id="0").load(spreadsheet_id).show()
+-------+---------+---------+-------+
|country| latitude|longitude| name|
+-------+---------+---------+-------+
| AD|42.546245| 1.601554|Andorra|
| ...| ...| ...| ...|
+-------+---------+---------+-------+

Load data from a public Google Sheets document using `url`.

>>> url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=0#gid=0"
>>> spark.read.format("googlesheets").options(url=url).load().show()
+-------+---------+--------+-------+
|country| latitude|ongitude| name|
+-------+---------+--------+-------+
| AD|42.546245|1.601554|Andorra|
| ...| ...| ...| ...|
+-------+---------+--------+-------+

Specify custom schema.

>>> schema = "id string, lat double, long double, name string"
>>> spark.read.format("googlesheets").schema(schema).options(url=url).load().show()
+---+---------+--------+-------+
| id| lat| long| name|
+---+---------+--------+-------+
| AD|42.546245|1.601554|Andorra|
|...| ...| ...| ...|
+---+---------+--------+-------+

Treat first row as data instead of header.

>>> schema = "c1 string, c2 string, c3 string, c4 string"
>>> spark.read.format("googlesheets").schema(schema).options(url=url, has_header="false").load().show()
+-------+---------+---------+-------+
| c1| c2| c3| c4|
+-------+---------+---------+-------+
|country| latitude|longitude| name|
| AD|42.546245| 1.601554|Andorra|
| ...| ...| ...| ...|
+-------+---------+---------+-------+
"""

@classmethod
def name(self):
return "googlesheets"

def __init__(self, options: Dict[str, str]):
if "url" in options:
sheet = Sheet.from_url(options.pop("url"))
elif "path" in options:
sheet = Sheet(options.pop("path"), options.pop("sheet_id", None))
else:
raise ValueError(
"You must specify either `url` or `path` (spreadsheet ID)."
)
has_header = options.pop("has_header", "true").lower() == "true"
self.parameters = Parameters(sheet, has_header)

def schema(self) -> StructType:
if not self.parameters.has_header:
raise ValueError("Custom schema is required when `has_header` is false")

import pandas as pd

# Read schema from the first row of the sheet
df = pd.read_csv(self.parameters.sheet.get_query_url("select * limit 1"))
return StructType([StructField(col, StringType()) for col in df.columns])

def reader(self, schema: StructType) -> DataSourceReader:
return GoogleSheetsReader(self.parameters, schema)


class GoogleSheetsReader(DataSourceReader):

def __init__(self, parameters: Parameters, schema: StructType):
self.parameters = parameters
self.schema = schema

def read(self, partition):
from urllib.request import urlopen

from pyarrow import csv
from pyspark.sql.pandas.types import to_arrow_schema

# Specify column types based on the schema
convert_options = csv.ConvertOptions(
column_types=to_arrow_schema(self.schema),
)
read_options = csv.ReadOptions(
column_names=self.schema.fieldNames(), # Rename columns
skip_rows=(
1 if self.parameters.has_header else 0 # Skip header row if present
),
)
with urlopen(self.parameters.sheet.get_query_url()) as file:
yield from csv.read_csv(
file, read_options=read_options, convert_options=convert_options
).to_batches()
103 changes: 103 additions & 0 deletions tests/test_google_sheets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import pytest
from pyspark.errors.exceptions.captured import AnalysisException, PythonException
from pyspark.sql import SparkSession

from pyspark_datasources import GoogleSheetsDataSource


@pytest.fixture(scope="module")
def spark():
spark = SparkSession.builder.getOrCreate()
spark.dataSource.register(GoogleSheetsDataSource)
yield spark


def test_url(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = spark.read.format("googlesheets").options(url=url).load()
df.show()
assert df.count() == 2
assert len(df.columns) == 2
assert df.schema.simpleString() == "struct<num:string,name:string>"


def test_spreadsheet_id(spark):
df = spark.read.format("googlesheets").load(
"10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0"
)
df.show()
assert df.count() == 2
assert len(df.columns) == 2


def test_missing_options(spark):
with pytest.raises(AnalysisException) as excinfo:
spark.read.format("googlesheets").load()
assert "ValueError" in str(excinfo.value)


def test_mutual_exclusive_options(spark):
with pytest.raises(AnalysisException) as excinfo:
spark.read.format("googlesheets").options(
url="a",
spreadsheet_id="b",
).load()
assert "ValueError" in str(excinfo.value)


def test_custom_schema(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = (
spark.read.format("googlesheets")
.options(url=url)
.schema("a double, b string")
.load()
)
df.show()
assert df.count() == 2
assert len(df.columns) == 2
assert df.schema.simpleString() == "struct<a:double,b:string>"


def test_custom_schema_mismatch_count(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=846122797#gid=846122797"
df = spark.read.format("googlesheets").options(url=url).schema("a double").load()
with pytest.raises(PythonException) as excinfo:
df.show()
assert "CSV parse error" in str(excinfo.value)


def test_unnamed_column(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1579451727#gid=1579451727"
df = spark.read.format("googlesheets").options(url=url).load()
df.show()
assert df.count() == 1
assert df.columns == ["Unnamed: 0", "1", "Unnamed: 2"]


def test_duplicate_column(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1875209731#gid=1875209731"
df = spark.read.format("googlesheets").options(url=url).load()
df.show()
assert df.count() == 1
assert df.columns == ["a", "a.1"]


def test_no_header_row(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=1579451727#gid=1579451727"
df = (
spark.read.format("googlesheets")
.schema("a int, b int, c int")
.options(url=url, has_header="false")
.load()
)
df.show()
assert df.count() == 2
assert len(df.columns) == 3


def test_empty(spark):
url = "https://docs.google.com/spreadsheets/d/10pD8oRN3RTBJq976RKPWHuxYy0Qa_JOoGFpsaS0Lop0/edit?gid=2123944555#gid=2123944555"
with pytest.raises(AnalysisException) as excinfo:
spark.read.format("googlesheets").options(url=url).load()
assert "EmptyDataError" in str(excinfo.value)