Skip to content

Commit

Permalink
[SPARK-49607][PYTHON] Update the sampling approach for sampled based …
Browse files Browse the repository at this point in the history
…plots

### What changes were proposed in this pull request?
1, Update the sampling approach for sampled based plots
2, Eliminate "spark.sql.pyspark.plotting.sample_ratio" config

### Why are the changes needed?
1, to be consistent with the PS plotting;
2, the "spark.sql.pyspark.plotting.sample_ratio" config is not friendly to large scale data: the plotting backend cannot render large number of data points efficiently, and it is hard for users to set an appropriate sample ratio;

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ci

### Was this patch authored or co-authored using generative AI tooling?
no

Closes apache#48218 from zhengruifeng/py_plot_sampling.

Authored-by: Ruifeng Zheng <ruifengz@apache.org>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
zhengruifeng authored and HyukjinKwon committed Sep 24, 2024
1 parent 35e5d29 commit 64ea50e
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 38 deletions.
36 changes: 27 additions & 9 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,45 @@ def get_top_n(self, sdf: "DataFrame") -> "pd.DataFrame":

class PySparkSampledPlotBase:
def get_sampled(self, sdf: "DataFrame") -> "pd.DataFrame":
from pyspark.sql import SparkSession
from pyspark.sql import SparkSession, Observation, functions as F

session = SparkSession.getActiveSession()
if session is None:
raise PySparkRuntimeError(errorClass="NO_ACTIVE_SESSION", messageParameters=dict())

sample_ratio = session.conf.get("spark.sql.pyspark.plotting.sample_ratio")
max_rows = int(
session.conf.get("spark.sql.pyspark.plotting.max_rows") # type: ignore[arg-type]
)

if sample_ratio is None:
fraction = 1 / (sdf.count() / max_rows)
fraction = min(1.0, fraction)
else:
fraction = float(sample_ratio)
observation = Observation("pyspark plotting")

sampled_sdf = sdf.sample(fraction=fraction)
rand_col_name = "__pyspark_plotting_sampled_plot_base_rand__"
id_col_name = "__pyspark_plotting_sampled_plot_base_id__"

sampled_sdf = (
sdf.observe(observation, F.count(F.lit(1)).alias("count"))
.select(
"*",
F.rand().alias(rand_col_name),
F.monotonically_increasing_id().alias(id_col_name),
)
.sort(rand_col_name)
.limit(max_rows + 1)
.coalesce(1)
.sortWithinPartitions(id_col_name)
.drop(rand_col_name, id_col_name)
)
pdf = sampled_sdf.toPandas()

return pdf
if len(pdf) > max_rows:
try:
self.fraction = float(max_rows) / observation.get["count"]
except Exception:
pass
return pdf[:max_rows]
else:
self.fraction = 1.0
return pdf


class PySparkPlotAccessor:
Expand Down
14 changes: 1 addition & 13 deletions python/pyspark/sql/tests/plot/test_frame_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,23 +39,11 @@ def test_backend(self):
)

def test_topn_max_rows(self):
try:
with self.sql_conf({"spark.sql.pyspark.plotting.max_rows": "1000"}):
self.spark.conf.set("spark.sql.pyspark.plotting.max_rows", "1000")
sdf = self.spark.range(2500)
pdf = PySparkTopNPlotBase().get_top_n(sdf)
self.assertEqual(len(pdf), 1000)
finally:
self.spark.conf.unset("spark.sql.pyspark.plotting.max_rows")

def test_sampled_plot_with_ratio(self):
try:
self.spark.conf.set("spark.sql.pyspark.plotting.sample_ratio", "0.5")
data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2500)]
sdf = self.spark.createDataFrame(data)
pdf = PySparkSampledPlotBase().get_sampled(sdf)
self.assertEqual(round(len(pdf) / 2500, 1), 0.5)
finally:
self.spark.conf.unset("spark.sql.pyspark.plotting.sample_ratio")

def test_sampled_plot_with_max_rows(self):
data = [Row(a=i, b=i + 1, c=i + 2, d=i + 3) for i in range(2000)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3178,20 +3178,6 @@ object SQLConf {
.intConf
.createWithDefault(1000)

val PYSPARK_PLOT_SAMPLE_RATIO =
buildConf("spark.sql.pyspark.plotting.sample_ratio")
.doc(
"The proportion of data that will be plotted for sample-based plots. It is determined " +
"based on spark.sql.pyspark.plotting.max_rows if not explicitly set."
)
.version("4.0.0")
.doubleConf
.checkValue(
ratio => ratio >= 0.0 && ratio <= 1.0,
"The value should be between 0.0 and 1.0 inclusive."
)
.createOptional

val ARROW_SPARKR_EXECUTION_ENABLED =
buildConf("spark.sql.execution.arrow.sparkr.enabled")
.doc("When true, make use of Apache Arrow for columnar data transfers in SparkR. " +
Expand Down Expand Up @@ -5907,8 +5893,6 @@ class SQLConf extends Serializable with Logging with SqlApiConf {

def pysparkPlotMaxRows: Int = getConf(PYSPARK_PLOT_MAX_ROWS)

def pysparkPlotSampleRatio: Option[Double] = getConf(PYSPARK_PLOT_SAMPLE_RATIO)

def arrowSparkREnabled: Boolean = getConf(ARROW_SPARKR_EXECUTION_ENABLED)

def arrowPySparkFallbackEnabled: Boolean = getConf(ARROW_PYSPARK_FALLBACK_ENABLED)
Expand Down

0 comments on commit 64ea50e

Please sign in to comment.