-
Notifications
You must be signed in to change notification settings - Fork 29k
[WIP][SPARK-27463][PYTHON] Support Dataframe Cogroup via Pandas UDFs #24965
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2e0b308
64ff5ac
6d039e3
d8a5c5d
73188f6
690fa14
c86b2bf
e3b66ac
8007fa6
86d1385
d15dabb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| from pyspark.sql.dataframe import DataFrame | ||
|
|
||
|
|
||
| class CoGroupedData(object): | ||
|
|
||
| def __init__(self, gd1, gd2): | ||
| self._gd1 = gd1 | ||
| self._gd2 = gd2 | ||
| self.sql_ctx = gd1.sql_ctx | ||
|
|
||
| def apply(self, udf): | ||
| all_cols = self._extract_cols(self._gd1) + self._extract_cols(self._gd2) | ||
| udf_column = udf(*all_cols) | ||
| jdf = self._gd1._jgd.flatMapCoGroupsInPandas(self._gd2._jgd, udf_column._jc.expr()) | ||
| return DataFrame(jdf, self.sql_ctx) | ||
|
|
||
| @staticmethod | ||
| def _extract_cols(gd): | ||
| df = gd._df | ||
| return [df[col] for col in df.columns] | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| # | ||
| # Licensed to the Apache Software Foundation (ASF) under one or more | ||
| # contributor license agreements. See the NOTICE file distributed with | ||
| # this work for additional information regarding copyright ownership. | ||
| # The ASF licenses this file to You under the Apache License, Version 2.0 | ||
| # (the "License"); you may not use this file except in compliance with | ||
| # the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| # | ||
|
|
||
| import datetime | ||
| import unittest | ||
| import sys | ||
|
|
||
| from collections import OrderedDict | ||
| from decimal import Decimal | ||
|
|
||
| from pyspark.sql import Row | ||
| from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType | ||
| from pyspark.sql.types import * | ||
| from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ | ||
| pandas_requirement_message, pyarrow_requirement_message | ||
| from pyspark.testing.utils import QuietTest | ||
|
|
||
| if have_pandas: | ||
| import pandas as pd | ||
| from pandas.util.testing import assert_frame_equal | ||
|
|
||
| if have_pyarrow: | ||
| import pyarrow as pa | ||
|
|
||
|
|
||
| """ | ||
| Tests below use pd.DataFrame.assign that will infer mixed types (unicode/str) for column names | ||
| from kwargs w/ Python 2, so need to set check_column_type=False and avoid this check | ||
| """ | ||
| if sys.version < '3': | ||
| _check_column_type = False | ||
| else: | ||
| _check_column_type = True | ||
|
|
||
|
|
||
| @unittest.skipIf( | ||
| not have_pandas or not have_pyarrow, | ||
| pandas_requirement_message or pyarrow_requirement_message) | ||
| class CoGroupedMapPandasUDFTests(ReusedSQLTestCase): | ||
|
|
||
| @property | ||
| def data1(self): | ||
| return self.spark.range(10).toDF('id') \ | ||
| .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ | ||
| .withColumn("k", explode(col('ks')))\ | ||
| .withColumn("v", col('k') * 10)\ | ||
| .drop('ks') | ||
|
|
||
| @property | ||
| def data2(self): | ||
| return self.spark.range(10).toDF('id') \ | ||
| .withColumn("ks", array([lit(i) for i in range(20, 30)])) \ | ||
| .withColumn("k", explode(col('ks'))) \ | ||
| .withColumn("v2", col('k') * 100) \ | ||
| .drop('ks') | ||
|
|
||
| def test_simple(self): | ||
| import pandas as pd | ||
|
|
||
| l = self.data1 | ||
| r = self.data2 | ||
|
|
||
| @pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP) | ||
| def merge_pandas(left, right): | ||
| return pd.merge(left, right, how='outer', on=['k', 'id']) | ||
|
|
||
| result = l\ | ||
| .groupby('id')\ | ||
| .cogroup(r.groupby(r.id))\ | ||
| .apply(merge_pandas)\ | ||
| .sort(['id', 'k'])\ | ||
| .toPandas() | ||
|
|
||
| expected = pd\ | ||
| .merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id']) | ||
|
|
||
| assert_frame_equal(expected, result, check_column_type=_check_column_type) | ||
|
|
||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @d80tb7, I work with Li and am also interested in cogroup. Can I ask how you were able to get your test to run? I wasn't able to run it without the following snippet: taken from the other similar tests like test_pandas_udf_grouped_map.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @hjoo So far I've just been running it via PyCharm's unit test runner under python 3. I suspect the problem you had was that the iterator I added wasn't compatible with python 2 (sorry!). I've fixed the iterator and added a similar snippet to the one you provided above. Now I can run using If you still have problems let me know the error you get and I'll take a look. |
||
| if __name__ == "__main__": | ||
| from pyspark.sql.tests.test_pandas_udf_cogrouped_map import * | ||
|
|
||
| try: | ||
| import xmlrunner | ||
| testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) | ||
| except ImportError: | ||
| testRunner = None | ||
| unittest.main(testRunner=testRunner, verbosity=2) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,7 +38,7 @@ | |
| from pyspark.rdd import PythonEvalType | ||
| from pyspark.serializers import write_with_length, write_int, read_long, read_bool, \ | ||
| write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ | ||
| BatchedSerializer, ArrowStreamPandasUDFSerializer | ||
| BatchedSerializer, ArrowStreamPandasUDFSerializer, InterleavedArrowStreamPandasSerializer | ||
| from pyspark.sql.types import to_arrow_type, StructType | ||
| from pyspark.util import _get_argspec, fail_on_stopiteration | ||
| from pyspark import shuffle | ||
|
|
@@ -111,8 +111,25 @@ def verify_result_length(result, length): | |
| map(verify_result_type, f(*iterator))) | ||
|
|
||
|
|
||
| def wrap_grouped_map_pandas_udf(f, return_type, argspec): | ||
| def wrap_cogrouped_map_pandas_udf(f, return_type): | ||
|
|
||
| def wrapped(left, right): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes they are- they are value series for left and right sides of the cogroup respectively. Agreed that the names aren't the best. I'll improve them when I do a tidy up. |
||
| import pandas as pd | ||
| result = f(pd.concat(left, axis=1), pd.concat(right, axis=1)) | ||
| if not isinstance(result, pd.DataFrame): | ||
| raise TypeError("Return type of the user-defined function should be " | ||
| "pandas.DataFrame, but is {}".format(type(result))) | ||
| if not len(result.columns) == len(return_type): | ||
| raise RuntimeError( | ||
| "Number of columns of the returned pandas.DataFrame " | ||
| "doesn't match specified schema. " | ||
| "Expected: {} Actual: {}".format(len(return_type), len(result.columns))) | ||
| return result | ||
|
|
||
| return lambda v: [(wrapped(v[0], v[1]), to_arrow_type(return_type))] | ||
|
|
||
|
|
||
| def wrap_grouped_map_pandas_udf(f, return_type, argspec): | ||
| def wrapped(key_series, value_series): | ||
| import pandas as pd | ||
|
|
||
|
|
@@ -232,6 +249,8 @@ def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): | |
| elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: | ||
| argspec = _get_argspec(chained_func) # signature was lost when wrapping it | ||
| return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) | ||
| elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: | ||
| return arg_offsets, wrap_cogrouped_map_pandas_udf(func, return_type) | ||
| elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: | ||
| return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) | ||
| elif eval_type == PythonEvalType.SQL_WINDOW_AGG_PANDAS_UDF: | ||
|
|
@@ -246,6 +265,7 @@ def read_udfs(pickleSer, infile, eval_type): | |
| runner_conf = {} | ||
|
|
||
| if eval_type in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, | ||
| PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF, | ||
| PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF, | ||
| PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, | ||
| PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF, | ||
|
|
@@ -269,10 +289,13 @@ def read_udfs(pickleSer, infile, eval_type): | |
|
|
||
| # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of | ||
| # pandas Series. See SPARK-27240. | ||
| df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or | ||
| if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: | ||
| ser = InterleavedArrowStreamPandasSerializer(timezone, safecheck, assign_cols_by_name) | ||
| else: | ||
| df_for_struct = (eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF or | ||
| eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF) | ||
| ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, | ||
| df_for_struct) | ||
| ser = ArrowStreamPandasUDFSerializer(timezone, safecheck, assign_cols_by_name, | ||
| df_for_struct) | ||
| else: | ||
| ser = BatchedSerializer(PickleSerializer(), 100) | ||
|
|
||
|
|
@@ -343,6 +366,14 @@ def map_batch(batch): | |
| arg0 = ["a[%d]" % o for o in arg_offsets[1: split_offset]] | ||
| arg1 = ["a[%d]" % o for o in arg_offsets[split_offset:]] | ||
| mapper_str = "lambda a: f([%s], [%s])" % (", ".join(arg0), ", ".join(arg1)) | ||
| elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: | ||
| # We assume there is only one UDF here because cogrouped map doesn't | ||
| # support combining multiple UDFs. | ||
| assert num_udfs == 1 | ||
| arg_offsets, udf = read_single_udf( | ||
| pickleSer, infile, eval_type, runner_conf, udf_index=0) | ||
| udfs['f'] = udf | ||
| mapper_str = "lambda a: f(a)" | ||
| else: | ||
| # Create function like this: | ||
| # lambda a: (f0(a[0]), f1(a[1], a[2]), f2(a[3])) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,18 @@ case class FlatMapGroupsInPandas( | |
| override val producedAttributes = AttributeSet(output) | ||
| } | ||
|
|
||
|
|
||
| case class FlatMapCoGroupsInPandas( | ||
| leftAttributes: Seq[Attribute], | ||
| rightAttributes: Seq[Attribute], | ||
| functionExpr: Expression, | ||
| output: Seq[Attribute], | ||
| left: LogicalPlan, | ||
| right: LogicalPlan) extends BinaryNode { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (BTW, checkout this https://github.com/databricks/scala-style-guide) |
||
| override val producedAttributes = AttributeSet(output) | ||
| } | ||
|
|
||
|
|
||
| trait BaseEvalPython extends UnaryNode { | ||
|
|
||
| def udfs: Seq[PythonUDF] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to read these also using the message reader but for some reason pa.read_schema(self_reader.read_next_message()) didn't work.