1010from  pyspark .sql .datasource  import  DataSource , DataSourceReader , InputPartition 
1111
1212
13- 
14- 
15- 
1613@dataclass  
1714class  CryptoPair (InputPartition ):
1815    """Represents a single crypto trading pair partition for parallel processing.""" 
16+ 
1917    symbol : str 
2018
2119
@@ -25,21 +23,22 @@ class RobinhoodDataReader(DataSourceReader):
2523    def  __init__ (self , schema : StructType , options : Dict [str , str ]) ->  None :
2624        self .schema  =  schema 
2725        self .options  =  options 
28-          
26+ 
2927        # Required API authentication 
3028        self .api_key  =  options .get ("api_key" )
3129        self .private_key_base64  =  options .get ("private_key" )
32-          
30+ 
3331        if  not  self .api_key  or  not  self .private_key_base64 :
3432            raise  ValueError (
3533                "Robinhood Crypto API requires both 'api_key' and 'private_key' options. " 
3634                "The private_key should be base64-encoded. " 
3735                "Get your API credentials from https://docs.robinhood.com/crypto/trading/" 
3836            )
39-          
37+ 
4038        # Initialize NaCl signing key 
4139        try :
4240            from  nacl .signing  import  SigningKey 
41+ 
4342            private_key_seed  =  base64 .b64decode (self .private_key_base64 )
4443            self .signing_key  =  SigningKey (private_key_seed )
4544        except  ImportError :
@@ -49,17 +48,14 @@ def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
4948            )
5049        except  Exception  as  e :
5150            raise  ValueError (f"Invalid private key format: { str (e )}  )
52-         
53- 
5451
55-         
5652        # Crypto API base URL (configurable for testing) 
5753        self .base_url  =  options .get ("base_url" , "https://trading.robinhood.com" )
5854
5955    def  _get_current_timestamp (self ) ->  int :
6056        """Get current UTC timestamp.""" 
6157        return  int (datetime .datetime .now (tz = datetime .timezone .utc ).timestamp ())
62-      
58+ 
6359    def  _generate_signature (self , timestamp : int , method : str , path : str , body : str  =  "" ) ->  str :
6460        """Generate NaCl signature for API authentication following Robinhood's specification.""" 
6561        # Official Robinhood signature format: f"{api_key}{current_timestamp}{path}{method}{body}" 
@@ -68,41 +64,49 @@ def _generate_signature(self, timestamp: int, method: str, path: str, body: str
6864            message_to_sign  =  f"{ self .api_key } { timestamp } { path } { method .upper ()}  
6965        else :
7066            message_to_sign  =  f"{ self .api_key } { timestamp } { path } { method .upper ()} { body }  
71-              
67+ 
7268        signed  =  self .signing_key .sign (message_to_sign .encode ("utf-8" ))
7369        signature  =  base64 .b64encode (signed .signature ).decode ("utf-8" )
7470        return  signature 
7571
76-     def  _make_authenticated_request (self , method : str , path : str , params : Optional [Dict [str , str ]] =  None , json_data : Optional [Dict ] =  None ) ->  Optional [Dict ]:
72+     def  _make_authenticated_request (
73+         self ,
74+         method : str ,
75+         path : str ,
76+         params : Optional [Dict [str , str ]] =  None ,
77+         json_data : Optional [Dict ] =  None ,
78+     ) ->  Optional [Dict ]:
7779        """Make an authenticated request to the Robinhood Crypto API.""" 
7880        timestamp  =  self ._get_current_timestamp ()
7981        url  =  self .base_url  +  path 
80-          
82+ 
8183        # Prepare request body for signature (only for non-GET requests) 
8284        body  =  "" 
8385        if  method .upper () !=  "GET"  and  json_data :
84-             body  =  json .dumps (json_data , separators = (',' ,  ':' ))  # Compact JSON format 
85-          
86+             body  =  json .dumps (json_data , separators = ("," ,  ":" ))  # Compact JSON format 
87+ 
8688        # Generate signature 
8789        signature  =  self ._generate_signature (timestamp , method , path , body )
88-          
90+ 
8991        # Set authentication headers 
9092        headers  =  {
91-             ' x-api-key' self .api_key ,
92-             ' x-signature' signature ,
93-             ' x-timestamp' str (timestamp )
93+             " x-api-key" self .api_key ,
94+             " x-signature" signature ,
95+             " x-timestamp" str (timestamp ), 
9496        }
95-          
97+ 
9698        try :
9799            # Make request 
98100            if  method .upper () ==  "GET" :
99101                response  =  requests .get (url , headers = headers , params = params , timeout = 10 )
100102            elif  method .upper () ==  "POST" :
101-                 headers [' Content-Type' =  ' application/json' 
103+                 headers [" Content-Type" =  " application/json" 
102104                response  =  requests .post (url , headers = headers , json = json_data , timeout = 10 )
103105            else :
104-                 response  =  requests .request (method , url , headers = headers , params = params , json = json_data , timeout = 10 )
105-             
106+                 response  =  requests .request (
107+                     method , url , headers = headers , params = params , json = json_data , timeout = 10 
108+                 )
109+ 
106110            response .raise_for_status ()
107111            return  response .json ()
108112        except  requests .RequestException  as  e :
@@ -116,32 +120,30 @@ def _get_query_params(key: str, *args: str) -> str:
116120            return  "" 
117121        params  =  [f"{ key } { arg }   for  arg  in  args  if  arg ]
118122        return  "?"  +  "&" .join (params )
119-      
123+ 
120124    def  partitions (self ) ->  List [CryptoPair ]:
121125        """Create partitions for parallel processing of crypto pairs.""" 
122126        # Use specified symbols from path 
123127        symbols_str  =  self .options .get ("path" , "" )
124128        if  not  symbols_str :
125-             raise  ValueError (
126-                 "Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')" 
127-             )
128-             
129+             raise  ValueError ("Must specify crypto pairs to load using .load('BTC-USD,ETH-USD')" )
130+ 
129131        # Split symbols by comma and create partitions 
130132        symbols  =  [symbol .strip ().upper () for  symbol  in  symbols_str .split ("," )]
131133        # Ensure proper format (e.g., BTC-USD) 
132134        formatted_symbols  =  []
133135        for  symbol  in  symbols :
134-             if  symbol  and  '-'  not  in symbol :
136+             if  symbol  and  "-"  not  in symbol :
135137                symbol  =  f"{ symbol }    # Default to USD pair 
136138            if  symbol :
137139                formatted_symbols .append (symbol )
138-          
140+ 
139141        return  [CryptoPair (symbol = symbol ) for  symbol  in  formatted_symbols ]
140142
141143    def  read (self , partition : CryptoPair ) ->  Generator [Row , None , None ]:
142144        """Read crypto data for a single trading pair partition.""" 
143145        symbol  =  partition .symbol 
144-          
146+ 
145147        try :
146148            yield  from  self ._read_crypto_pair_data (symbol )
147149        except  Exception  as  e :
@@ -154,34 +156,36 @@ def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
154156            # Get best bid/ask data for the trading pair using query parameters 
155157            path  =  f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={ symbol }  
156158            market_data  =  self ._make_authenticated_request ("GET" , path )
157-              
158-             if  market_data  and  ' results' in  market_data :
159-                 for  quote  in  market_data [' results' 
159+ 
160+             if  market_data  and  " results" in  market_data :
161+                 for  quote  in  market_data [" results" 
160162                    # Parse numeric values safely 
161-                     def  safe_float (value : Union [str , int , float , None ], default : float  =  0.0 ) ->  float :
163+                     def  safe_float (
164+                         value : Union [str , int , float , None ], default : float  =  0.0 
165+                     ) ->  float :
162166                        if  value  is  None  or  value  ==  "" :
163167                            return  default 
164168                        try :
165169                            return  float (value )
166170                        except  (ValueError , TypeError ):
167171                            return  default 
168-                      
172+ 
169173                    # Extract market data fields from best bid/ask response 
170174                    # Use the correct field names from the API response 
171-                     price  =  safe_float (quote .get (' price' 
172-                     bid_price  =  safe_float (quote .get (' bid_inclusive_of_sell_spread' 
173-                     ask_price  =  safe_float (quote .get (' ask_inclusive_of_buy_spread' 
174-                      
175+                     price  =  safe_float (quote .get (" price" 
176+                     bid_price  =  safe_float (quote .get (" bid_inclusive_of_sell_spread" 
177+                     ask_price  =  safe_float (quote .get (" ask_inclusive_of_buy_spread" 
178+ 
175179                    yield  Row (
176180                        symbol = symbol ,
177181                        price = price ,
178182                        bid_price = bid_price ,
179183                        ask_price = ask_price ,
180-                         updated_at = quote .get (' timestamp' "" )
184+                         updated_at = quote .get (" timestamp" "" ), 
181185                    )
182186            else :
183187                print (f"Warning: No market data found for { symbol }  )
184-                  
188+ 
185189        except  requests .exceptions .RequestException  as  e :
186190            print (f"Network error fetching data for { symbol } { str (e )}  )
187191        except  (ValueError , KeyError ) as  e :
@@ -256,10 +260,7 @@ def name(cls) -> str:
256260        return  "robinhood" 
257261
258262    def  schema (self ) ->  str :
259-         return  (
260-             "symbol string, price double, bid_price double, ask_price double, " 
261-             "updated_at string" 
262-         )
263+         return  "symbol string, price double, bid_price double, ask_price double, updated_at string" 
263264
264265    def  reader (self , schema : StructType ) ->  RobinhoodDataReader :
265-         return  RobinhoodDataReader (schema , self .options )
266+         return  RobinhoodDataReader (schema , self .options )
0 commit comments