Skip to content

Commit

Permalink
[SPARK-49776][PYTHON][CONNECT] Support pie plots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support area plots with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes. Area plots are supported as shown below.

```py
>>> from datetime import datetime
>>> data = [
...     (3, 5, 20, datetime(2018, 1, 31)),
...     (2, 5, 42, datetime(2018, 2, 28)),
...     (3, 6, 28, datetime(2018, 3, 31)),
...     (9, 12, 62, datetime(2018, 4, 30))]
>>> columns = ["sales", "signups", "visits", "date"]
>>> df = spark.createDataFrame(data, columns)
>>> fig = df.plot(kind="pie", x="date", y="sales")  # df.plot(kind="pie", x="date", y="sales")
>>> fig.show()
```
![newplot (8)](https://github.com/user-attachments/assets/c4078bb7-4d84-4607-bcd7-bdd6fbbf8e28)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48256 from xinrong-meng/plot_pie.

Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Xinrong Meng <xinrong@apache.org>
  • Loading branch information
xinrong-meng authored and himadripal committed Oct 19, 2024
1 parent bc3ce4a commit e4c5566
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 1 deletion.
5 changes: 5 additions & 0 deletions python/pyspark/errors/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,11 @@
"Pipe function `<func_name>` exited with error code <error_code>."
]
},
"PLOT_NOT_NUMERIC_COLUMN": {
"message": [
"Argument <arg_name> must be a numerical column for plotting, got <arg_type>."
]
},
"PYTHON_HASH_SEED_NOT_SET": {
"message": [
"Randomness of hash of string should be disabled via PYTHONHASHSEED."
Expand Down
41 changes: 40 additions & 1 deletion python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

from typing import Any, TYPE_CHECKING, Optional, Union
from types import ModuleType
from pyspark.errors import PySparkRuntimeError, PySparkValueError
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.sql.types import NumericType
from pyspark.sql.utils import require_minimum_plotly_version


Expand Down Expand Up @@ -97,6 +98,7 @@ class PySparkPlotAccessor:
"bar": PySparkTopNPlotBase().get_top_n,
"barh": PySparkTopNPlotBase().get_top_n,
"line": PySparkSampledPlotBase().get_sampled,
"pie": PySparkTopNPlotBase().get_top_n,
"scatter": PySparkSampledPlotBase().get_sampled,
}
_backends = {} # type: ignore[var-annotated]
Expand Down Expand Up @@ -299,3 +301,40 @@ def area(self, x: str, y: str, **kwargs: Any) -> "Figure":
>>> df.plot.area(x='date', y=['sales', 'signups', 'visits']) # doctest: +SKIP
"""
return self(kind="area", x=x, y=y, **kwargs)

def pie(self, x: str, y: str, **kwargs: Any) -> "Figure":
"""
Generate a pie plot.
A pie plot is a proportional representation of the numerical data in a
column.
Parameters
----------
x : str
Name of column to be used as the category labels for the pie plot.
y : str
Name of the column to plot.
**kwargs
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
"""
schema = self.data.schema

# Check if 'y' is a numerical column
y_field = schema[y] if y in schema.names else None
if y_field is None or not isinstance(y_field.dataType, NumericType):
raise PySparkTypeError(
errorClass="PLOT_NOT_NUMERIC_COLUMN",
messageParameters={
"arg_name": "y",
"arg_type": str(y_field.dataType) if y_field else "None",
},
)
return self(kind="pie", x=x, y=y, **kwargs)
15 changes: 15 additions & 0 deletions python/pyspark/sql/plot/plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,19 @@
def plot_pyspark(data: "DataFrame", kind: str, **kwargs: Any) -> "Figure":
import plotly

if kind == "pie":
return plot_pie(data, **kwargs)

return plotly.plot(PySparkPlotAccessor.plot_data_map[kind](data), kind, **kwargs)


def plot_pie(data: "DataFrame", **kwargs: Any) -> "Figure":
# TODO(SPARK-49530): Support pie subplots with plotly backend
from plotly import express

pdf = PySparkPlotAccessor.plot_data_map["pie"](data)
x = kwargs.pop("x", None)
y = kwargs.pop("y", None)
fig = express.pie(pdf, values=y, names=x, **kwargs)

return fig
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from datetime import datetime

import pyspark.sql.plot # noqa: F401
from pyspark.errors import PySparkTypeError
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_plotly, plotly_requirement_message


Expand Down Expand Up @@ -64,6 +65,11 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")
self.assertEqual(fig_data["mode"], "lines")
elif kind == "pie":
self.assertEqual(fig_data["type"], "pie")
self.assertEqual(list(fig_data["labels"]), expected_x)
self.assertEqual(list(fig_data["values"]), expected_y)
return

self.assertEqual(fig_data["xaxis"], "x")
self.assertEqual(list(fig_data["x"]), expected_x)
Expand Down Expand Up @@ -133,6 +139,25 @@ def test_area_plot(self):
self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups")
self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits")

def test_pie_plot(self):
fig = self.sdf3.plot(kind="pie", x="date", y="sales")
expected_x = [
datetime(2018, 1, 31, 0, 0),
datetime(2018, 2, 28, 0, 0),
datetime(2018, 3, 31, 0, 0),
datetime(2018, 4, 30, 0, 0),
]
self._check_fig_data("pie", fig["data"][0], expected_x, [3, 2, 3, 9])

# y is not a numerical column
with self.assertRaises(PySparkTypeError) as pe:
self.sdf.plot.pie(x="int_val", y="category")
self.check_error(
exception=pe.exception,
errorClass="PLOT_NOT_NUMERIC_COLUMN",
messageParameters={"arg_name": "y", "arg_type": "StringType()"},
)


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
Expand Down

0 comments on commit e4c5566

Please sign in to comment.