Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/datasources/robinhood.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# RobinhoodDataSource

> Requires the [`pynacl`](https://github.com/pyca/pynacl) library for cryptographic signing. You can install it manually: `pip install pynacl`
> or use `pip install pyspark-data-sources[robinhood]`.

::: pyspark_datasources.robinhood.RobinhoodDataSource
3 changes: 2 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,5 @@ spark.readStream.format("fake").load().writeStream.format("console").start()
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
| [KaggleDataSource](./datasources/kaggle.md) | `kaggle` | Read datasets from Kaggle | `kagglehub`, `pandas` |
| [JSONPlaceHolder](./datasources/jsonplaceholder.md) | `jsonplaceholder` | Read JSON data for testing and prototyping | None |
| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects |`simple-salesforce` |
| [RobinhoodDataSource](./datasources/robinhood.md) | `robinhood` | Read cryptocurrency market data from Robinhood API | `pynacl` |
| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects |`simple-salesforce` |
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ nav:
- datasources/googlesheets.md
- datasources/kaggle.md
- datasources/jsonplaceholder.md
- datasources/robinhood.md

markdown_extensions:
- pymdownx.highlight:
Expand Down
39 changes: 34 additions & 5 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@ datasets = {version = "^2.17.0", optional = true}
databricks-sdk = {version = "^0.28.0", optional = true}
kagglehub = {extras = ["pandas-datasets"], version = "^0.3.10", optional = true}
simple-salesforce = {version = "^1.12.0", optional = true}
pynacl = {version = "^1.5.0", optional = true}

[tool.poetry.extras]
faker = ["faker"]
datasets = ["datasets"]
databricks = ["databricks-sdk"]
kaggle = ["kagglehub"]
lance = ["pylance"]
robinhood = ["pynacl"]
salesforce = ["simple-salesforce"]
all = ["faker", "datasets", "databricks-sdk", "kagglehub", "simple-salesforce"]
all = ["faker", "datasets", "databricks-sdk", "kagglehub", "pynacl", "simple-salesforce"]

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
Expand Down
1 change: 1 addition & 0 deletions pyspark_datasources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .huggingface import HuggingFaceDatasets
from .kaggle import KaggleDataSource
from .opensky import OpenSkyDataSource
from .robinhood import RobinhoodDataSource
from .salesforce import SalesforceDataSource
from .simplejson import SimpleJsonDataSource
from .stock import StockDataSource
Expand Down
266 changes: 266 additions & 0 deletions pyspark_datasources/robinhood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Generator, Union
import requests
import json
import base64
import datetime

from pyspark.sql import Row
from pyspark.sql.types import StructType
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition


@dataclass
class CryptoPair(InputPartition):
"""Represents a single crypto trading pair partition for parallel processing."""

symbol: str


class RobinhoodDataReader(DataSourceReader):
"""Reader implementation for Robinhood Crypto API data source."""

def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
self.schema = schema
self.options = options

# Required API authentication
self.api_key = options.get("api_key")
self.private_key_base64 = options.get("private_key")

if not self.api_key or not self.private_key_base64:
raise ValueError(
"Robinhood Crypto API requires both 'api_key' and 'private_key' options. "
"The private_key should be base64-encoded. "
"Get your API credentials from https://docs.robinhood.com/crypto/trading/"
)

# Initialize NaCl signing key
try:
from nacl.signing import SigningKey

private_key_seed = base64.b64decode(self.private_key_base64)
self.signing_key = SigningKey(private_key_seed)
except ImportError:
raise ImportError(
"PyNaCl library is required for Robinhood Crypto API authentication. "
"Install it with: pip install pynacl"
)
except Exception as e:
raise ValueError(f"Invalid private key format: {str(e)}")

# Crypto API base URL (configurable for testing)
self.base_url = options.get("base_url", "https://trading.robinhood.com")

def _get_current_timestamp(self) -> int:
"""Get current UTC timestamp."""
return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp())

def _generate_signature(self, timestamp: int, method: str, path: str, body: str = "") -> str:
"""Generate NaCl signature for API authentication following Robinhood's specification."""
# Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}"
# For GET requests with no body, omit the body parameter
if method.upper() == "GET" and not body:
message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}"
else:
message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}{body}"

signed = self.signing_key.sign(message_to_sign.encode("utf-8"))
signature = base64.b64encode(signed.signature).decode("utf-8")
return signature

def _make_authenticated_request(
self,
method: str,
path: str,
params: Optional[Dict[str, str]] = None,
json_data: Optional[Dict] = None,
) -> Optional[Dict]:
"""Make an authenticated request to the Robinhood Crypto API."""
timestamp = self._get_current_timestamp()
url = self.base_url + path

# Prepare request body for signature (only for non-GET requests)
body = ""
if method.upper() != "GET" and json_data:
body = json.dumps(json_data, separators=(",", ":")) # Compact JSON format

# Generate signature
signature = self._generate_signature(timestamp, method, path, body)

# Set authentication headers
headers = {
"x-api-key": self.api_key,
"x-signature": signature,
"x-timestamp": str(timestamp),
}

try:
# Make request
if method.upper() == "GET":
response = requests.get(url, headers=headers, params=params, timeout=10)
elif method.upper() == "POST":
headers["Content-Type"] = "application/json"
response = requests.post(url, headers=headers, json=json_data, timeout=10)
else:
response = requests.request(
method, url, headers=headers, params=params, json=json_data, timeout=10
)

response.raise_for_status()
return response.json()
except requests.RequestException as e:
print(f"Error making API request to {path}: {e}")
return None

@staticmethod
def _get_query_params(key: str, *args: str) -> str:
"""Build query parameters for API requests."""
if not args:
return ""
params = [f"{key}={arg}" for arg in args if arg]
return "?" + "&".join(params)

def partitions(self) -> List[CryptoPair]:
"""Create partitions for parallel processing of crypto pairs."""
# Use specified symbols from path
symbols_str = self.options.get("path", "")
if not symbols_str:
raise ValueError("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')")

# Split symbols by comma and create partitions
symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
# Ensure proper format (e.g., BTC-USD)
formatted_symbols = []
for symbol in symbols:
if symbol and "-" not in symbol:
symbol = f"{symbol}-USD" # Default to USD pair
if symbol:
formatted_symbols.append(symbol)

return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]

def read(self, partition: CryptoPair) -> Generator[Row, None, None]:
"""Read crypto data for a single trading pair partition."""
symbol = partition.symbol

try:
yield from self._read_crypto_pair_data(symbol)
except Exception as e:
# Log error but don't fail the entire job
print(f"Warning: Failed to fetch data for {symbol}: {str(e)}")

def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
"""Fetch cryptocurrency market data for a given trading pair."""
try:
# Get best bid/ask data for the trading pair using query parameters
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support passing timestamp to query for historical data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, it is not supported at the moment.

market_data = self._make_authenticated_request("GET", path)

if market_data and "results" in market_data:
for quote in market_data["results"]:
# Parse numeric values safely
def safe_float(
value: Union[str, int, float, None], default: float = 0.0
) -> float:
if value is None or value == "":
return default
try:
return float(value)
except (ValueError, TypeError):
return default

# Extract market data fields from best bid/ask response
# Use the correct field names from the API response
price = safe_float(quote.get("price"))
bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread"))
ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread"))

yield Row(
symbol=symbol,
price=price,
bid_price=bid_price,
ask_price=ask_price,
updated_at=quote.get("timestamp", ""),
)
else:
print(f"Warning: No market data found for {symbol}")

except requests.exceptions.RequestException as e:
print(f"Network error fetching data for {symbol}: {str(e)}")
except (ValueError, KeyError) as e:
Comment on lines +153 to +191
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Build query via params, not inline; unify error logging.

Keep path constant and pass params, letting _make_authenticated_request assemble and sign deterministically.

-            # Get best bid/ask data for the trading pair using query parameters
-            path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
-            market_data = self._make_authenticated_request("GET", path)
+            # Get best bid/ask for the trading pair
+            path = "/api/v1/crypto/marketdata/best_bid_ask/"
+            market_data = self._make_authenticated_request("GET", path, params={"symbol": symbol})
@@
-            else:
-                print(f"Warning: No market data found for {symbol}")
+            else:
+                logger.warning("No market data found for %s", symbol)
@@
-        except requests.exceptions.RequestException as e:
-            print(f"Network error fetching data for {symbol}: {str(e)}")
-        except (ValueError, KeyError) as e:
-            print(f"Data parsing error for {symbol}: {str(e)}")
-        except Exception as e:
-            print(f"Unexpected error fetching data for {symbol}: {str(e)}")
+        except requests.exceptions.RequestException as e:
+            logger.error("Network error fetching data for %s: %s", symbol, str(e))
+        except (ValueError, KeyError) as e:
+            logger.error("Data parsing error for %s: %s", symbol, str(e))
+        except Exception as e:
+            logger.error("Unexpected error fetching data for %s: %s", symbol, str(e))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
"""Fetch cryptocurrency market data for a given trading pair."""
try:
# Get best bid/ask data for the trading pair using query parameters
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
market_data = self._make_authenticated_request("GET", path)
if market_data and "results" in market_data:
for quote in market_data["results"]:
# Parse numeric values safely
def safe_float(
value: Union[str, int, float, None], default: float = 0.0
) -> float:
if value is None or value == "":
return default
try:
return float(value)
except (ValueError, TypeError):
return default
# Extract market data fields from best bid/ask response
# Use the correct field names from the API response
price = safe_float(quote.get("price"))
bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread"))
ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread"))
yield Row(
symbol=symbol,
price=price,
bid_price=bid_price,
ask_price=ask_price,
updated_at=quote.get("timestamp", ""),
)
else:
print(f"Warning: No market data found for {symbol}")
except requests.exceptions.RequestException as e:
print(f"Network error fetching data for {symbol}: {str(e)}")
except (ValueError, KeyError) as e:
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
"""Fetch cryptocurrency market data for a given trading pair."""
try:
- # Get best bid/ask data for the trading pair using query parameters
- path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
# Get best bid/ask for the trading pair
path = "/api/v1/crypto/marketdata/best_bid_ask/"
market_data = self._make_authenticated_request(
"GET",
path,
params={"symbol": symbol}
)
if market_data and "results" in market_data:
for quote in market_data["results"]:
@@
- else:
else:
logger.warning("No market data found for %s", symbol)
- except requests.exceptions.RequestException as e:
- print(f"Network error fetching data for {symbol}: {str(e)}")
- except (ValueError, KeyError) as e:
- print(f"Data parsing error for {symbol}: {str(e)}")
- except Exception as e:
except requests.exceptions.RequestException as e:
logger.error("Network error fetching data for %s: %s", symbol, str(e))
except (ValueError, KeyError) as e:
logger.error("Data parsing error for %s: %s", symbol, str(e))
except Exception as e:
logger.error("Unexpected error fetching data for %s: %s", symbol, str(e))

print(f"Data parsing error for {symbol}: {str(e)}")
except Exception as e:
print(f"Unexpected error fetching data for {symbol}: {str(e)}")


class RobinhoodDataSource(DataSource):
"""
A data source for reading cryptocurrency data from Robinhood Crypto API.
This data source allows you to fetch real-time cryptocurrency market data,
trading pairs, and price information using Robinhood's official Crypto API.
It implements proper API key authentication and signature-based security.
Name: `robinhood`
Schema: `symbol string, price double, bid_price double, ask_price double, updated_at string`
Examples
--------
Register the data source:
>>> from pyspark_datasources import RobinhoodDataSource
>>> spark.dataSource.register(RobinhoodDataSource)
Load cryptocurrency market data with API authentication:
>>> df = spark.read.format("robinhood") \\
... .option("api_key", "your-api-key") \\
... .option("private_key", "your-base64-private-key") \\
... .load("BTC-USD,ETH-USD,DOGE-USD")
>>> df.show()
+--------+--------+---------+---------+--------------------+
| symbol| price|bid_price|ask_price| updated_at|
+--------+--------+---------+---------+--------------------+
|BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...|
|ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...|
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
+--------+--------+---------+---------+--------------------+
Options
-------
- api_key: string (required) — Robinhood Crypto API key.
- private_key: string (required) — Base64-encoded Ed25519 private key seed.
- base_url: string (optional, default "https://trading.robinhood.com") — Override for sandbox/testing.
Errors
------
- Raises ValueError when required options are missing or private_key is invalid.
- Network/API errors are logged and skipped per symbol; no rows are emitted for failed symbols.
Partitioning
------------
- One partition per requested trading pair (e.g., "BTC-USD,ETH-USD"). Symbols are uppercased and auto-appended with "-USD" if missing pair format.
Arrow
-----
- Rows are yielded directly; Arrow-based batches can be added in future for improved performance.
Notes
-----
- Requires 'pynacl' for Ed25519 signing: pip install pynacl
- Refer to official Robinhood documentation for authentication details.
"""

@classmethod
def name(cls) -> str:
return "robinhood"

def schema(self) -> str:
return "symbol string, price double, bid_price double, ask_price double, updated_at string"

def reader(self, schema: StructType) -> RobinhoodDataReader:
return RobinhoodDataReader(schema, self.options)
Loading