Skip to content

Commit 8949a85

Browse files
committed
fix format
1 parent 9347ef2 commit 8949a85

File tree

2 files changed

+126
-95
lines changed

2 files changed

+126
-95
lines changed

pyspark_datasources/robinhood.py

Lines changed: 48 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,10 @@
1010
from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
1111

1212

13-
14-
15-
1613
@dataclass
1714
class CryptoPair(InputPartition):
1815
"""Represents a single crypto trading pair partition for parallel processing."""
16+
1917
symbol: str
2018

2119

@@ -25,21 +23,22 @@ class RobinhoodDataReader(DataSourceReader):
2523
def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
2624
self.schema = schema
2725
self.options = options
28-
26+
2927
# Required API authentication
3028
self.api_key = options.get("api_key")
3129
self.private_key_base64 = options.get("private_key")
32-
30+
3331
if not self.api_key or not self.private_key_base64:
3432
raise ValueError(
3533
"Robinhood Crypto API requires both 'api_key' and 'private_key' options. "
3634
"The private_key should be base64-encoded. "
3735
"Get your API credentials from https://docs.robinhood.com/crypto/trading/"
3836
)
39-
37+
4038
# Initialize NaCl signing key
4139
try:
4240
from nacl.signing import SigningKey
41+
4342
private_key_seed = base64.b64decode(self.private_key_base64)
4443
self.signing_key = SigningKey(private_key_seed)
4544
except ImportError:
@@ -49,17 +48,14 @@ def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
4948
)
5049
except Exception as e:
5150
raise ValueError(f"Invalid private key format: {str(e)}")
52-
53-
5451

55-
5652
# Crypto API base URL (configurable for testing)
5753
self.base_url = options.get("base_url", "https://trading.robinhood.com")
5854

5955
def _get_current_timestamp(self) -> int:
6056
"""Get current UTC timestamp."""
6157
return int(datetime.datetime.now(tz=datetime.timezone.utc).timestamp())
62-
58+
6359
def _generate_signature(self, timestamp: int, method: str, path: str, body: str = "") -> str:
6460
"""Generate NaCl signature for API authentication following Robinhood's specification."""
6561
# Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}"
@@ -68,41 +64,49 @@ def _generate_signature(self, timestamp: int, method: str, path: str, body: str
6864
message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}"
6965
else:
7066
message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}{body}"
71-
67+
7268
signed = self.signing_key.sign(message_to_sign.encode("utf-8"))
7369
signature = base64.b64encode(signed.signature).decode("utf-8")
7470
return signature
7571

76-
def _make_authenticated_request(self, method: str, path: str, params: Optional[Dict[str, str]] = None, json_data: Optional[Dict] = None) -> Optional[Dict]:
72+
def _make_authenticated_request(
73+
self,
74+
method: str,
75+
path: str,
76+
params: Optional[Dict[str, str]] = None,
77+
json_data: Optional[Dict] = None,
78+
) -> Optional[Dict]:
7779
"""Make an authenticated request to the Robinhood Crypto API."""
7880
timestamp = self._get_current_timestamp()
7981
url = self.base_url + path
80-
82+
8183
# Prepare request body for signature (only for non-GET requests)
8284
body = ""
8385
if method.upper() != "GET" and json_data:
84-
body = json.dumps(json_data, separators=(',', ':')) # Compact JSON format
85-
86+
body = json.dumps(json_data, separators=(",", ":")) # Compact JSON format
87+
8688
# Generate signature
8789
signature = self._generate_signature(timestamp, method, path, body)
88-
90+
8991
# Set authentication headers
9092
headers = {
91-
'x-api-key': self.api_key,
92-
'x-signature': signature,
93-
'x-timestamp': str(timestamp)
93+
"x-api-key": self.api_key,
94+
"x-signature": signature,
95+
"x-timestamp": str(timestamp),
9496
}
95-
97+
9698
try:
9799
# Make request
98100
if method.upper() == "GET":
99101
response = requests.get(url, headers=headers, params=params, timeout=10)
100102
elif method.upper() == "POST":
101-
headers['Content-Type'] = 'application/json'
103+
headers["Content-Type"] = "application/json"
102104
response = requests.post(url, headers=headers, json=json_data, timeout=10)
103105
else:
104-
response = requests.request(method, url, headers=headers, params=params, json=json_data, timeout=10)
105-
106+
response = requests.request(
107+
method, url, headers=headers, params=params, json=json_data, timeout=10
108+
)
109+
106110
response.raise_for_status()
107111
return response.json()
108112
except requests.RequestException as e:
@@ -116,32 +120,30 @@ def _get_query_params(key: str, *args: str) -> str:
116120
return ""
117121
params = [f"{key}={arg}" for arg in args if arg]
118122
return "?" + "&".join(params)
119-
123+
120124
def partitions(self) -> List[CryptoPair]:
121125
"""Create partitions for parallel processing of crypto pairs."""
122126
# Use specified symbols from path
123127
symbols_str = self.options.get("path", "")
124128
if not symbols_str:
125-
raise ValueError(
126-
"Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')"
127-
)
128-
129+
raise ValueError("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')")
130+
129131
# Split symbols by comma and create partitions
130132
symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
131133
# Ensure proper format (e.g., BTC-USD)
132134
formatted_symbols = []
133135
for symbol in symbols:
134-
if symbol and '-' not in symbol:
136+
if symbol and "-" not in symbol:
135137
symbol = f"{symbol}-USD" # Default to USD pair
136138
if symbol:
137139
formatted_symbols.append(symbol)
138-
140+
139141
return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
140142

141143
def read(self, partition: CryptoPair) -> Generator[Row, None, None]:
142144
"""Read crypto data for a single trading pair partition."""
143145
symbol = partition.symbol
144-
146+
145147
try:
146148
yield from self._read_crypto_pair_data(symbol)
147149
except Exception as e:
@@ -154,34 +156,36 @@ def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
154156
# Get best bid/ask data for the trading pair using query parameters
155157
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
156158
market_data = self._make_authenticated_request("GET", path)
157-
158-
if market_data and 'results' in market_data:
159-
for quote in market_data['results']:
159+
160+
if market_data and "results" in market_data:
161+
for quote in market_data["results"]:
160162
# Parse numeric values safely
161-
def safe_float(value: Union[str, int, float, None], default: float = 0.0) -> float:
163+
def safe_float(
164+
value: Union[str, int, float, None], default: float = 0.0
165+
) -> float:
162166
if value is None or value == "":
163167
return default
164168
try:
165169
return float(value)
166170
except (ValueError, TypeError):
167171
return default
168-
172+
169173
# Extract market data fields from best bid/ask response
170174
# Use the correct field names from the API response
171-
price = safe_float(quote.get('price'))
172-
bid_price = safe_float(quote.get('bid_inclusive_of_sell_spread'))
173-
ask_price = safe_float(quote.get('ask_inclusive_of_buy_spread'))
174-
175+
price = safe_float(quote.get("price"))
176+
bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread"))
177+
ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread"))
178+
175179
yield Row(
176180
symbol=symbol,
177181
price=price,
178182
bid_price=bid_price,
179183
ask_price=ask_price,
180-
updated_at=quote.get('timestamp', "")
184+
updated_at=quote.get("timestamp", ""),
181185
)
182186
else:
183187
print(f"Warning: No market data found for {symbol}")
184-
188+
185189
except requests.exceptions.RequestException as e:
186190
print(f"Network error fetching data for {symbol}: {str(e)}")
187191
except (ValueError, KeyError) as e:
@@ -256,10 +260,7 @@ def name(cls) -> str:
256260
return "robinhood"
257261

258262
def schema(self) -> str:
259-
return (
260-
"symbol string, price double, bid_price double, ask_price double, "
261-
"updated_at string"
262-
)
263+
return "symbol string, price double, bid_price double, ask_price double, updated_at string"
263264

264265
def reader(self, schema: StructType) -> RobinhoodDataReader:
265-
return RobinhoodDataReader(schema, self.options)
266+
return RobinhoodDataReader(schema, self.options)

0 commit comments

Comments
 (0)