Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ spark.readStream.format("fake").load().writeStream.format("console").start()
| [KaggleDataSource](pyspark_datasources/kaggle.py) | `kaggle` | Read datasets from Kaggle | `kagglehub`, `pandas` |
| [SimpleJsonDataSource](pyspark_datasources/simplejson.py) | `simplejson` | Write JSON data to Databricks DBFS | `databricks-sdk` |
| [OpenSkyDataSource](pyspark_datasources/opensky.py) | `opensky` | Read from OpenSky Network. | None |
| [SalesforceDataSource](pyspark_datasources/salesforce.py) | `salesforce` | Write streaming data to Salesforce objects | `simple-salesforce` |

See more here: https://allisonwang-db.github.io/pyspark-data-sources/.

Expand Down
6 changes: 6 additions & 0 deletions docs/datasources/salesforce.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# SalesforceDataSource

> Requires the [`simple-salesforce`](https://github.com/simple-salesforce/simple-salesforce) library. You can install it manually: `pip install simple-salesforce`
> or use `pip install pyspark-data-sources[salesforce]`.

::: pyspark_datasources.salesforce.SalesforceDataSource
1 change: 1 addition & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ spark.readStream.format("fake").load().writeStream.format("console").start()
| [HuggingFaceDatasets](./datasources/huggingface.md) | `huggingface` | Read datasets from the HuggingFace Hub | `datasets` |
| [StockDataSource](./datasources/stock.md) | `stock` | Read stock data from Alpha Vantage | None |
| [SimpleJsonDataSource](./datasources/simplejson.md) | `simplejson` | Write JSON data to Databricks DBFS | `databricks-sdk` |
| [SalesforceDataSource](./datasources/salesforce.md) | `salesforce` | Write streaming data to Salesforce objects | None |
| [GoogleSheetsDataSource](./datasources/googlesheets.md) | `googlesheets` | Read table from public Google Sheets document | None |
| [KaggleDataSource](./datasources/kaggle.md) | `kaggle` | Read datasets from Kaggle | `kagglehub`, `pandas` |
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ nav:
- datasources/huggingface.md
- datasources/stock.md
- datasources/simplejson.md
- datasources/salesforce.md
- datasources/googlesheets.md
- datasources/kaggle.md

Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@ mkdocstrings = {extras = ["python"], version = "^0.28.0"}
datasets = {version = "^2.17.0", optional = true}
databricks-sdk = {version = "^0.28.0", optional = true}
kagglehub = {extras = ["pandas-datasets"], version = "^0.3.10", optional = true}
simple-salesforce = {version = "^1.12.0", optional = true}

[tool.poetry.extras]
faker = ["faker"]
datasets = ["datasets"]
databricks = ["databricks-sdk"]
kaggle = ["kagglehub"]
lance = ["pylance"]
all = ["faker", "datasets", "databricks-sdk", "kagglehub"]
salesforce = ["simple-salesforce"]
all = ["faker", "datasets", "databricks-sdk", "kagglehub", "simple-salesforce"]

[tool.poetry.group.dev.dependencies]
pytest = "^8.0.0"
Expand Down
1 change: 1 addition & 0 deletions pyspark_datasources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
from .huggingface import HuggingFaceDatasets
from .kaggle import KaggleDataSource
from .opensky import OpenSkyDataSource
from .salesforce import SalesforceDataSource
from .simplejson import SimpleJsonDataSource
from .stock import StockDataSource
294 changes: 294 additions & 0 deletions pyspark_datasources/salesforce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import logging
from dataclasses import dataclass
from typing import Dict, List, Any

from pyspark.sql.types import StructType
from pyspark.sql.datasource import DataSource, DataSourceStreamWriter, WriterCommitMessage

logger = logging.getLogger(__name__)


@dataclass
class SalesforceCommitMessage(WriterCommitMessage):
"""Commit message for Salesforce write operations."""
records_written: int
batch_id: int


class SalesforceDataSource(DataSource):
"""
A Salesforce streaming data source for PySpark to write data to Salesforce objects.
This data source enables writing streaming data from Spark to Salesforce using the
Salesforce REST API. It supports common Salesforce objects like Account, Contact,
Opportunity, and custom objects.
Name: `salesforce`
Notes
-----
- Requires the `simple-salesforce` library for Salesforce API integration
- Only supports streaming write operations (not read operations)
- Uses Salesforce username/password/security token authentication
- Supports streaming processing for efficient API usage
Parameters
----------
username : str
Salesforce username (email address)
password : str
Salesforce password
security_token : str
Salesforce security token (obtained from Salesforce setup)
salesforce_object : str, optional
Target Salesforce object name (default: "Account")
Comment on lines +46 to +47
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can I get a list of objects?

Copy link
Contributor Author

@shujingyang-db shujingyang-db Jul 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's on salesforce UI. We can also add features to pull objects list and their schemas from salesforce

batch_size : str, optional
Number of records to process per batch (default: "200")
instance_url : str, optional
Custom Salesforce instance URL (auto-detected if not provided)
Examples
--------
Register the data source:
>>> from pyspark_datasources import SalesforceDataSource
>>> spark.dataSource.register(SalesforceDataSource)
Write streaming data to Salesforce Accounts:
>>> from pyspark.sql import SparkSession
>>> from pyspark.sql.functions import col, lit
>>>
>>> spark = SparkSession.builder.appName("SalesforceExample").getOrCreate()
>>> spark.dataSource.register(SalesforceDataSource)
>>>
>>> # Create sample streaming data
>>> streaming_df = spark.readStream.format("rate").load()
>>> account_data = streaming_df.select(
... col("value").cast("string").alias("Name"),
... lit("Technology").alias("Industry"),
... (col("value") * 100000).cast("double").alias("AnnualRevenue")
... )
>>>
>>> # Write to Salesforce
>>> query = account_data.writeStream \\
... .format("salesforce") \\
... .option("username", "your-username@company.com") \\
... .option("password", "your-password") \\
... .option("security_token", "your-security-token") \\
... .option("salesforce_object", "Account") \\
... .option("batch_size", "100") \\
... .start()
Write to Salesforce Contacts:
>>> contact_data = streaming_df.select(
... col("value").cast("string").alias("FirstName"),
... lit("Doe").alias("LastName"),
... lit("contact@example.com").alias("Email")
... )
>>>
>>> query = contact_data.writeStream \\
... .format("salesforce") \\
... .option("username", "your-username@company.com") \\
... .option("password", "your-password") \\
... .option("security_token", "your-security-token") \\
... .option("salesforce_object", "Contact") \\
... .start()
Write to custom Salesforce objects:
>>> custom_data = streaming_df.select(
... col("value").cast("string").alias("Custom_Field__c"),
... lit("Custom Value").alias("Another_Field__c")
... )
>>>
>>> query = custom_data.writeStream \\
... .format("salesforce") \\
... .option("username", "your-username@company.com") \\
... .option("password", "your-password") \\
... .option("security_token", "your-security-token") \\
... .option("salesforce_object", "Custom_Object__c") \\
... .start()
"""

@classmethod
def name(cls) -> str:
"""Return the short name for this data source."""
return "salesforce"

def schema(self) -> str:
"""
Define the default schema for Salesforce Account objects.
This schema can be overridden by users when creating their DataFrame.
"""
return """
Name STRING NOT NULL,
Industry STRING,
Phone STRING,
Website STRING,
AnnualRevenue DOUBLE,
NumberOfEmployees INT,
BillingStreet STRING,
BillingCity STRING,
BillingState STRING,
BillingPostalCode STRING,
BillingCountry STRING
"""

def streamWriter(self, schema: StructType, overwrite: bool) -> "SalesforceStreamWriter":
"""Create a stream writer for Salesforce integration."""
return SalesforceStreamWriter(schema, self.options)


class SalesforceStreamWriter(DataSourceStreamWriter):
"""Stream writer implementation for Salesforce integration."""

def __init__(self, schema: StructType, options: Dict[str, str]):
self.schema = schema
self.options = options

# Extract Salesforce configuration
self.username = options.get("username")
self.password = options.get("password")
self.security_token = options.get("security_token")
self.instance_url = options.get("instance_url")
self.salesforce_object = options.get("salesforce_object", "Account")
self.batch_size = int(options.get("batch_size", "200"))

Comment on lines +210 to +216
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Security: Avoid storing sensitive credentials as instance variables.

Storing passwords and security tokens as instance variables could expose them in logs, stack traces, or memory dumps. Consider extracting credentials only when needed in the write method.

Apply this diff to improve security:

     def __init__(self, schema: StructType, options: Dict[str, str]):
         self.schema = schema
         self.options = options
         
-        # Extract Salesforce configuration
-        self.username = options.get("username")
-        self.password = options.get("password")
-        self.security_token = options.get("security_token")
-        self.instance_url = options.get("instance_url")
         self.salesforce_object = options.get("salesforce_object", "Account")
         self.batch_size = int(options.get("batch_size", "200"))
         
         # Validate required options
-        if not all([self.username, self.password, self.security_token]):
+        username = options.get("username")
+        password = options.get("password") 
+        security_token = options.get("security_token")
+        if not all([username, password, security_token]):
             raise ValueError(
                 "Salesforce username, password, and security_token are required. "
                 "Set them using .option() method in your streaming query."
             )

Then update the write method to extract credentials:

# In write method, after line 186:
username = self.options.get("username")
password = self.options.get("password")
security_token = self.options.get("security_token")
instance_url = self.options.get("instance_url")
🤖 Prompt for AI Agents
In pyspark_datasources/salesforce.py around lines 153 to 159, sensitive
credentials like password and security_token are stored as instance variables,
which risks exposure in logs or memory. Remove these credentials from instance
variables and instead extract them locally within the write method after line
186 by accessing them from self.options. This limits the scope of sensitive data
and improves security.

# Validate required options
if not all([self.username, self.password, self.security_token]):
raise ValueError(
"Salesforce username, password, and security_token are required. "
"Set them using .option() method in your streaming query."
)

logger.info(f"Initializing Salesforce writer for object '{self.salesforce_object}'")

def write(self, iterator) -> SalesforceCommitMessage:
"""Write data to Salesforce."""
# Import here to avoid serialization issues
try:
from simple_salesforce import Salesforce
except ImportError:
raise ImportError(
"simple-salesforce library is required for Salesforce integration. "
"Install it with: pip install simple-salesforce"
)
Comment on lines +232 to +235
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add exception chaining for better error traceability.

When re-raising exceptions, use from to preserve the original exception context.

Apply this diff:

         except ImportError:
             raise ImportError(
                 "simple-salesforce library is required for Salesforce integration. "
                 "Install it with: pip install simple-salesforce"
-            )
+            ) from None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
raise ImportError(
"simple-salesforce library is required for Salesforce integration. "
"Install it with: pip install simple-salesforce"
)
except ImportError:
raise ImportError(
"simple-salesforce library is required for Salesforce integration. "
"Install it with: pip install simple-salesforce"
) from None
🧰 Tools
🪛 Ruff (0.12.2)

175-178: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🤖 Prompt for AI Agents
In pyspark_datasources/salesforce.py around lines 175 to 178, the ImportError is
raised without exception chaining, which loses the original error context.
Modify the raise statement to use "raise ImportError(...) from e" where "e" is
the caught exception, to preserve the original exception context for better
traceability.


from pyspark import TaskContext

# Get task context for batch identification
context = TaskContext.get()
batch_id = context.taskAttemptId()

# Connect to Salesforce
try:
sf_kwargs = {
'username': self.username,
'password': self.password,
'security_token': self.security_token
}
if self.instance_url:
sf_kwargs['instance_url'] = self.instance_url

sf = Salesforce(**sf_kwargs)
logger.info(f"✓ Connected to Salesforce (batch {batch_id})")
except Exception as e:
logger.error(f"Failed to connect to Salesforce: {str(e)}")
raise ConnectionError(f"Salesforce connection failed: {str(e)}")
Comment on lines +256 to +257
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add exception chaining to preserve error context.

Apply this diff:

         except Exception as e:
             logger.error(f"Failed to connect to Salesforce: {str(e)}")
-            raise ConnectionError(f"Salesforce connection failed: {str(e)}")
+            raise ConnectionError(f"Salesforce connection failed: {str(e)}") from e
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
logger.error(f"Failed to connect to Salesforce: {str(e)}")
raise ConnectionError(f"Salesforce connection failed: {str(e)}")
except Exception as e:
logger.error(f"Failed to connect to Salesforce: {str(e)}")
- raise ConnectionError(f"Salesforce connection failed: {str(e)}")
+ raise ConnectionError(f"Salesforce connection failed: {str(e)}") from e
🧰 Tools
🪛 Ruff (0.12.2)

200-200: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🤖 Prompt for AI Agents
In pyspark_datasources/salesforce.py around lines 199 to 200, the code raises a
ConnectionError after logging the original exception but does not preserve the
original exception context. Modify the raise statement to use "raise
ConnectionError(...) from e" to enable exception chaining and preserve the
original error context.


# Convert rows to Salesforce records
records = []
for row in iterator:
try:
record = self._convert_row_to_salesforce_record(row)
if record: # Only add non-empty records
records.append(record)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will consume all rows in the current partition. What if we have too many rows that exceed the executor memory limit? We can either 1) batch the rows and periodically write them to Salesforce or 2) use the Arrow writer which take an iterator of arrow record batches : https://github.com/apache/spark/blob/master/python/pyspark/sql/datasource.py#L979

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think we don't have streaming support for Arrow record batch writer (which we should support!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure I updated it with batch writing data to Salesforce

except Exception as e:
logger.warning(f"Failed to convert row to Salesforce record: {str(e)}")

if not records:
logger.info(f"No valid records to write in batch {batch_id}")
return SalesforceCommitMessage(records_written=0, batch_id=batch_id)

# Write records to Salesforce
try:
records_written = self._write_to_salesforce(sf, records, batch_id)
logger.info(f"✅ Batch {batch_id}: Successfully wrote {records_written} records")
return SalesforceCommitMessage(records_written=records_written, batch_id=batch_id)
except Exception as e:
logger.error(f"❌ Batch {batch_id}: Failed to write records: {str(e)}")
raise

def _convert_row_to_salesforce_record(self, row) -> Dict[str, Any]:
"""Convert a Spark Row to a Salesforce record format."""
record = {}

for field in self.schema.fields:
field_name = field.name
try:
# Use getattr for safe field access
value = getattr(row, field_name, None)

if value is not None:
# Convert value based on field type
if hasattr(value, 'isoformat'): # datetime objects
record[field_name] = value.isoformat()
elif isinstance(value, (int, float)):
record[field_name] = value
else:
record[field_name] = str(value)

except Exception as e:
logger.warning(f"Failed to convert field '{field_name}': {str(e)}")

return record

def _write_to_salesforce(self, sf, records: List[Dict[str, Any]], batch_id: int) -> int:
"""Write records to Salesforce using REST API."""
success_count = 0

# Get the Salesforce object API
try:
sf_object = getattr(sf, self.salesforce_object)
except AttributeError:
raise ValueError(f"Salesforce object '{self.salesforce_object}' not found")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Add exception chaining for consistency.

Apply this diff:

         except AttributeError:
-            raise ValueError(f"Salesforce object '{self.salesforce_object}' not found")
+            raise ValueError(f"Salesforce object '{self.salesforce_object}' not found") from None
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
except AttributeError:
raise ValueError(f"Salesforce object '{self.salesforce_object}' not found")
except AttributeError:
raise ValueError(f"Salesforce object '{self.salesforce_object}' not found") from None
🧰 Tools
🪛 Ruff (0.12.2)

257-257: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)

🤖 Prompt for AI Agents
In pyspark_datasources/salesforce.py around lines 256 to 257, the ValueError
raised in the except block for AttributeError lacks exception chaining. Modify
the raise statement to include "from e" where "e" is the caught AttributeError,
ensuring consistent exception chaining for better error traceability.


# Process records in batches
for i in range(0, len(records), self.batch_size):
batch_records = records[i:i + self.batch_size]

for j, record in enumerate(batch_records):
try:
# Create the record in Salesforce
result = sf_object.create(record)

if result.get('success'):
success_count += 1
else:
logger.warning(f"Failed to create record {i+j}: {result.get('errors', 'Unknown error')}")

except Exception as e:
logger.error(f"Error creating record {i+j}: {str(e)}")

# Log progress for large batches
if len(records) > 50 and (i + self.batch_size) % 100 == 0:
logger.info(f"Batch {batch_id}: Processed {i + self.batch_size}/{len(records)} records")

return success_count

def commit(self, messages: List[SalesforceCommitMessage], batch_id: int) -> None:
"""Commit the write operation."""
total_records = sum(msg.records_written for msg in messages if msg is not None)
total_batches = len([msg for msg in messages if msg is not None])

logger.info(f"✅ Commit batch {batch_id}: Successfully wrote {total_records} records across {total_batches} batches")

def abort(self, messages: List[SalesforceCommitMessage], batch_id: int) -> None:
"""Abort the write operation."""
total_batches = len([msg for msg in messages if msg is not None])
logger.warning(f"❌ Abort batch {batch_id}: Rolling back {total_batches} batches")
# Note: Salesforce doesn't support transaction rollback for individual records
# Records that were successfully created will remain in Salesforce
Loading