Skip to content

Conversation

@Yicong-Huang
Copy link
Contributor

@Yicong-Huang Yicong-Huang commented Aug 20, 2025

Add Robinhood cryptocurrency data source

Features:

  • Add RobinhoodDataSource for reading crypto market data from Robinhood API
  • Implement NaCl (Ed25519) cryptographic signing for API authentication
  • Support for individual crypto pairs (e.g., "BTC-USD") and bulk loading (e.g., "BTC-USD,ETH-USD,DOGE-USD")
  • Schema includes: symbol, price, bid_price, ask_price, updated_at
  • Add comprehensive tests with environment variable configuration
  • Add documentation and update project dependencies

Testing Instructions:

  1. Obtain Robinhood API credentials:

  2. Set environment variables:

    export ROBINHOOD_API_KEY="your-api-key-here"
    export ROBINHOOD_PRIVATE_KEY="your-base64-encoded-private-key"
    
  3. Run tests:

    pytest tests/test_robinhood.py -v -s
    
  4. Usage example:

    df = spark.read.format("robinhood") \
        .option("api_key", api_key) \
        .option("private_key", private_key) \
        .load("BTC-USD,ETH-USD")
    

Note: Tests will be skipped if environment variables are not set.
Real API credentials are required for integration tests to pass.

Summary by CodeRabbit

  • New Features

    • Added RobinhoodDataSource to read crypto market data (symbol, price, bid, ask, timestamp) with authenticated multi-symbol support.
  • Documentation

    • Added RobinhoodDataSource docs page, updated Data Sources navigation, and included install notes for the signing dependency and an install extra.
  • Chores

    • Added optional cryptographic signing dependency and a new installation extra for Robinhood.
  • Tests

    • Added tests for registration, schema, credential/error handling, and optional live integration.

- Add RobinhoodDataSource for reading crypto market data from Robinhood API
- Implement NaCl cryptographic signing for API authentication
- Support for individual crypto pairs and bulk loading
- Add comprehensive tests with environment variable configuration
- Add documentation and update project dependencies
- Schema includes: symbol, price, bid_price, ask_price, updated_at
@coderabbitai
Copy link

coderabbitai bot commented Aug 20, 2025

Important

Review skipped

Review was skipped due to path filters

⛔ Files ignored due to path filters (1)
  • poetry.lock is excluded by !**/*.lock

CodeRabbit blocks several paths by default. You can override this behavior by explicitly including those paths in the path filters. For example, including **/dist/** will override the default block on the dist directory, by removing the pattern from both the lists.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Walkthrough

Introduce a Robinhood crypto PySpark DataSource that performs NaCl-signed authenticated requests, expose it at package root, add pynacl and a robinhood extra, update docs and MkDocs nav, and add pytest coverage with optional live API tests.

Changes

Cohort / File(s) Summary of changes
Documentation
docs/datasources/robinhood.md, docs/index.md, mkdocs.yml
New docs page for RobinhoodDataSource with install notes for pynacl and pyspark-data-sources[robinhood], public API render directive; index and mkdocs navigation updated.
Packaging & Dependencies
pyproject.toml
Added pynacl (^1.5.0) and a robinhood extras group; extras updated to include pynacl.
Package Exports
pyspark_datasources/__init__.py
Exported RobinhoodDataSource at package level via from .robinhood import RobinhoodDataSource.
Data Source Implementation
pyspark_datasources/robinhood.py
New RobinhoodDataSource, CryptoPair (InputPartition), and RobinhoodDataReader: base64 private-key decoding, PyNaCl signing, authenticated requests to Robinhood endpoints, partitioning by symbols (path option), and emitting rows (symbol, price, bid_price, ask_price, updated_at).
Tests
tests/test_robinhood.py
New pytest module covering name/schema, credential and symbol validation errors, invalid private-key handling, and optional live API tests for single and multiple crypto pairs.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Spark
  participant DataSource as RobinhoodDataSource
  participant Reader as RobinhoodDataReader
  participant API as Robinhood API

  User->>Spark: spark.read.format("robinhood").options(path, api_key, private_key).load()
  Spark->>DataSource: schema() / reader(schema)
  DataSource->>Reader: __init__(schema, options)
  note right of Reader: decode base64 private_key\ninit PyNaCl SigningKey

  Reader->>Spark: partitions() -> list of CryptoPair(symbol)

  loop per CryptoPair
    Spark->>Reader: read(CryptoPair)
    Reader->>API: GET /api/v1/crypto/marketdata/best_bid_ask?symbol=... (signed)
    API-->>Reader: market data JSON
    Reader-->>Spark: Row(symbol, price, bid_price, ask_price, updated_at)
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Poem

I thump and twitch—new data sprouts,
Keys tucked tight in secret clouts.
BTC, ETH, DOGE in rows,
I hop through bytes where Robinhood goes.
A carrot of rows—Spark happily scrounges. 🐇✨

✨ Finishing Touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

CodeRabbit Configuration File (.coderabbit.yaml)

  • You can programmatically configure CodeRabbit by adding a .coderabbit.yaml file to the root of your repository.
  • Please see the configuration documentation for more information.
  • If your editor has YAML language server enabled, you can add the path at the top of this file to enable auto-completion and validation: # yaml-language-server: $schema=https://coderabbit.ai/integrations/schema.v2.json

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@Yicong-Huang
Copy link
Contributor Author

@allisonwang-db please help review, thanks!

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🧹 Nitpick comments (15)
pyproject.toml (1)

30-32: Keep the “all” extra in sync with all optional deps (potential drift)

Nice addition of a dedicated robinhood extra and inclusion of pynacl in all. Please double-check whether all is intended to be the union of every optional dependency; if so, confirm whether anything under the existing lance extra should also be represented here (today all excludes anything listed in lance). If that exclusion is intentional, consider a brief comment to prevent future refactors from “fixing” it.

docs/datasources/robinhood.md (1)

1-6: Add Quickstart, Options, and test env vars so users can run this without hunting the docstring

The page renders the API, but a minimal “how to use” and options table will shorten onboarding and aligns with prior docs patterns.

Apply:

 # RobinhoodDataSource

 > Requires the [`pynacl`](https://github.com/pyca/pynacl) library for cryptographic signing. You can install it manually: `pip install pynacl`
 > or use `pip install pyspark-data-sources[robinhood]`.

+## Quickstart
+
+```python
+from pyspark_datasources import RobinhoodDataSource
+
+# Register (if your environment requires explicit registration)
+spark.dataSource.register(RobinhoodDataSource)
+
+df = (
+    spark.read.format("robinhood")
+    .option("api_key", "<YOUR_API_KEY>")
+    .option("private_key", "<YOUR_BASE64_ED25519_PRIVATE_KEY>")
+    .load("BTC-USD,ETH-USD")
+)
+df.show()
+```
+
+## Options
+
+- `api_key` (string, required): Robinhood API key.
+- `private_key` (string, required): Base64-encoded Ed25519 private key.
+- `load_all_pairs` (string bool, optional, default `"false"`): If `"true"`, loads all available pairs when `.load()` is called with no path.
+
+Schema returned: `symbol string, price double, bid_price double, ask_price double, updated_at string`
+
+## Testing
+
+Integration tests require real credentials. Set:
+
+```bash
+export ROBINHOOD_API_KEY="..."
+export ROBINHOOD_PRIVATE_KEY="..."  # base64-encoded Ed25519 private key
+```
+
 ::: pyspark_datasources.robinhood.RobinhoodDataSource
mkdocs.yml (1)

30-30: Optional: keep Data Sources nav roughly alphabetical

Placing robinhood near stock (or otherwise alphabetically) helps discoverability. Not a blocker.

Apply (move robinhood under stock):

   - datasources/stock.md
+  - datasources/robinhood.md
   - datasources/simplejson.md
-  - datasources/salesforce.md
+  - datasources/salesforce.md
   - datasources/googlesheets.md
   - datasources/kaggle.md
   - datasources/jsonplaceholder.md
-  - datasources/robinhood.md
tests/test_robinhood.py (4)

1-6: Remove unused imports or put them to use.

  • Mock and Row are unused.
  • If you keep patch (recommended for auth mocking below), drop Mock and Row to satisfy Ruff F401.

Apply:

-import os
-import pytest
-from unittest.mock import Mock, patch
-from pyspark.sql import SparkSession, Row
-from pyspark.errors.exceptions.captured import AnalysisException
+import os
+import pytest
+from unittest.mock import patch
+from pyspark.sql import SparkSession
+from pyspark.errors.exceptions.captured import AnalysisException

66-87: Mark these as integration tests and reduce noisy prints.

These run against live APIs and should be clearly marked to allow selective runs. Also, prefer pytest assertions over prints for CI signal.

Apply:

- def test_robinhood_btc_data(spark):
+ @pytest.mark.integration
+ def test_robinhood_btc_data(spark):
@@
-    rows = df.collect()
-    print(f"Retrieved {len(rows)} rows")
+    rows = df.collect()
@@
-    for i, row in enumerate(rows):
-        print(f"Row {i+1}: {row}")
+    for row in rows:
         # Validate data structure

99-139: Same as above: mark as integration and trim prints; assert on content.

Apply:

- def test_robinhood_multiple_crypto_pairs(spark):
+ @pytest.mark.integration
+ def test_robinhood_multiple_crypto_pairs(spark):
@@
-    rows = df.collect()
-    print(f"Retrieved {len(rows)} rows")
+    rows = df.collect()
@@
-    for i, row in enumerate(rows):
-        symbols_found.add(row.symbol)
-        print(f"Row {i+1}: {row}")
+    for row in rows:
+        symbols_found.add(row.symbol)

38-38: Wrap long lines to satisfy Ruff E501.

Multiple assertion lines exceed 100 chars.

Example fixes:

-    assert "ValueError" in str(excinfo.value) and ("api_key" in str(excinfo.value) or "private_key" in str(excinfo.value))
+    msg = str(excinfo.value)
+    assert "ValueError" in msg and ("api_key" in msg or "private_key" in msg)
-        pytest.skip("ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables required for real API tests")
+        pytest.skip(
+            "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables "
+            "required for real API tests"
+        )

And split the long f-strings similarly using intermediate variables or parentheses.

Also applies to: 73-73, 92-96, 106-106, 131-135, 138-138

pyspark_datasources/robinhood.py (8)

155-190: Use lazy client, add retries, and replace prints with logging; wrap long lines.

  • Create the client on demand.
  • Add simple retry/backoff to be more resilient.
  • Replace print with logger warnings/errors.
  • Break long lines flagged by E501.

Apply:

+import logging
@@
-    def _make_authenticated_request(self, method: str, path: str, params: Dict = None, json_data: Dict = None):
+    def _make_authenticated_request(
+        self,
+        method: str,
+        path: str,
+        params: dict | None = None,
+        json_data: dict | None = None,
+    ):
         """Make an authenticated request to the Robinhood Crypto API."""
+        self._ensure_client()
         timestamp = self._get_current_timestamp()
         url = self.base_url + path
@@
-        try:
-            # Make request
-            if method.upper() == "GET":
-                response = self.session.get(url, headers=headers, params=params, timeout=10)
-            elif method.upper() == "POST":
-                headers['Content-Type'] = 'application/json'
-                response = self.session.post(url, headers=headers, json=json_data, timeout=10)
-            else:
-                response = self.session.request(method, url, headers=headers, params=params, json=json_data, timeout=10)
-            
-            response.raise_for_status()
-            return response.json()
-        except requests.RequestException as e:
-            print(f"Error making API request to {path}: {e}")
-            return None
+        # Simple retry loop
+        attempts, last_exc = 0, None
+        while attempts < 3:
+            attempts += 1
+            try:
+                if method.upper() == "GET":
+                    response = self._session.get(  # type: ignore[union-attr]
+                        url, headers=headers, params=params, timeout=10
+                    )
+                elif method.upper() == "POST":
+                    headers["Content-Type"] = "application/json"
+                    response = self._session.post(  # type: ignore[union-attr]
+                        url, headers=headers, json=json_data, timeout=10
+                    )
+                else:
+                    response = self._session.request(  # type: ignore[union-attr]
+                        method,
+                        url,
+                        headers=headers,
+                        params=params,
+                        json=json_data,
+                        timeout=10,
+                    )
+                response.raise_for_status()
+                return response.json()
+            except requests.RequestException as e:
+                last_exc = e
+        logging.getLogger(__name__).warning(
+            "Error making API request to %s after %d attempts: %s",
+            path, attempts, last_exc,
+        )
+        return None

191-198: Remove dead code: _get_query_params is unused.

The helper is not referenced anywhere.

Apply:

-    @staticmethod
-    def _get_query_params(key: str, *args) -> str:
-        """Build query parameters for API requests."""
-        if not args:
-            return ""
-        params = [f"{key}={arg}" for arg in args if arg]
-        return "?" + "&".join(params)

199-234: Harden partitions(): deduplicate symbols and keep order; validate input early.

  • Comma-split can yield duplicates and empties. Deduplicate while preserving order.
  • Keep the USD defaulting behavior.

Apply:

-            symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
+            symbols = [s.strip().upper() for s in symbols_str.split(",")]
@@
-            formatted_symbols = []
-            for symbol in symbols:
-                if symbol and '-' not in symbol:
-                    symbol = f"{symbol}-USD"  # Default to USD pair
-                if symbol:
-                    formatted_symbols.append(symbol)
-            
-            return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
+            formatted_symbols: list[str] = []
+            for s in symbols:
+                if not s:
+                    continue
+                if "-" not in s:
+                    s = f"{s}-USD"  # Default to USD pair
+                formatted_symbols.append(s)
+            # Deduplicate while preserving order
+            unique_symbols = list(dict.fromkeys(formatted_symbols))
+            return [CryptoPair(symbol=s) for s in unique_symbols]

235-244: Consider surfacing partition-level failures to callers instead of silently swallowing.

Swallowing all exceptions loses observability and can produce silently incomplete datasets. Prefer logging with context (done below), and optionally re-raising a structured error or emitting an error Row with a status column (if your consumers expect resilient reads).

Minimal change (switch to logger):

-        except Exception as e:
-            # Log error but don't fail the entire job
-            print(f"Warning: Failed to fetch data for {symbol}: {str(e)}")
+        except Exception as e:
+            logging.getLogger(__name__).warning(
+                "Failed to fetch data for %s: %s", symbol, e
+            )

79-83: Schema type consideration: updated_at could be a timestamp.

If the API returns ISO-8601, you may prefer timestamp to enable time operations natively. Keeping string is fine for now if it matches downstream expectations and tests.

Would you like a follow-up PR to switch to updated_at timestamp and update tests/docs accordingly?


13-73: Docstring: add Options, error cases, partitioning strategy, Arrow usage, and printSchema.

To align with the repo’s guidelines:

  • Options (api_key, private_key, load_all_pairs, base_url [new], timeouts).
  • Error cases: missing credentials, invalid private key, no symbols vs load_all_pairs.
  • Partitioning strategy: one partition per symbol; how dedupe works.
  • Arrow optimizations and whether RecordBatch is supported.
  • Include df.printSchema() example.

I can draft the docstring update if you’d like.


2-2: Modernize typing: prefer built-in dict over typing.Dict; wrap long signatures.

This quiets UP006/UP035 and E501.

Apply:

-from typing import Dict
+from typing import Any
@@
-    def __init__(self, schema: StructType, options: Dict):
+    def __init__(self, schema: StructType, options: dict[str, Any]):
@@
-    def _make_authenticated_request(self, method: str, path: str, params: Dict = None, json_data: Dict = None):
+    def _make_authenticated_request(
+        self,
+        method: str,
+        path: str,
+        params: dict | None = None,
+        json_data: dict | None = None,
+    ):

Also applies to: 98-98, 155-155


245-285: Refine field mapping and structured logging

Verified the /best_bid_ask endpoint only returns price, bid_inclusive_of_sell_spread, ask_inclusive_of_buy_spread, and timestamp; the suggested fallbacks for mid_price, mark_price, and last_trade_price aren’t present in Robinhood’s response and can be removed . While retaining the intent to avoid masking missing data, switch safe_float’s default to None and replace all print calls with structured logger invocations.

Key changes to apply:

  • Update safe_float default from 0.0 to None
  • Assign price directly from quote.get("price") only
  • Swap print statements for:
    • logger.warning("No market data found for %s", symbol)
    • logger.error("Network error fetching data for %s: %s", symbol, e)
    • logger.error("Data parsing error for %s: %s", symbol, e)
    • logger.exception("Unexpected error fetching data for %s", symbol)

Suggested diff:

     def _read_crypto_pair_data(self, symbol: str):
         """Fetch cryptocurrency market data for a given trading pair."""
         try:
             path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
             market_data = self._make_authenticated_request("GET", path)
             
             if market_data and 'results' in market_data:
                 for quote in market_data['results']:
-                    def safe_float(value, default=0.0):
+                    def safe_float(value, default=None):
                         if value is None or value == "":
                             return default
                         try:
                             return float(value)
                         except (ValueError, TypeError):
                             return default
                     
-                    price = safe_float(
-                        quote.get("price")
-                        or quote.get("mid_price")
-                        or quote.get("mark_price")
-                        or quote.get("last_trade_price")
-                    )
+                    price = safe_float(quote.get("price"))
                     bid_price = safe_float(quote.get('bid_inclusive_of_sell_spread'))
                     ask_price = safe_float(quote.get('ask_inclusive_of_buy_spread'))
                     
                     yield Row(
                         symbol=symbol,
                         price=price,
                         bid_price=bid_price,
                         ask_price=ask_price,
                         updated_at=quote.get('timestamp', "")
                     )
             else:
-                print(f"Warning: No market data found for {symbol}")
+                logger = logging.getLogger(__name__)
+                logger.warning("No market data found for %s", symbol)
                 
         except requests.exceptions.RequestException as e:
-            print(f"Network error fetching data for {symbol}: {str(e)}")
+            logger.error("Network error fetching data for %s: %s", symbol, e)
         except (ValueError, KeyError) as e:
-            print(f"Data parsing error for {symbol}: {str(e)}")
+            logger.error("Data parsing error for %s: %s", symbol, e)
         except Exception as e:
-            print(f"Unexpected error fetching data for {symbol}: {str(e)}")
+            logger.exception("Unexpected error fetching data for %s", symbol)
📜 Review details

Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 3ee0484 and 2b483d2.

📒 Files selected for processing (7)
  • docs/datasources/robinhood.md (1 hunks)
  • docs/index.md (1 hunks)
  • mkdocs.yml (1 hunks)
  • pyproject.toml (1 hunks)
  • pyspark_datasources/__init__.py (1 hunks)
  • pyspark_datasources/robinhood.py (1 hunks)
  • tests/test_robinhood.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
{pyspark_datasources,tests}/**/*.py

📄 CodeRabbit Inference Engine (CLAUDE.md)

{pyspark_datasources,tests}/**/*.py: When specifying file paths with Spark, use load("/path") instead of option("path", "/path")
Format code with Ruff (ruff format) and fix/lint with ruff check

Files:

  • pyspark_datasources/__init__.py
  • tests/test_robinhood.py
  • pyspark_datasources/robinhood.py
pyspark_datasources/!(__init__).py

📄 CodeRabbit Inference Engine (CLAUDE.md)

pyspark_datasources/!(__init__).py: All data source classes must inherit from Spark's DataSource base class
Implement schema() in each data source to define the schema using PySpark StructType
Implement reader(schema) for batch reads; if streaming is supported, implement streamReader() or simpleStreamReader() accordingly
If batch writing is supported, implement writer(schema, overwrite); if streaming writes are supported, implement streamWriter(schema, overwrite)
All classes used by the data sources (including readers/writers) must be pickle-serializable
Use only standard Spark SQL data types in schemas
Implement robust exception handling in data source read/write paths
For streaming sources, ensure resources are cleaned up (e.g., implement stop()/commit() as appropriate)
Prefer returning pyarrow.RecordBatch from readers for better serialization performance
Implement partitions() on DataSourceReader to enable parallel processing when beneficial
Defer expensive operations until read time (lazy evaluation) in data source implementations
Include comprehensive class docstrings for each data source with: brief description and Name: "format_name", an Options section (parameters/types/defaults), and Examples (registration and basic usage)
In docstrings, include schema output (e.g., df.printSchema()) for clarity
In docstrings, document error cases (e.g., missing files or invalid options)
In docstrings, document the partitioning strategy used to improve performance
In docstrings, document Arrow optimizations (how Arrow is used to transmit data)
DataSourceReader.read must return tuples, Rows, or pyarrow.RecordBatch; implement accordingly

Files:

  • pyspark_datasources/robinhood.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T20:07:33.249Z
Learnt from: CR
PR: allisonwang-db/pyspark-data-sources#0
File: CLAUDE.md:0-0
Timestamp: 2025-08-19T20:07:33.249Z
Learning: Applies to pyspark_datasources/!(__init__).py : Include comprehensive class docstrings for each data source with: brief description and Name: "format_name", an Options section (parameters/types/defaults), and Examples (registration and basic usage)

Applied to files:

  • docs/index.md
🧬 Code Graph Analysis (3)
pyspark_datasources/__init__.py (1)
pyspark_datasources/robinhood.py (1)
  • RobinhoodDataSource (13-86)
tests/test_robinhood.py (3)
pyspark_datasources/robinhood.py (4)
  • RobinhoodDataSource (13-86)
  • name (76-77)
  • schema (79-83)
  • read (235-243)
tests/test_data_sources.py (1)
  • spark (11-13)
tests/test_google_sheets.py (1)
  • spark (9-12)
pyspark_datasources/robinhood.py (1)
pyspark_datasources/stock.py (2)
  • StockDataSource (8-41)
  • StockDataReader (49-86)
🪛 Ruff (0.12.2)
pyspark_datasources/__init__.py

8-8: .robinhood.RobinhoodDataSource imported but unused; consider removing, adding to __all__, or using a redundant alias

(F401)

tests/test_robinhood.py

3-3: unittest.mock.Mock imported but unused

Remove unused import

(F401)


3-3: unittest.mock.patch imported but unused

Remove unused import

(F401)


4-4: pyspark.sql.Row imported but unused

Remove unused import: pyspark.sql.Row

(F401)


38-38: Line too long (122 > 100)

(E501)


73-73: Line too long (116 > 100)

(E501)


92-92: Line too long (101 > 100)

(E501)


94-94: Line too long (113 > 100)

(E501)


95-95: Line too long (113 > 100)

(E501)


96-96: Line too long (113 > 100)

(E501)


106-106: Line too long (116 > 100)

(E501)


131-131: Line too long (101 > 100)

(E501)


133-133: Line too long (113 > 100)

(E501)


134-134: Line too long (113 > 100)

(E501)


135-135: Line too long (113 > 100)

(E501)


138-138: Line too long (119 > 100)

(E501)

pyspark_datasources/robinhood.py

2-2: typing.Dict is deprecated, use dict instead

(UP035)


98-98: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


155-155: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


155-155: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


155-155: Line too long (111 > 100)

(E501)


183-183: Line too long (120 > 100)

(E501)

🪛 Gitleaks (8.27.2)
tests/test_robinhood.py

46-46: Detected a Generic API Key, potentially exposing access to various services and sensitive operations.

(generic-api-key)

🔇 Additional comments (1)
pyproject.toml (1)

22-22: Good: optional PyNaCl is correctly modeled

Declaring pynacl as optional aligns with the new signed-auth requirement without forcing it on users who don’t need Robinhood.

Copy link
Owner

@allisonwang-db allisonwang-db left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

Comment on lines 43 to 44
|ETH-USD | 2650.75| 2650.50 | 2651.00 |2024-01-15T16:00:...|
|DOGE-USD| 0.085| 0.084| 0.086|2024-01-15T16:00:...|
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's improve the formatting (or we can remove this row)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

Comment on lines 48 to 54
>>> df = spark.read.format("robinhood") \\
... .option("api_key", "your-api-key") \\
... .option("private_key", "your-base64-private-key") \\
... .load("BTC-USD,ETH-USD")
>>> df.show()
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this example the same as previous one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a slight different example. But since it looks redundant, I have removed it.

Comment on lines 81 to 83
"symbol string, price double, bid_price double, ask_price double, "
"updated_at string"
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this only support fetching the latest snapshot price or also historical prices?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes according to the doc, there is no option for supply a historical time. So it appears it only support the latest snapshot price.

# Initialize session for connection pooling
self.session = requests.Session()
self.session.headers.update({
'User-Agent': 'PySpark Robinhood Crypto DataSource/1.0'
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice :) do we actually need this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I first thought there are many crypto pairs to retrieve, thus I added session and pagination. However, at the moment there are only 32 cryptos pairs listed on this API. So both session and pagination is an over kill. I have removed both of them now.

params = [f"{key}={arg}" for arg in args if arg]
return "?" + "&".join(params)

def partitions(self):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add type hint here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. also added type hints for other methods.

"""Fetch cryptocurrency market data for a given trading pair."""
try:
# Get best bid/ask data for the trading pair using query parameters
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it support passing timestamp to query for historical data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unfortunately, it is not supported at the moment.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/test_robinhood.py (1)

99-139: Add an offline parsing test by mocking HTTP to validate field extraction.

This ensures core parsing works without live credentials and shields CI from vendor changes.

Add outside this range:

def test_robinhood_parsing_offline(spark, monkeypatch):
    sample = {
        "results": [{
            "price": "50000.1",
            "bid_price": "49999.9",
            "ask_price": "50000.3",
            "timestamp": "2025-08-20T12:34:56Z",
        }]
    }
    class DummyResp:
        status_code = 200
        def raise_for_status(self): pass
        def json(self): return sample
    class DummySession:
        def get(self, *a, **kw): return DummyResp()
        def post(self, *a, **kw): return DummyResp()
        def request(self, *a, **kw): return DummyResp()

    # Patch lazy client/session and signing to avoid crypto/network
    import pyspark_datasources.robinhood as rh
    monkeypatch.setattr(rh.RobinhoodDataReader, "_ensure_client", lambda self: None)
    monkeypatch.setattr(rh, "requests", Mock(Session=lambda: DummySession()))
    # Also patch signature to a no-op to avoid depending on format
    monkeypatch.setattr(rh.RobinhoodDataReader, "_generate_signature", lambda *a, **k: "sig")

    df = (
        spark.read.format("robinhood")
        .option("api_key", "k")
        .option("private_key", "p")
        .load("BTC-USD")
    )
    rows = df.collect()
    assert rows and rows[0].symbol == "BTC-USD"
    assert rows[0].price == 50000.1
♻️ Duplicate comments (2)
pyspark_datasources/robinhood.py (1)

83-121: Reader is not pickle-serializable; eagerly creates crypto key and HTTP session. Defer initialization and implement pickling.

Per project guidelines (and our prior learning), reader/writer classes must be pickle-serializable. Storing requests.Session and a NaCl SigningKey on the instance prevents pickling and will break executor-side deserialization. Defer creation until first use and drop transient fields from __getstate__. Also expose base_url as an option for sandboxing.

Apply within this range:

-    def __init__(self, schema: StructType, options: Dict):
+    def __init__(self, schema: StructType, options: dict[str, Any]):
         self.schema = schema
         self.options = options
-        
-        # Required API authentication
+        # Required API authentication
         self.api_key = options.get("api_key")
         self.private_key_base64 = options.get("private_key")
-        
         if not self.api_key or not self.private_key_base64:
             raise ValueError(
                 "Robinhood Crypto API requires both 'api_key' and 'private_key' options. "
                 "The private_key should be base64-encoded. "
                 "Get your API credentials from https://docs.robinhood.com/crypto/trading/"
             )
-        
-        # Initialize NaCl signing key
-        try:
-            from nacl.signing import SigningKey
-            private_key_seed = base64.b64decode(self.private_key_base64)
-            self.signing_key = SigningKey(private_key_seed)
-        except ImportError:
-            raise ImportError(
-                "PyNaCl library is required for Robinhood Crypto API authentication. "
-                "Install it with: pip install pynacl"
-            )
-        except Exception as e:
-            raise ValueError(f"Invalid private key format: {str(e)}")
-
-        # Initialize session for connection pooling
-        self.session = requests.Session()
-        self.session.headers.update({
-            'User-Agent': 'PySpark Robinhood Crypto DataSource/1.0'
-        })
-        
-        # Crypto API base URL
-        self.base_url = "https://trading.robinhood.com"
+        # Transient, lazily initialized on first request (pickling-friendly)
+        self._signing_key = None  # set in _ensure_client()
+        self._session = None      # set in _ensure_client()
+        # Crypto API base URL (overridable for testing/sandbox)
+        self.base_url = options.get("base_url", "https://trading.robinhood.com")

Add outside this range (new helpers):

def _ensure_client(self) -> None:
    """Lazily initialize signing key and HTTP session on the executor."""
    if self._signing_key is None:
        try:
            from nacl.signing import SigningKey
        except ImportError as ie:
            raise ImportError(
                "PyNaCl library is required. Install with: pip install pynacl"
            ) from ie
        try:
            private_key_seed = base64.b64decode(self.private_key_base64)
            self._signing_key = SigningKey(private_key_seed)
        except Exception as e:
            raise ValueError(f"Invalid private key format: {e}") from e
    if self._session is None:
        self._session = requests.Session()
        self._session.headers.update(
            {"User-Agent": "PySpark Robinhood Crypto DataSource/1.0"}
        )

def __getstate__(self):
    state = dict(self.__dict__)
    state["_session"] = None
    state["_signing_key"] = None
    return state

def __setstate__(self, state):
    self.__dict__.update(state)
tests/test_robinhood.py (1)

41-49: Avoid hard-coded key-like test strings; patch SigningKey and base64 decode.

Secret scanners will flag the base64-looking value. Also, tests on error paths need not construct real crypto state. Patch the crypto bits and use a benign placeholder.

-def test_robinhood_missing_symbols(spark):
+def test_robinhood_missing_symbols(spark):
     """Test that missing symbols raises an error."""
-    with pytest.raises(AnalysisException) as excinfo:
-        df = spark.read.format("robinhood") \
-            .option("api_key", "test-key") \
-            .option("private_key", "FAPmPMsQqDFOFiRvpUMJ6BC5eFOh/tPx7qcTYGKc8nE=") \
-            .load("")
-        df.collect()  # Trigger execution
+    with patch("pyspark_datasources.robinhood.base64.b64decode", return_value=b"\x00"*32), \
+         patch("pyspark_datasources.robinhood.nacl.signing.SigningKey") as _sk:
+        with pytest.raises(AnalysisException) as excinfo:
+            df = (
+                spark.read.format("robinhood")
+                .option("api_key", "test-key")
+                .option("private_key", "placeholder-key")
+                .load("")
+            )
+            df.collect()  # Trigger execution
🧹 Nitpick comments (8)
pyspark_datasources/robinhood.py (4)

2-2: Replace deprecated typing.Dict with builtin dict; import Any.

Ruff flags this (UP035/UP006). Use builtin generics for annotations.

-from typing import Dict
+from typing import Any

175-182: Helper _get_query_params is currently unused. Remove or use via _make_authenticated_request.

Dead code increases maintenance cost. Prefer removing it or reusing it consistently when constructing path.

-    @staticmethod
-    def _get_query_params(key: str, *args) -> str:
-        """Build query parameters for API requests."""
-        if not args:
-            return ""
-        params = [f"{key}={arg}" for arg in args if arg]
-        return "?" + "&".join(params)
+    # NOTE: Removed _get_query_params as construction is centralized in _make_authenticated_request.

183-203: Normalize symbols consistently; dedupe and validate; minor ergonomics.

  • Deduplicate symbols and preserve order.
  • Normalize case and default quote currency while avoiding double “-USD”.
  • Fail fast on empty post-normalization set.
-        symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
-        # Ensure proper format (e.g., BTC-USD)
-        formatted_symbols = []
-        for symbol in symbols:
-            if symbol and '-' not in symbol:
-                symbol = f"{symbol}-USD"  # Default to USD pair
-            if symbol:
-                formatted_symbols.append(symbol)
-        
-        return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
+        raw = [s.strip().upper() for s in symbols_str.split(",")]
+        seen = set()
+        formatted = []
+        for s in raw:
+            if not s:
+                continue
+            sym = s if "-" in s else f"{s}-USD"
+            if sym not in seen:
+                seen.add(sym)
+                formatted.append(sym)
+        if not formatted:
+            raise ValueError("No valid crypto pairs provided.")
+        return [CryptoPair(symbol=s) for s in formatted]

214-245: Field mapping and robustness for market data parsing.

  • Field names bid_inclusive_of_sell_spread and ask_inclusive_of_buy_spread may differ across endpoints/versions. Provide fallbacks.
  • Consider emitting updated_at as timestamp type in schema (optional, breaking change).
-                    price = safe_float(quote.get('price'))
-                    bid_price = safe_float(quote.get('bid_inclusive_of_sell_spread'))
-                    ask_price = safe_float(quote.get('ask_inclusive_of_buy_spread'))
+                    price = safe_float(quote.get('price') or quote.get('mark_price') or quote.get('mid_price'))
+                    bid_price = safe_float(
+                        quote.get('bid_inclusive_of_sell_spread') or quote.get('bid_price') or quote.get('bid')
+                    )
+                    ask_price = safe_float(
+                        quote.get('ask_inclusive_of_buy_spread') or quote.get('ask_price') or quote.get('ask')
+                    )

Optional: switch schema’s updated_at to timestamp and parse to ISO 8601 for better downstream ops. If you want this, I can provide a follow-up patch and migration note.

tests/test_robinhood.py (4)

3-4: Trim unused imports or make use of them.

Mock and Row are unused. If you accept my patch below for crypto patching, patch will be used; keep it, drop Mock and Row.

-from unittest.mock import Mock, patch
-from pyspark.sql import SparkSession, Row
+from unittest.mock import patch
+from pyspark.sql import SparkSession

38-38: Wrap long assertions to satisfy Ruff E501 (line length).

Use implicit string concatenation or parentheses to break lines.

Example:

-    assert "ValueError" in str(excinfo.value) and ("api_key" in str(excinfo.value) or "private_key" in str(excinfo.value))
+    err = str(excinfo.value)
+    assert "ValueError" in err and ("api_key" in err or "private_key" in err)

Also applies to: 73-73, 92-96, 106-106, 131-135, 138-138


32-39: Error assertion is brittle; assert on message shape without relying on “ValueError” prefixing.

Spark may wrap differently across versions. Assert on informative substrings.

-    assert "ValueError" in str(excinfo.value) and ("api_key" in str(excinfo.value) or "private_key" in str(excinfo.value))
+    err = str(excinfo.value)
+    assert ("api_key" in err) or ("private_key" in err)

66-97: Live API test is fine; consider marking as flaky and adding a strict timeout.

Network and vendor outages will fail CI. You already skip without creds; additionally mark as flaky and shorten Spark job where possible.

@pytest.mark.flaky(reruns=2, reruns_delay=2)
def test_robinhood_btc_data(spark):
    ...
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 2b483d2 and 9718baa.

📒 Files selected for processing (2)
  • pyspark_datasources/robinhood.py (1 hunks)
  • tests/test_robinhood.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
pyspark_datasources/!(__init__).py

📄 CodeRabbit inference engine (CLAUDE.md)

pyspark_datasources/!(__init__).py: All data source classes must inherit from Spark's DataSource base class
Implement schema() in each data source to define the schema using PySpark StructType
Implement reader(schema) for batch reads; if streaming is supported, implement streamReader() or simpleStreamReader() accordingly
If batch writing is supported, implement writer(schema, overwrite); if streaming writes are supported, implement streamWriter(schema, overwrite)
All classes used by the data sources (including readers/writers) must be pickle-serializable
Use only standard Spark SQL data types in schemas
Implement robust exception handling in data source read/write paths
For streaming sources, ensure resources are cleaned up (e.g., implement stop()/commit() as appropriate)
Prefer returning pyarrow.RecordBatch from readers for better serialization performance
Implement partitions() on DataSourceReader to enable parallel processing when beneficial
Defer expensive operations until read time (lazy evaluation) in data source implementations
Include comprehensive class docstrings for each data source with: brief description and Name: "format_name", an Options section (parameters/types/defaults), and Examples (registration and basic usage)
In docstrings, include schema output (e.g., df.printSchema()) for clarity
In docstrings, document error cases (e.g., missing files or invalid options)
In docstrings, document the partitioning strategy used to improve performance
In docstrings, document Arrow optimizations (how Arrow is used to transmit data)
DataSourceReader.read must return tuples, Rows, or pyarrow.RecordBatch; implement accordingly

Files:

  • pyspark_datasources/robinhood.py
{pyspark_datasources,tests}/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

{pyspark_datasources,tests}/**/*.py: When specifying file paths with Spark, use load("/path") instead of option("path", "/path")
Format code with Ruff (ruff format) and fix/lint with ruff check

Files:

  • pyspark_datasources/robinhood.py
  • tests/test_robinhood.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T20:07:33.249Z
Learnt from: CR
PR: allisonwang-db/pyspark-data-sources#0
File: CLAUDE.md:0-0
Timestamp: 2025-08-19T20:07:33.249Z
Learning: Applies to pyspark_datasources/!(__init__).py : All classes used by the data sources (including readers/writers) must be pickle-serializable

Applied to files:

  • pyspark_datasources/robinhood.py
🧬 Code graph analysis (2)
pyspark_datasources/robinhood.py (2)
pyspark_datasources/stock.py (2)
  • StockDataSource (8-41)
  • StockDataReader (49-86)
pyspark_datasources/github.py (1)
  • GithubPullRequestReader (48-72)
tests/test_robinhood.py (2)
pyspark_datasources/robinhood.py (4)
  • RobinhoodDataSource (13-71)
  • name (61-62)
  • schema (64-68)
  • read (204-212)
tests/test_data_sources.py (1)
  • spark (11-13)
🪛 Ruff (0.12.2)
pyspark_datasources/robinhood.py

2-2: typing.Dict is deprecated, use dict instead

(UP035)


83-83: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


139-139: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


139-139: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


139-139: Line too long (111 > 100)

(E501)


167-167: Line too long (120 > 100)

(E501)

tests/test_robinhood.py

3-3: unittest.mock.Mock imported but unused

Remove unused import

(F401)


3-3: unittest.mock.patch imported but unused

Remove unused import

(F401)


4-4: pyspark.sql.Row imported but unused

Remove unused import: pyspark.sql.Row

(F401)


38-38: Line too long (122 > 100)

(E501)


73-73: Line too long (116 > 100)

(E501)


92-92: Line too long (101 > 100)

(E501)


94-94: Line too long (113 > 100)

(E501)


95-95: Line too long (113 > 100)

(E501)


96-96: Line too long (113 > 100)

(E501)


106-106: Line too long (116 > 100)

(E501)


131-131: Line too long (101 > 100)

(E501)


133-133: Line too long (113 > 100)

(E501)


134-134: Line too long (113 > 100)

(E501)


135-135: Line too long (113 > 100)

(E501)


138-138: Line too long (119 > 100)

(E501)

🔇 Additional comments (3)
pyspark_datasources/robinhood.py (2)

60-71: DataSource interface usage looks correct.

name(), schema(), and reader() implement the expected interface and match other sources in the repo. Good.


122-137: Ignore the newline-separated canonicalization suggestion; the existing signature logic already matches Robinhood’s spec.

I’ve verified against the community‐published examples and Robinhood’s announcement that the string to sign is simply the concatenation of:

  • API key
  • Timestamp (integer seconds since the Unix epoch, UTC)
  • Request path (including query string, but not the domain)
  • HTTP method (uppercase)
  • Request body (stringified JSON; omit entirely for GET requests)

They do not insert newline separators or hash the body—just UTF-8-encode the raw concatenated string and sign it with your Ed25519 key, then Base64-encode the raw signature and send it in the x-signature header alongside x-api-key and x-timestamp. That exactly matches this code:

if method.upper() == "GET" and not body:
    message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}"
else:
    message_to_sign = f"{self.api_key}{timestamp}{path}{method.upper()}{body}"
signed = self.signing_key.sign(message_to_sign.encode("utf-8"))
signature = base64.b64encode(signed.signature).decode("utf-8")

No SHA-512 hashing or newline formatting is required—and unless you’ve uncovered a separate bug in how self.signing_key is initialized elsewhere, there’s no need to switch to a lazy _ensure_client() call here. You can safely ignore the proposed refactor.

Likely an incorrect or invalid review comment.

tests/test_robinhood.py (1)

18-30: Registration and schema tests look good.

Covers source name and the exposed schema string. Solid baseline.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (4)
pyspark_datasources/robinhood.py (4)

14-58: Docstring is inconsistent with implementation and missing required sections.

The docstring claims rate limiting is handled automatically (Line 54), but the code doesn't implement any rate limiting logic. Also missing required sections per coding guidelines: Options, Errors, Partitioning, and Arrow details.

Based on the past review comment, this needs to be addressed with proper Options/Errors/Partitioning/Arrow sections and accurate rate limiting claims.


83-115: Constructor creates non-serializable objects that prevent proper parallelization.

The constructor initializes requests.Session and SigningKey objects that are not pickle-serializable, which will cause issues when the reader is distributed to Spark executors. Per coding guidelines and past review feedback, these expensive/non-serializable operations should be deferred until read time.

This was flagged in previous reviews and needs to be addressed with lazy initialization and proper pickle handling via __getstate__/__setstate__ methods.


121-132: Signature generation doesn't match Robinhood API specification.

The current signature format doesn't align with the Robinhood Crypto API's Ed25519 signing requirements, and it uses self.signing_key which won't exist after implementing lazy initialization.

Based on the past review comment, this needs to use the proper canonical string format with newline-separated components (method, path, timestamp, SHA-512 body hash) and call self._ensure_client() before using self._signing_key.


134-168: Request method needs improvements for reliability and type safety.

Multiple issues: missing type hints, no rate limiting despite docstring claims, uses print instead of logging, doesn't handle query parameters in signature, and lacks proper error handling.

Based on the past review comment, this needs proper typing, 429/retry handling, logging instead of prints, and inclusion of sorted query parameters in the signed path.

🧹 Nitpick comments (2)
pyspark_datasources/robinhood.py (2)

2-2: Update type annotations to use modern syntax.

Use dict instead of the deprecated typing.Dict for type annotations.

-from typing import Dict
+from typing import Any

178-178: Add type hints for method parameters and return value.

The method signature lacks type annotations as requested in previous review.

-    def partitions(self):
+    def partitions(self) -> list[CryptoPair]:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 9718baa and bdf2072.

📒 Files selected for processing (1)
  • pyspark_datasources/robinhood.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
pyspark_datasources/!(__init__).py

📄 CodeRabbit inference engine (CLAUDE.md)

pyspark_datasources/!(__init__).py: All data source classes must inherit from Spark's DataSource base class
Implement schema() in each data source to define the schema using PySpark StructType
Implement reader(schema) for batch reads; if streaming is supported, implement streamReader() or simpleStreamReader() accordingly
If batch writing is supported, implement writer(schema, overwrite); if streaming writes are supported, implement streamWriter(schema, overwrite)
All classes used by the data sources (including readers/writers) must be pickle-serializable
Use only standard Spark SQL data types in schemas
Implement robust exception handling in data source read/write paths
For streaming sources, ensure resources are cleaned up (e.g., implement stop()/commit() as appropriate)
Prefer returning pyarrow.RecordBatch from readers for better serialization performance
Implement partitions() on DataSourceReader to enable parallel processing when beneficial
Defer expensive operations until read time (lazy evaluation) in data source implementations
Include comprehensive class docstrings for each data source with: brief description and Name: "format_name", an Options section (parameters/types/defaults), and Examples (registration and basic usage)
In docstrings, include schema output (e.g., df.printSchema()) for clarity
In docstrings, document error cases (e.g., missing files or invalid options)
In docstrings, document the partitioning strategy used to improve performance
In docstrings, document Arrow optimizations (how Arrow is used to transmit data)
DataSourceReader.read must return tuples, Rows, or pyarrow.RecordBatch; implement accordingly

Files:

  • pyspark_datasources/robinhood.py
{pyspark_datasources,tests}/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

{pyspark_datasources,tests}/**/*.py: When specifying file paths with Spark, use load("/path") instead of option("path", "/path")
Format code with Ruff (ruff format) and fix/lint with ruff check

Files:

  • pyspark_datasources/robinhood.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T20:07:33.249Z
Learnt from: CR
PR: allisonwang-db/pyspark-data-sources#0
File: CLAUDE.md:0-0
Timestamp: 2025-08-19T20:07:33.249Z
Learning: Applies to pyspark_datasources/!(__init__).py : All classes used by the data sources (including readers/writers) must be pickle-serializable

Applied to files:

  • pyspark_datasources/robinhood.py
🧬 Code graph analysis (1)
pyspark_datasources/robinhood.py (2)
pyspark_datasources/stock.py (2)
  • StockDataSource (8-41)
  • StockDataReader (49-86)
pyspark_datasources/github.py (1)
  • GithubDataSource (7-45)
🪛 Ruff (0.12.2)
pyspark_datasources/robinhood.py

2-2: typing.Dict is deprecated, use dict instead

(UP035)


83-83: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


134-134: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


134-134: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


134-134: Line too long (111 > 100)

(E501)


162-162: Line too long (116 > 100)

(E501)

🔇 Additional comments (1)
pyspark_datasources/robinhood.py (1)

213-213: API endpoint only supports current market data, not historical.

Based on previous discussion, this endpoint only provides latest snapshot prices without historical data capability, which aligns with the current implementation.

…umentation

- Made the base URL configurable for testing purposes.
- Added detailed options, error handling, and partitioning information to the documentation.
- Updated notes to clarify requirements and references for the Robinhood API.
Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (5)
pyspark_datasources/robinhood.py (5)

40-51: Critical: Initialize crypto components lazily to ensure pickle compatibility.

The current implementation creates non-serializable objects (SigningKey) during initialization, which will cause serialization failures when Spark distributes this reader to executors. Per coding guidelines and past review feedback, defer expensive operations until read time.


63-74: Critical: Fix signature generation to match Robinhood API specification.

The current signature format doesn't match the documented Robinhood Crypto API specification. Based on web search results, the canonical string should use newline-separated components with SHA-512 body hashing.


76-110: Improve request handling with proper error management and lazy initialization.

Multiple issues: missing lazy client initialization, no retry logic for 429 responses, using print instead of logging, and missing query parameter handling in signature.


109-109: Replace print statements with proper logging.

Using print statements in production code is inappropriate. These should use Python's logging module for proper log level control.

Also applies to: 149-149, 183-183, 186-186, 188-188, 190-190


194-238: Docstring lacks required sections per coding guidelines.

The docstring is missing required Options, Errors, Partitioning, and Arrow sections as specified in the coding guidelines. Also, the claim about rate limiting is incorrect since no retry logic is implemented.

🧹 Nitpick comments (4)
pyspark_datasources/robinhood.py (4)

2-2: Update imports to use modern type annotations.

Static analysis suggests using more modern typing constructs that are available in newer Python versions.

-from typing import Dict, List, Optional, Generator, Union
+from typing import Optional, Union
+from collections.abc import Generator

25-25: Use modern dict type annotation.

The Dict import is deprecated in favor of the built-in dict type.

-    def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
+    def __init__(self, schema: StructType, options: dict[str, str]) -> None:

120-120: Use modern list type annotation.

The List import is deprecated in favor of the built-in list type.

-    def partitions(self) -> List[CryptoPair]:
+    def partitions(self) -> list[CryptoPair]:

161-167: Optimize line length and improve readability.

The function signature exceeds the 100-character line limit. Consider breaking it into multiple lines for better readability.

-                    def safe_float(value: Union[str, int, float, None], default: float = 0.0) -> float:
+                    def safe_float(
+                        value: Union[str, int, float, None], default: float = 0.0
+                    ) -> float:
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between bdf2072 and 6d6c2f9.

📒 Files selected for processing (1)
  • pyspark_datasources/robinhood.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
pyspark_datasources/!(__init__).py

📄 CodeRabbit inference engine (CLAUDE.md)

pyspark_datasources/!(__init__).py: All data source classes must inherit from Spark's DataSource base class
Implement schema() in each data source to define the schema using PySpark StructType
Implement reader(schema) for batch reads; if streaming is supported, implement streamReader() or simpleStreamReader() accordingly
If batch writing is supported, implement writer(schema, overwrite); if streaming writes are supported, implement streamWriter(schema, overwrite)
All classes used by the data sources (including readers/writers) must be pickle-serializable
Use only standard Spark SQL data types in schemas
Implement robust exception handling in data source read/write paths
For streaming sources, ensure resources are cleaned up (e.g., implement stop()/commit() as appropriate)
Prefer returning pyarrow.RecordBatch from readers for better serialization performance
Implement partitions() on DataSourceReader to enable parallel processing when beneficial
Defer expensive operations until read time (lazy evaluation) in data source implementations
Include comprehensive class docstrings for each data source with: brief description and Name: "format_name", an Options section (parameters/types/defaults), and Examples (registration and basic usage)
In docstrings, include schema output (e.g., df.printSchema()) for clarity
In docstrings, document error cases (e.g., missing files or invalid options)
In docstrings, document the partitioning strategy used to improve performance
In docstrings, document Arrow optimizations (how Arrow is used to transmit data)
DataSourceReader.read must return tuples, Rows, or pyarrow.RecordBatch; implement accordingly

Files:

  • pyspark_datasources/robinhood.py
{pyspark_datasources,tests}/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

{pyspark_datasources,tests}/**/*.py: When specifying file paths with Spark, use load("/path") instead of option("path", "/path")
Format code with Ruff (ruff format) and fix/lint with ruff check

Files:

  • pyspark_datasources/robinhood.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T20:07:33.249Z
Learnt from: CR
PR: allisonwang-db/pyspark-data-sources#0
File: CLAUDE.md:0-0
Timestamp: 2025-08-19T20:07:33.249Z
Learning: Applies to pyspark_datasources/!(__init__).py : All classes used by the data sources (including readers/writers) must be pickle-serializable

Applied to files:

  • pyspark_datasources/robinhood.py
🧬 Code graph analysis (1)
pyspark_datasources/robinhood.py (2)
pyspark_datasources/stock.py (2)
  • StockDataSource (8-41)
  • StockDataReader (49-86)
pyspark_datasources/fake.py (1)
  • FakeDataSource (35-125)
🪛 Ruff (0.12.2)
pyspark_datasources/robinhood.py

2-2: Import from collections.abc instead: Generator

Import from collections.abc

(UP035)


2-2: typing.Dict is deprecated, use dict instead

(UP035)


2-2: typing.List is deprecated, use list instead

(UP035)


25-25: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


76-76: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


76-76: Line too long (159 > 100)

(E501)


76-76: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


76-76: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


104-104: Line too long (116 > 100)

(E501)


120-120: Use list instead of List for type annotation

Replace with list

(UP006)


161-161: Line too long (103 > 100)

(E501)

🔇 Additional comments (2)
pyspark_datasources/robinhood.py (2)

154-156: Confirmed: Robinhood Crypto API endpoint is correct
Community implementations—including the StackOverflow example in the “robinhood-unofficial/pyrh” client—use:

path = f"/api/v1/crypto/marketdata/best_bid_ask/{query_params}"

where query_params is "?symbol=BTC-USD" (or similar), resulting in the exact format you’ve implemented (/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}) (stackoverflow.com)

No changes required.


1-251: Ensure Robinhood Crypto API signature, timestamp, and headers strictly follow the spec.

The current implementation diverges from the required canonical string format, timestamp encoding, and header names:

  • Signature generation (_generate_signature)
    • Spec requires signing the canonical string:
      <HTTP_METHOD>\n
      <REQUEST_PATH>\n
      <TIMESTAMP>\n
      <REQUEST_BODY>
      
      – Remove the API key from the message, and join elements with newline delimiters instead of simple concatenation.
  • Timestamp formatting (_get_current_timestamp)
    • Spec enforces a 30-second replay window using an ISO 8601 UTC timestamp (e.g. 2025-08-21T23:58:32Z) or UNIX epoch in milliseconds.
    • Update from int(datetime.now().timestamp()) to an ISO 8601 string (or ms epoch) matching the server’s expectation.
  • Request headers (_make_authenticated_request)
    • Replace:
      • x-api-keyX-RH-API-Key
      • x-signatureX-RH-Signature
      • x-timestampX-RH-Timestamp
    • Ensure Content-Type: application/json is only set for non-GET requests.

These changes are critical to achieving compatibility with Robinhood’s Ed25519-based authentication scheme.

⛔ Skipped due to learnings
Learnt from: CR
PR: allisonwang-db/pyspark-data-sources#0
File: CLAUDE.md:0-0
Timestamp: 2025-08-19T20:07:33.249Z
Learning: Applies to pyspark_datasources/!(__init__).py : Implement partitions() on DataSourceReader to enable parallel processing when beneficial

Comment on lines 53 to 57



# Crypto API base URL
self.base_url = "https://trading.robinhood.com"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing session initialization and cleanup.

The reader lacks a requests.Session for connection pooling and doesn't implement proper resource cleanup. These should be initialized lazily.

Add lazy session initialization:

+        # Initialize session for connection pooling (lazily initialized)
+        self._session = None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Crypto API base URL
self.base_url = "https://trading.robinhood.com"
# Crypto API base URL
self.base_url = "https://trading.robinhood.com"
# Initialize session for connection pooling (lazily initialized)
self._session = None
🤖 Prompt for AI Agents
In pyspark_datasources/robinhood.py around lines 53 to 57, the class sets
self.base_url but does not initialize an HTTP session or handle cleanup; add a
lazily-initialized requests.Session stored as self._session (create it only when
first needed for making requests), update all request calls to use this session
for connection pooling, and provide a public close() method plus context-manager
support (__enter__/__exit__) or a __del__ fallback to call close() so the
session is properly closed when the client is disposed.

@Yicong-Huang
Copy link
Contributor Author

@allisonwang-db I've taken care of the comments. The PR is ready for another review. Thanks.

Copy link

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tests/test_robinhood.py (1)

42-54: Remove hard-coded key-like string; patch crypto construction for this error-path test.

This string is flagged by secret scanners and the test shouldn’t construct real crypto state. Mock the signer and base64 decode to a benign constant.

-def test_robinhood_missing_symbols(spark):
-    """Test that missing symbols raises an error."""
-    with pytest.raises(AnalysisException) as excinfo:
-        df = (
-            spark.read.format("robinhood")
-            .option("api_key", "test-key")
-            .option("private_key", "FAPmPMsQqDFOFiRvpUMJ6BC5eFOh/tPx7qcTYGKc8nE=")
-            .load("")
-        )
-        df.collect()  # Trigger execution
+def test_robinhood_missing_symbols(spark):
+    """Test that missing symbols raises an error."""
+    with patch("nacl.signing.SigningKey") as _sk, \
+         patch("pyspark_datasources.robinhood.base64.b64decode", return_value=b"\x00" * 32):
+        with pytest.raises(AnalysisException) as excinfo:
+            df = (
+                spark.read.format("robinhood")
+                .option("api_key", "test-key")
+                .option("private_key", "placeholder-key")
+                .load("")
+            )
+            df.collect()  # Trigger execution
pyspark_datasources/robinhood.py (2)

23-54: Avoid non-picklable state in DataSourceReader; defer expensive crypto init.

SigningKey is not reliably pickle-serializable. Readers get shipped to executors; eager creation risks serialization errors and violates the repo guideline to “defer expensive operations until read time.” Store only primitives; lazily init crypto/client and clear in __getstate__.

-    def __init__(self, schema: StructType, options: Dict[str, str]) -> None:
+    def __init__(self, schema: StructType, options: dict[str, Any]) -> None:
         self.schema = schema
         self.options = options
@@
-        # Initialize NaCl signing key
-        try:
-            from nacl.signing import SigningKey
-
-            private_key_seed = base64.b64decode(self.private_key_base64)
-            self.signing_key = SigningKey(private_key_seed)
-        except ImportError:
-            raise ImportError(
-                "PyNaCl library is required for Robinhood Crypto API authentication. "
-                "Install it with: pip install pynacl"
-            )
-        except Exception as e:
-            raise ValueError(f"Invalid private key format: {str(e)}")
+        # Lazily initialized client state for pickling and performance
+        self._signing_key = None  # type: ignore[assignment]
+        self._session = None
@@
-        self.base_url = options.get("base_url", "https://trading.robinhood.com")
+        self.base_url = options.get("base_url", "https://trading.robinhood.com")
+
+    def _ensure_client(self) -> None:
+        """Initialize signing key and HTTP session on first use."""
+        if self._signing_key is None:
+            try:
+                from nacl.signing import SigningKey
+            except ImportError as ie:
+                raise ImportError(
+                    "PyNaCl library is required. Install with: pip install pynacl"
+                ) from ie
+            try:
+                private_key_seed = base64.b64decode(self.private_key_base64)
+                self._signing_key = SigningKey(private_key_seed)
+            except Exception as e:
+                raise ValueError(f"Invalid private key format: {e}") from e
+        if self._session is None:
+            self._session = requests.Session()
+            self._session.headers.update(
+                {"User-Agent": "PySpark Robinhood Crypto DataSource/1.0"}
+            )
+
+    def __getstate__(self):
+        state = dict(self.__dict__)
+        state["_signing_key"] = None
+        state["_session"] = None
+        return state
+
+    def __setstate__(self, state):
+        self.__dict__.update(state)

72-115: Sign the exact request path+query; avoid body divergence; replace prints with logging and add basic retries.

  • Include a sorted query string in the signed path to avoid mismatches.
  • Ensure the transmitted body matches the signed body; pass data=body if you pre-dumped JSON.
  • Replace print with logging; optionally add simple retry for 429/5xx.
     def _make_authenticated_request(
         self,
         method: str,
         path: str,
-        params: Optional[Dict[str, str]] = None,
-        json_data: Optional[Dict] = None,
-    ) -> Optional[Dict]:
+        params: Optional[dict[str, str]] = None,
+        json_data: Optional[dict[str, Any]] = None,
+    ) -> Optional[dict[str, Any]]:
         """Make an authenticated request to the Robinhood Crypto API."""
-        timestamp = self._get_current_timestamp()
-        url = self.base_url + path
+        from urllib.parse import urlencode
+        import time
+        self._ensure_client()
+        timestamp = self._get_current_timestamp()
+        query = f"?{urlencode(sorted(params.items()))}" if params else ""
+        signed_path = f"{path}{query}"
+        url = f"{self.base_url}{signed_path}"
@@
-        body = ""
-        if method.upper() != "GET" and json_data:
-            body = json.dumps(json_data, separators=(",", ":"))  # Compact JSON format
+        body = ""
+        if method.upper() != "GET" and json_data:
+            body = json.dumps(json_data, separators=(",", ":"))
@@
-        signature = self._generate_signature(timestamp, method, path, body)
+        signature = self._generate_signature(timestamp, method, signed_path, body)
@@
-        try:
-            # Make request
-            if method.upper() == "GET":
-                response = requests.get(url, headers=headers, params=params, timeout=10)
-            elif method.upper() == "POST":
-                headers["Content-Type"] = "application/json"
-                response = requests.post(url, headers=headers, json=json_data, timeout=10)
-            else:
-                response = requests.request(
-                    method, url, headers=headers, params=params, json=json_data, timeout=10
-                )
-
-            response.raise_for_status()
-            return response.json()
-        except requests.RequestException as e:
-            print(f"Error making API request to {path}: {e}")
-            return None
+        # Basic retry for transient issues
+        for attempt in range(3):
+            try:
+                if method.upper() == "GET":
+                    response = self._session.get(url, headers=headers, timeout=10)  # type: ignore[union-attr]
+                else:
+                    headers["Content-Type"] = "application/json"
+                    response = self._session.request(  # type: ignore[union-attr]
+                        method, url, headers=headers, data=(body or None), timeout=10
+                    )
+                if response.status_code == 429:
+                    retry_after = min(int(response.headers.get("Retry-After", "1")), 5)
+                    time.sleep(retry_after)
+                    continue
+                response.raise_for_status()
+                try:
+                    return response.json()
+                except ValueError:
+                    logger.warning("Non-JSON response from %s", url)
+                    return None
+            except requests.RequestException as e:
+                logger.warning("API request error (%s): %s", url, e)
+                if attempt == 2:
+                    return None
+                time.sleep(1)
🧹 Nitpick comments (9)
tests/test_robinhood.py (3)

1-6: Clean up unused imports to satisfy Ruff and reduce noise.

Mock and Row are unused. Keep patch only if you adopt the mocking change below; otherwise remove it too.

-import os
-import pytest
-from unittest.mock import Mock, patch
-from pyspark.sql import SparkSession, Row
+import os
+import pytest
+from unittest.mock import patch
+from pyspark.sql import SparkSession
 from pyspark.errors.exceptions.captured import AnalysisException

76-79: Wrap overly long skip messages to meet line-length policy (<=100).

Shorten the message or wrap with parentheses to satisfy E501.

-        pytest.skip(
-            "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY environment variables required for real API tests"
-        )
+        pytest.skip(
+            "ROBINHOOD_API_KEY and ROBINHOOD_PRIVATE_KEY env vars required "
+            "for real API tests"
+        )

Also applies to: 121-124


115-142: Reduce external flakiness: assert by distinct symbols; optional mark for live tests.

The combination of len(rows) >= 3 and symbol set check can be brittle under transient API hiccups. Rely on the distinct symbol count; optionally add a pytest marker (e.g., @pytest.mark.live) to let CI skip by default.

-    # CRITICAL: Should get data for all 3 requested pairs
-    assert len(rows) >= 3, f"TEST FAILED: Expected 3 crypto pairs, got {len(rows)} records."
+    # Expect data for all 3 requested pairs
+    # (avoid brittle dependency on total row count)

If you’d like, I can add a pytest.ini marker and gate these tests behind -m live.

pyspark_datasources/robinhood.py (6)

1-12: Modernize type imports and add logging scaffolding.

Use collections.abc for Generator, prefer builtin generics, and prepare logging for later print→logger refactor.

-from dataclasses import dataclass
-from typing import Dict, List, Optional, Generator, Union
-import requests
-import json
-import base64
-import datetime
+from dataclasses import dataclass
+from collections.abc import Generator
+from typing import Optional, Union, Any
+import requests
+import json
+import base64
+import datetime
+import logging
@@
-from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
+from pyspark.sql.datasource import DataSource, DataSourceReader, InputPartition
+
+logger = logging.getLogger(__name__)

116-123: Remove unused helper _get_query_params or integrate it into signing.

It’s currently unused and can confuse readers.

-    @staticmethod
-    def _get_query_params(key: str, *args: str) -> str:
-        """Build query parameters for API requests."""
-        if not args:
-            return ""
-        params = [f"{key}={arg}" for arg in args if arg]
-        return "?" + "&".join(params)
+    # (removed; query building is handled centrally in _make_authenticated_request)

124-142: Deduplicate symbols while preserving order; tighten types.

Avoid fetching the same symbol twice if provided redundantly.

-    def partitions(self) -> List[CryptoPair]:
+    def partitions(self) -> list[CryptoPair]:
@@
-        # Split symbols by comma and create partitions
-        symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
+        # Split by comma, normalize, and preserve order
+        symbols = [symbol.strip().upper() for symbol in symbols_str.split(",")]
@@
-        return [CryptoPair(symbol=symbol) for symbol in formatted_symbols]
+        # Deduplicate while preserving order
+        deduped = list(dict.fromkeys(formatted_symbols))
+        return [CryptoPair(symbol=s) for s in deduped]

143-152: Replace print with logging in read path; don’t fail job for single-partition errors.

Use the module logger for proper levels and structured logs.

-        except Exception as e:
-            # Log error but don't fail the entire job
-            print(f"Warning: Failed to fetch data for {symbol}: {str(e)}")
+        except Exception as e:
+            # Log error but don't fail the entire job
+            logger.warning("Failed to fetch data for %s: %s", symbol, str(e))

196-266: Docstring largely aligns with repo guidelines; minor wording/line-width nits.

Looks good: includes Name, schema, Options, Errors, Partitioning, Arrow, and Examples. Consider wrapping long lines (>100) to satisfy Ruff E501. No functional changes needed.


59-71: _generate_signature implementation aligns with Robinhood’s spec—minimal tweaks suggested

The current code already concatenates API key + timestamp (Unix seconds) + full request path (including any query string) + HTTP method (uppercase) + body (omit for no-body requests), then signs the UTF-8 bytes with Ed25519 and Base64-encodes the raw signature—exactly as community SDKs and examples prescribe .

To make the intent crystal-clear and guard against common pitfalls, consider these optional refinements:

• Rename the path parameter to path_with_query and update the docstring to note:
– Must include the “?key=val…” portion exactly as sent (no reordering)
– The timestamp must be a Unix-epoch integer in seconds (not milliseconds)
– The body must be the exact JSON string you’ll send (deterministically serialized)

• (Optional) At the top of _generate_signature, call your lazy-init helper (e.g. self._ensure_client()) to guarantee self.signing_key is ready prior to use.

No hashing of the body or newline-delimited canonical blocks are required—those would deviate from the official format.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6d6c2f9 and 8949a85.

📒 Files selected for processing (2)
  • pyspark_datasources/robinhood.py (1 hunks)
  • tests/test_robinhood.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
pyspark_datasources/!(__init__).py

📄 CodeRabbit inference engine (CLAUDE.md)

pyspark_datasources/!(__init__).py: All data source classes must inherit from Spark's DataSource base class
Implement schema() in each data source to define the schema using PySpark StructType
Implement reader(schema) for batch reads; if streaming is supported, implement streamReader() or simpleStreamReader() accordingly
If batch writing is supported, implement writer(schema, overwrite); if streaming writes are supported, implement streamWriter(schema, overwrite)
All classes used by the data sources (including readers/writers) must be pickle-serializable
Use only standard Spark SQL data types in schemas
Implement robust exception handling in data source read/write paths
For streaming sources, ensure resources are cleaned up (e.g., implement stop()/commit() as appropriate)
Prefer returning pyarrow.RecordBatch from readers for better serialization performance
Implement partitions() on DataSourceReader to enable parallel processing when beneficial
Defer expensive operations until read time (lazy evaluation) in data source implementations
Include comprehensive class docstrings for each data source with: brief description and Name: "format_name", an Options section (parameters/types/defaults), and Examples (registration and basic usage)
In docstrings, include schema output (e.g., df.printSchema()) for clarity
In docstrings, document error cases (e.g., missing files or invalid options)
In docstrings, document the partitioning strategy used to improve performance
In docstrings, document Arrow optimizations (how Arrow is used to transmit data)
DataSourceReader.read must return tuples, Rows, or pyarrow.RecordBatch; implement accordingly

Files:

  • pyspark_datasources/robinhood.py
{pyspark_datasources,tests}/**/*.py

📄 CodeRabbit inference engine (CLAUDE.md)

{pyspark_datasources,tests}/**/*.py: When specifying file paths with Spark, use load("/path") instead of option("path", "/path")
Format code with Ruff (ruff format) and fix/lint with ruff check

Files:

  • pyspark_datasources/robinhood.py
  • tests/test_robinhood.py
🧠 Learnings (1)
📚 Learning: 2025-08-19T20:07:33.249Z
Learnt from: CR
PR: allisonwang-db/pyspark-data-sources#0
File: CLAUDE.md:0-0
Timestamp: 2025-08-19T20:07:33.249Z
Learning: Applies to pyspark_datasources/!(__init__).py : All classes used by the data sources (including readers/writers) must be pickle-serializable

Applied to files:

  • pyspark_datasources/robinhood.py
🧬 Code graph analysis (2)
pyspark_datasources/robinhood.py (1)
pyspark_datasources/stock.py (2)
  • StockDataSource (8-41)
  • StockDataReader (49-86)
tests/test_robinhood.py (1)
pyspark_datasources/robinhood.py (4)
  • RobinhoodDataSource (197-266)
  • name (259-260)
  • schema (262-263)
  • read (143-151)
🪛 Ruff (0.12.2)
pyspark_datasources/robinhood.py

2-2: Import from collections.abc instead: Generator

Import from collections.abc

(UP035)


2-2: typing.Dict is deprecated, use dict instead

(UP035)


2-2: typing.List is deprecated, use list instead

(UP035)


23-23: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


76-76: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


77-77: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


78-78: Use dict instead of Dict for type annotation

Replace with dict

(UP006)


124-124: Use list instead of List for type annotation

Replace with list

(UP006)


237-237: Line too long (106 > 100)

(E501)


246-246: Line too long (150 > 100)

(E501)


250-250: Line too long (101 > 100)

(E501)

tests/test_robinhood.py

3-3: unittest.mock.Mock imported but unused

Remove unused import

(F401)


3-3: unittest.mock.patch imported but unused

Remove unused import

(F401)


4-4: pyspark.sql.Row imported but unused

Remove unused import: pyspark.sql.Row

(F401)


78-78: Line too long (107 > 100)

(E501)


123-123: Line too long (107 > 100)

(E501)

🪛 Gitleaks (8.27.2)
tests/test_robinhood.py

48-48: Detected a Generic API Key, potentially exposing access to various services and sensitive operations.

(generic-api-key)

Comment on lines +153 to +191
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
"""Fetch cryptocurrency market data for a given trading pair."""
try:
# Get best bid/ask data for the trading pair using query parameters
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
market_data = self._make_authenticated_request("GET", path)

if market_data and "results" in market_data:
for quote in market_data["results"]:
# Parse numeric values safely
def safe_float(
value: Union[str, int, float, None], default: float = 0.0
) -> float:
if value is None or value == "":
return default
try:
return float(value)
except (ValueError, TypeError):
return default

# Extract market data fields from best bid/ask response
# Use the correct field names from the API response
price = safe_float(quote.get("price"))
bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread"))
ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread"))

yield Row(
symbol=symbol,
price=price,
bid_price=bid_price,
ask_price=ask_price,
updated_at=quote.get("timestamp", ""),
)
else:
print(f"Warning: No market data found for {symbol}")

except requests.exceptions.RequestException as e:
print(f"Network error fetching data for {symbol}: {str(e)}")
except (ValueError, KeyError) as e:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Build query via params, not inline; unify error logging.

Keep path constant and pass params, letting _make_authenticated_request assemble and sign deterministically.

-            # Get best bid/ask data for the trading pair using query parameters
-            path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
-            market_data = self._make_authenticated_request("GET", path)
+            # Get best bid/ask for the trading pair
+            path = "/api/v1/crypto/marketdata/best_bid_ask/"
+            market_data = self._make_authenticated_request("GET", path, params={"symbol": symbol})
@@
-            else:
-                print(f"Warning: No market data found for {symbol}")
+            else:
+                logger.warning("No market data found for %s", symbol)
@@
-        except requests.exceptions.RequestException as e:
-            print(f"Network error fetching data for {symbol}: {str(e)}")
-        except (ValueError, KeyError) as e:
-            print(f"Data parsing error for {symbol}: {str(e)}")
-        except Exception as e:
-            print(f"Unexpected error fetching data for {symbol}: {str(e)}")
+        except requests.exceptions.RequestException as e:
+            logger.error("Network error fetching data for %s: %s", symbol, str(e))
+        except (ValueError, KeyError) as e:
+            logger.error("Data parsing error for %s: %s", symbol, str(e))
+        except Exception as e:
+            logger.error("Unexpected error fetching data for %s: %s", symbol, str(e))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
"""Fetch cryptocurrency market data for a given trading pair."""
try:
# Get best bid/ask data for the trading pair using query parameters
path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
market_data = self._make_authenticated_request("GET", path)
if market_data and "results" in market_data:
for quote in market_data["results"]:
# Parse numeric values safely
def safe_float(
value: Union[str, int, float, None], default: float = 0.0
) -> float:
if value is None or value == "":
return default
try:
return float(value)
except (ValueError, TypeError):
return default
# Extract market data fields from best bid/ask response
# Use the correct field names from the API response
price = safe_float(quote.get("price"))
bid_price = safe_float(quote.get("bid_inclusive_of_sell_spread"))
ask_price = safe_float(quote.get("ask_inclusive_of_buy_spread"))
yield Row(
symbol=symbol,
price=price,
bid_price=bid_price,
ask_price=ask_price,
updated_at=quote.get("timestamp", ""),
)
else:
print(f"Warning: No market data found for {symbol}")
except requests.exceptions.RequestException as e:
print(f"Network error fetching data for {symbol}: {str(e)}")
except (ValueError, KeyError) as e:
def _read_crypto_pair_data(self, symbol: str) -> Generator[Row, None, None]:
"""Fetch cryptocurrency market data for a given trading pair."""
try:
- # Get best bid/ask data for the trading pair using query parameters
- path = f"/api/v1/crypto/marketdata/best_bid_ask/?symbol={symbol}"
# Get best bid/ask for the trading pair
path = "/api/v1/crypto/marketdata/best_bid_ask/"
market_data = self._make_authenticated_request(
"GET",
path,
params={"symbol": symbol}
)
if market_data and "results" in market_data:
for quote in market_data["results"]:
@@
- else:
else:
logger.warning("No market data found for %s", symbol)
- except requests.exceptions.RequestException as e:
- print(f"Network error fetching data for {symbol}: {str(e)}")
- except (ValueError, KeyError) as e:
- print(f"Data parsing error for {symbol}: {str(e)}")
- except Exception as e:
except requests.exceptions.RequestException as e:
logger.error("Network error fetching data for %s: %s", symbol, str(e))
except (ValueError, KeyError) as e:
logger.error("Data parsing error for %s: %s", symbol, str(e))
except Exception as e:
logger.error("Unexpected error fetching data for %s: %s", symbol, str(e))

@allisonwang-db allisonwang-db merged commit 2bcae6a into allisonwang-db:master Aug 26, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants