Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions python/pyspark/sql/worker/data_source_pushdown_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
StringStartsWith,
)
from pyspark.sql.types import StructType, VariantVal, _parse_datatype_json_string
from pyspark.sql.worker.plan_data_source_read import write_read_func_and_partitions
from pyspark.util import handle_worker_exception, local_connect_and_auth
from pyspark.worker_util import (
check_python_version,
Expand Down Expand Up @@ -131,11 +132,12 @@ def main(infile: IO, outfile: IO) -> None:
- a `DataSource` instance representing the data source
- a `StructType` instance representing the output schema of the data source
- a list of filters to be pushed down
- configuration values

This process then creates a `DataSourceReader` instance by calling the `reader` method
on the `DataSource` instance. It applies the filters by calling the `pushFilters` method
on the reader and determines which filters are supported. The data source with updated reader
is then sent back to the JVM along with the indices of the supported filters.
on the reader and determines which filters are supported. The indices of the supported
filters are sent back to the JVM, along with the list of partitions and the read function.
"""
faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
try:
Expand Down Expand Up @@ -220,10 +222,22 @@ def main(infile: IO, outfile: IO) -> None:
},
)

# Monkey patch the data source instance
# to return the existing reader with the pushed down filters.
data_source.reader = lambda schema: reader # type: ignore[method-assign]
pickleSer._write_with_length(data_source, outfile)
# Receive the max arrow batch size.
max_arrow_batch_size = read_int(infile)
assert max_arrow_batch_size > 0, (
"The maximum arrow batch size should be greater than 0, but got "
f"'{max_arrow_batch_size}'"
)

# Return the read function and partitions. Doing this in the same worker as filter pushdown
# helps reduce the number of Python worker calls.
write_read_func_and_partitions(
outfile,
reader=reader,
data_source=data_source,
schema=schema,
max_arrow_batch_size=max_arrow_batch_size,
)

# Return the supported filter indices.
write_int(len(supported_filter_indices), outfile)
Expand Down
188 changes: 103 additions & 85 deletions python/pyspark/sql/worker/plan_data_source_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,101 @@ def batched(iterator: Iterator, n: int) -> Iterator:
yield batch


def write_read_func_and_partitions(
outfile: IO,
*,
reader: Union[DataSourceReader, DataSourceStreamReader],
data_source: DataSource,
schema: StructType,
max_arrow_batch_size: int,
) -> None:
is_streaming = isinstance(reader, DataSourceStreamReader)

# Create input converter.
converter = ArrowTableToRowsConversion._create_converter(BinaryType())

# Create output converter.
return_type = schema

def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
partition_bytes = None

# Get the partition value from the input iterator.
for batch in iterator:
# There should be only one row/column in the batch.
assert batch.num_columns == 1 and batch.num_rows == 1, (
"Expected each batch to have exactly 1 column and 1 row, "
f"but found {batch.num_columns} columns and {batch.num_rows} rows."
)
columns = [column.to_pylist() for column in batch.columns]
partition_bytes = converter(columns[0][0])

assert (
partition_bytes is not None
), "The input iterator for Python data source read function is empty."

# Deserialize the partition value.
partition = pickleSer.loads(partition_bytes)

assert partition is None or isinstance(partition, InputPartition), (
"Expected the partition value to be of type 'InputPartition', "
f"but found '{type(partition).__name__}'."
)

output_iter = reader.read(partition) # type: ignore[arg-type]

# Validate the output iterator.
if not isinstance(output_iter, Iterator):
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_INVALID_RETURN_TYPE",
messageParameters={
"type": type(output_iter).__name__,
"name": data_source.name(),
"supported_types": "iterator",
},
)

return records_to_arrow_batches(output_iter, max_arrow_batch_size, return_type, data_source)

command = (data_source_read_func, return_type)
pickleSer._write_with_length(command, outfile)

if not is_streaming:
# The partitioning of python batch source read is determined before query execution.
try:
partitions = reader.partitions() # type: ignore[call-arg]
if not isinstance(partitions, list):
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "'partitions' to return a list",
"actual": f"'{type(partitions).__name__}'",
},
)
if not all(isinstance(p, InputPartition) for p in partitions):
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "elements in 'partitions' to be of type 'InputPartition'",
"actual": partition_types,
},
)
if len(partitions) == 0:
partitions = [None] # type: ignore[list-item]
except NotImplementedError:
partitions = [None] # type: ignore[list-item]

# Return the serialized partition values.
write_int(len(partitions), outfile)
for partition in partitions:
pickleSer._write_with_length(partition, outfile)
else:
# Send an empty list of partition for stream reader because partitions are planned
# in each microbatch during query execution.
write_int(0, outfile)


def main(infile: IO, outfile: IO) -> None:
"""
Main method for planning a data source read.
Expand Down Expand Up @@ -284,91 +379,14 @@ def main(infile: IO, outfile: IO) -> None:
},
)

# Create input converter.
converter = ArrowTableToRowsConversion._create_converter(BinaryType())

# Create output converter.
return_type = schema

def data_source_read_func(iterator: Iterable[pa.RecordBatch]) -> Iterable[pa.RecordBatch]:
partition_bytes = None

# Get the partition value from the input iterator.
for batch in iterator:
# There should be only one row/column in the batch.
assert batch.num_columns == 1 and batch.num_rows == 1, (
"Expected each batch to have exactly 1 column and 1 row, "
f"but found {batch.num_columns} columns and {batch.num_rows} rows."
)
columns = [column.to_pylist() for column in batch.columns]
partition_bytes = converter(columns[0][0])

assert (
partition_bytes is not None
), "The input iterator for Python data source read function is empty."

# Deserialize the partition value.
partition = pickleSer.loads(partition_bytes)

assert partition is None or isinstance(partition, InputPartition), (
"Expected the partition value to be of type 'InputPartition', "
f"but found '{type(partition).__name__}'."
)

output_iter = reader.read(partition) # type: ignore[arg-type]

# Validate the output iterator.
if not isinstance(output_iter, Iterator):
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_INVALID_RETURN_TYPE",
messageParameters={
"type": type(output_iter).__name__,
"name": data_source.name(),
"supported_types": "iterator",
},
)

return records_to_arrow_batches(
output_iter, max_arrow_batch_size, return_type, data_source
)

command = (data_source_read_func, return_type)
pickleSer._write_with_length(command, outfile)

if not is_streaming:
# The partitioning of python batch source read is determined before query execution.
try:
partitions = reader.partitions() # type: ignore[call-arg]
if not isinstance(partitions, list):
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "'partitions' to return a list",
"actual": f"'{type(partitions).__name__}'",
},
)
if not all(isinstance(p, InputPartition) for p in partitions):
partition_types = ", ".join([f"'{type(p).__name__}'" for p in partitions])
raise PySparkRuntimeError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
"expected": "elements in 'partitions' to be of type 'InputPartition'",
"actual": partition_types,
},
)
if len(partitions) == 0:
partitions = [None] # type: ignore[list-item]
except NotImplementedError:
partitions = [None] # type: ignore[list-item]

# Return the serialized partition values.
write_int(len(partitions), outfile)
for partition in partitions:
pickleSer._write_with_length(partition, outfile)
else:
# Send an empty list of partition for stream reader because partitions are planned
# in each microbatch during query execution.
write_int(0, outfile)
# Send the read function and partitions to the JVM.
write_read_func_and_partitions(
outfile,
reader=reader,
data_source=data_source,
schema=schema,
max_arrow_batch_size=max_arrow_batch_size,
)
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,23 @@ class PythonDataSourceV2 extends TableProvider {
dataSourceInPython
}

def setDataSourceInPython(dataSourceInPython: PythonDataSourceCreationResult): Unit = {
this.dataSourceInPython = dataSourceInPython
private var readInfo: PythonDataSourceReadInfo = _

def getOrCreateReadInfo(
shortName: String,
options: CaseInsensitiveStringMap,
outputSchema: StructType,
isStreaming: Boolean
): PythonDataSourceReadInfo = {
if (readInfo == null) {
val creationResult = getOrCreateDataSourceInPython(shortName, options, Some(outputSchema))
readInfo = source.createReadInfoInPython(creationResult, outputSchema, isStreaming)
}
readInfo
}

def setReadInfo(readInfo: PythonDataSourceReadInfo): Unit = {
this.readInfo = readInfo
}

override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,7 @@ class PythonMicroBatchStream(
}

private lazy val readInfo: PythonDataSourceReadInfo = {
ds.source.createReadInfoInPython(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
outputSchema,
isStreaming = true)
ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = true)
}

override def createReaderFactory(): PartitionReaderFactory = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,7 @@ class PythonBatch(
private val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)

private lazy val infoInPython: PythonDataSourceReadInfo = {
ds.source.createReadInfoInPython(
ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)),
outputSchema,
isStreaming = false)
ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = false)
}

override def planInputPartitions(): Array[InputPartition] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,19 @@ class PythonScanBuilder(
}

val dataSource = ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema))
val result = ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters)

// The Data Source instance state changes after pushdown to remember the reader instance
// created and the filters pushed down. So pushdownFiltersInPython returns a new pickled
// Data Source instance. We need to use that new instance for further operations.
ds.setDataSourceInPython(dataSource.copy(dataSource = result.dataSource))

// Partition the filters into supported and unsupported ones.
val isPushed = result.isFilterPushed.zip(filters)
supportedFilters = isPushed.collect { case (true, filter) => filter }.toArray
val unsupported = isPushed.collect { case (false, filter) => filter }.toArray
unsupported
ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters) match {
case None => filters // No filters are supported.
case Some(result) =>
// Filter pushdown also returns partitions and the read function.
// This helps reduce the number of Python worker calls.
ds.setReadInfo(result.readInfo)

// Partition the filters into supported and unsupported ones.
val isPushed = result.isFilterPushed.zip(filters)
supportedFilters = isPushed.collect { case (true, filter) => filter }.toArray
val unsupported = isPushed.collect { case (false, filter) => filter }.toArray
unsupported
}
}

override def pushedFilters(): Array[Filter] = supportedFilters
Expand Down
Loading