Skip to content

Commit 2bcae6a

Browse files
authored
Add Robinhood cryptocurrency data source (#23)
* Add Robinhood cryptocurrency data source - Add RobinhoodDataSource for reading crypto market data from Robinhood API - Implement NaCl cryptographic signing for API authentication - Support for individual crypto pairs and bulk loading - Add comprehensive tests with environment variable configuration - Add documentation and update project dependencies - Schema includes: symbol, price, bid_price, ask_price, updated_at * update tests * update comments * remove load all pairs * remove session and user agent * add type hints * Enhance RobinhoodDataReader with configurable base URL and expand documentation - Made the base URL configurable for testing purposes. - Added detailed options, error handling, and partitioning information to the documentation. - Updated notes to clarify requirements and references for the Robinhood API. * fix format * fix poetry.lock
1 parent 3ee0484 commit 2bcae6a

File tree

8 files changed

+481
-7
lines changed

8 files changed

+481
-7
lines changed

docs/datasources/robinhood.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# RobinhoodDataSource
2+
3+
> Requires the [`pynacl`](https://github.com/pyca/pynacl) library for cryptographic signing. You can install it manually: `pip install pynacl`
4+
> or use `pip install pyspark-data-sources[robinhood]`.
5+
6+
::: pyspark_datasources.robinhood.RobinhoodDataSource

docs/index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,5 @@ spark.readStream.format("fake").load().writeStream.format("console").start()
4242
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
4343
| [KaggleDataSource](./datasources/kaggle.md) | `kaggle` | Read datasets from Kaggle | `kagglehub`, `pandas` |
4444
| [JSONPlaceHolder](./datasources/jsonplaceholder.md) | `jsonplaceholder` | Read JSON data for testing and prototyping | None |
45-
| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects |`simple-salesforce` |
45+
| [RobinhoodDataSource](./datasources/robinhood.md) | `robinhood` | Read cryptocurrency market data from Robinhood API | `pynacl` |
46+
| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects |`simple-salesforce` |

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ nav:
2727
- datasources/googlesheets.md
2828
- datasources/kaggle.md
2929
- datasources/jsonplaceholder.md
30+
- datasources/robinhood.md
3031

3132
markdown_extensions:
3233
- pymdownx.highlight:

poetry.lock

Lines changed: 34 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,17 @@ datasets = {version = "^2.17.0", optional = true}
1919
databricks-sdk = {version = "^0.28.0", optional = true}
2020
kagglehub = {extras = ["pandas-datasets"], version = "^0.3.10", optional = true}
2121
simple-salesforce = {version = "^1.12.0", optional = true}
22+
pynacl = {version = "^1.5.0", optional = true}
2223

2324
[tool.poetry.extras]
2425
faker = ["faker"]
2526
datasets = ["datasets"]
2627
databricks = ["databricks-sdk"]
2728
kaggle = ["kagglehub"]
2829
lance = ["pylance"]
30+
robinhood = ["pynacl"]
2931
salesforce = ["simple-salesforce"]
30-
all = ["faker", "datasets", "databricks-sdk", "kagglehub", "simple-salesforce"]
32+
all = ["faker", "datasets", "databricks-sdk", "kagglehub", "pynacl", "simple-salesforce"]
3133

3234
[tool.poetry.group.dev.dependencies]
3335
pytest = "^8.0.0"

pyspark_datasources/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .huggingface import HuggingFaceDatasets
66
from .kaggle import KaggleDataSource
77
from .opensky import OpenSkyDataSource
8+
from .robinhood import RobinhoodDataSource
89
from .salesforce import SalesforceDataSource
910
from .simplejson import SimpleJsonDataSource
1011
from .stock import StockDataSource

pyspark_datasources/robinhood.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
from dataclasses import dataclass
2+
from typing import Dict, List, Optional, Generator, Union
3+
import requests
4+
import json
5+
import base64
6+
import datetime
7+
8+
from pyspark.sql import Row
9+
from pyspark.sql.types import StructType
10+
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
11+
12+
13+
@dataclass
14+
class CryptoPair(InputPartition):
15+
"""Represents a single crypto trading pair partition for parallel processing."""
16+
17+
symbol: str
18+
19+
20+
class RobinhoodDataReader(DataSourceReader):
21+
"""Reader implementation for Robinhood Crypto API data source."""
22+
23+
def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
24+
self.schema = schema
25+
self.options = options
26+
27+
# Required API authentication
28+
self.api_key = options.get("api_key")
29+
self.private_key_base64 = options.get("private_key")
30+
31+
if not self.api_key or not self.private_key_base64:
32+
raise ValueError(
33+
"Robinhood Crypto API requires both 'api_key' and 'private_key' options. "
34+
"The private_key should be base64-encoded. "
35+
"Get your API credentials from https://docs.robinhood.com/crypto/trading/"
36+
)
37+
38+
# Initialize NaCl signing key
39+
try:
40+
from nacl.signing import SigningKey
41+
42+
private_key_seed = base64.b64decode(self.private_key_base64)
43+
self.signing_key = SigningKey(private_key_seed)
44+
except ImportError:
45+
raise ImportError(
46+
"PyNaCl library is required for Robinhood Crypto API authentication. "
47+
"Install it with: pip install pynacl"
48+
)
49+
except Exception as e:
50+
raise ValueError(f"Invalid private key format: {str(e)}")
51+
52+
# Crypto API base URL (configurable for testing)
53+
self.base_url = options.get("base_url", "https://trading.robinhood.com")
54+
55+
def _get_current_timestamp(self) -> int:
56+
"""Get current UTC timestamp."""
57+
return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp())
58+
59+
def _generate_signature(self, timestamp: int, method: str, path: str, body: str = "") -> str:
60+
"""Generate NaCl signature for API authentication following Robinhood's specification."""
61+
# Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}"
62+
# For GET requests with no body, omit the body parameter
63+
if method.upper() == "GET" and not body:
64+
message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}"
65+
else:
66+
message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}{body}"
67+
68+
signed = self.signing_key.sign(message_to_sign.encode("utf-8"))
69+
signature = base64.b64encode(signed.signature).decode("utf-8")
70+
return signature
71+
72+
def _make_authenticated_request(
73+
self,
74+
method: str,
75+
path: str,
76+
params: Optional[Dict[str, str]] = None,
77+
json_data: Optional[Dict] = None,
78+
) -> Optional[Dict]:
79+
"""Make an authenticated request to the Robinhood Crypto API."""
80+
timestamp = self._get_current_timestamp()
81+
url = self.base_url + path
82+
83+
# Prepare request body for signature (only for non-GET requests)
84+
body = ""
85+
if method.upper() != "GET" and json_data:
86+
body = json.dumps(json_data, separators=(",", ":")) # Compact JSON format
87+
88+
# Generate signature
89+
signature = self._generate_signature(timestamp, method, path, body)
90+
91+
# Set authentication headers
92+
headers = {
93+
"x-api-key": self.api_key,
94+
"x-signature": signature,
95+
"x-timestamp": str(timestamp),
96+
}
97+
98+
try:
99+
# Make request
100+
if method.upper() == "GET":
101+
response = requests.get(url, headers=headers, params=params, timeout=10)
102+
elif method.upper() == "POST":
103+
headers["Content-Type"] = "application/json"
104+
response = requests.post(url, headers=headers, json=json_data, timeout=10)
105+
else:
106+
response = requests.request(
107+
method, url, headers=headers, params=params, json=json_data, timeout=10
108+
)
109+
110+
response.raise_for_status()
111+
return response.json()
112+
except requests.RequestException as e:
113+
print(f"Error making API request to {path}: {e}")
114+
return None
115+
116+
@staticmethod
117+
def _get_query_params(key: str, *args: str) -> str:
118+
"""Build query parameters for API requests."""
119+
if not args:
120+
return ""
121+
params = [f"{key}={arg}" for arg in args if arg]
122+
return "?" + "&".join(params)
123+
124+
def partitions(self) -> List[CryptoPair]:
125+
"""Create partitions for parallel processing of crypto pairs."""
126+
# Use specified symbols from path
127+
symbols_str = self.options.get("path", "")
128+
if not symbols_str:
129+
raise ValueError("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')")
130+
131+
# Split symbols by comma and create partitions
132+
symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
133+
# Ensure proper format (e.g., BTC-USD)
134+
formatted_symbols = []
135+
for symbol in symbols:
136+
if symbol and "-" not in symbol:
137+
symbol = f"{symbol}-USD" # Default to USD pair
138+
if symbol:
139+
formatted_symbols.append(symbol)
140+
141+
return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
142+
143+
def read(self, partition: CryptoPair) -> Generator[Row, None, None]:
144+
"""Read crypto data for a single trading pair partition."""
145+
symbol = partition.symbol
146+
147+
try:
148+
yield from self._read_crypto_pair_data(symbol)
149+
except Exception as e:
150+
# Log error but don't fail the entire job
151+
print(f"Warning: Failed to fetch data for {symbol}: {str(e)}")
152+
153+
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
154+
"""Fetch cryptocurrency market data for a given trading pair."""
155+
try:
156+
# Get best bid/ask data for the trading pair using query parameters
157+
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
158+
market_data = self._make_authenticated_request("GET", path)
159+
160+
if market_data and "results" in market_data:
161+
for quote in market_data["results"]:
162+
# Parse numeric values safely
163+
def safe_float(
164+
value: Union[str, int, float, None], default: float = 0.0
165+
) -> float:
166+
if value is None or value == "":
167+
return default
168+
try:
169+
return float(value)
170+
except (ValueError, TypeError):
171+
return default
172+
173+
# Extract market data fields from best bid/ask response
174+
# Use the correct field names from the API response
175+
price = safe_float(quote.get("price"))
176+
bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread"))
177+
ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread"))
178+
179+
yield Row(
180+
symbol=symbol,
181+
price=price,
182+
bid_price=bid_price,
183+
ask_price=ask_price,
184+
updated_at=quote.get("timestamp", ""),
185+
)
186+
else:
187+
print(f"Warning: No market data found for {symbol}")
188+
189+
except requests.exceptions.RequestException as e:
190+
print(f"Network error fetching data for {symbol}: {str(e)}")
191+
except (ValueError, KeyError) as e:
192+
print(f"Data parsing error for {symbol}: {str(e)}")
193+
except Exception as e:
194+
print(f"Unexpected error fetching data for {symbol}: {str(e)}")
195+
196+
197+
class RobinhoodDataSource(DataSource):
198+
"""
199+
A data source for reading cryptocurrency data from Robinhood Crypto API.
200+
201+
This data source allows you to fetch real-time cryptocurrency market data,
202+
trading pairs, and price information using Robinhood's official Crypto API.
203+
It implements proper API key authentication and signature-based security.
204+
205+
Name: `robinhood`
206+
207+
Schema: `symbol string, price double, bid_price double, ask_price double, updated_at string`
208+
209+
Examples
210+
--------
211+
Register the data source:
212+
213+
>>> from pyspark_datasources import RobinhoodDataSource
214+
>>> spark.dataSource.register(RobinhoodDataSource)
215+
216+
Load cryptocurrency market data with API authentication:
217+
218+
>>> df = spark.read.format("robinhood") \\
219+
... .option("api_key", "your-api-key") \\
220+
... .option("private_key", "your-base64-private-key") \\
221+
... .load("BTC-USD,ETH-USD,DOGE-USD")
222+
>>> df.show()
223+
+--------+--------+---------+---------+--------------------+
224+
| symbol| price|bid_price|ask_price| updated_at|
225+
+--------+--------+---------+---------+--------------------+
226+
|BTC-USD |45000.50|45000.25 |45000.75 |2024-01-15T16:00:...|
227+
|ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...|
228+
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
229+
+--------+--------+---------+---------+--------------------+
230+
231+
232+
233+
Options
234+
-------
235+
- api_key: string (required) — Robinhood Crypto API key.
236+
- private_key: string (required) — Base64-encoded Ed25519 private key seed.
237+
- base_url: string (optional, default "https://trading.robinhood.com") — Override for sandbox/testing.
238+
239+
Errors
240+
------
241+
- Raises ValueError when required options are missing or private_key is invalid.
242+
- Network/API errors are logged and skipped per symbol; no rows are emitted for failed symbols.
243+
244+
Partitioning
245+
------------
246+
- One partition per requested trading pair (e.g., "BTC-USD,ETH-USD"). Symbols are uppercased and auto-appended with "-USD" if missing pair format.
247+
248+
Arrow
249+
-----
250+
- Rows are yielded directly; Arrow-based batches can be added in future for improved performance.
251+
252+
Notes
253+
-----
254+
- Requires 'pynacl' for Ed25519 signing: pip install pynacl
255+
- Refer to official Robinhood documentation for authentication details.
256+
"""
257+
258+
@classmethod
259+
def name(cls) -> str:
260+
return "robinhood"
261+
262+
def schema(self) -> str:
263+
return "symbol string, price double, bid_price double, ask_price double, updated_at string"
264+
265+
def reader(self, schema: StructType) -> RobinhoodDataReader:
266+
return RobinhoodDataReader(schema, self.options)

0 commit comments

Comments
 (0)