Skip to content

Commit 6527b86

Browse files
committed
Merge pull request #8 from davies/col-computability
fix python tests
2 parents fd92bc7 + f79034c commit 6527b86

File tree

7 files changed

+69
-37
lines changed

7 files changed

+69
-37
lines changed

python/pyspark/sql.py

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,6 +2124,10 @@ def head(self, n=None):
21242124
return rs[0] if rs else None
21252125
return self.take(n)
21262126

2127+
def first(self):
2128+
""" Return the first row. """
2129+
return self.head()
2130+
21272131
def tail(self):
21282132
raise NotImplemented
21292133

@@ -2159,7 +2163,7 @@ def select(self, *cols):
21592163
else:
21602164
cols = [c._jc for c in cols]
21612165
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
2162-
jdf = self._jdf.select(self._jdf.toColumnArray(jcols))
2166+
jdf = self._jdf.select(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
21632167
return DataFrame(jdf, self.sql_ctx)
21642168

21652169
def filter(self, condition):
@@ -2189,7 +2193,7 @@ def groupBy(self, *cols):
21892193
else:
21902194
cols = [c._jc for c in cols]
21912195
jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client)
2192-
jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols))
2196+
jdf = self._jdf.groupBy(self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
21932197
return GroupedDataFrame(jdf, self.sql_ctx)
21942198

21952199
def agg(self, *exprs):
@@ -2278,14 +2282,17 @@ def agg(self, *exprs):
22782282
:param exprs: list or aggregate columns or a map from column
22792283
name to agregate methods.
22802284
"""
2285+
assert exprs, "exprs should not be empty"
22812286
if len(exprs) == 1 and isinstance(exprs[0], dict):
22822287
jmap = MapConverter().convert(exprs[0],
22832288
self.sql_ctx._sc._gateway._gateway_client)
22842289
jdf = self._jdf.agg(jmap)
22852290
else:
22862291
# Columns
2287-
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns"
2288-
jdf = self._jdf.agg(*exprs)
2292+
assert all(isinstance(c, Column) for c in exprs), "all exprs should be Column"
2293+
jcols = ListConverter().convert([c._jc for c in exprs[1:]],
2294+
self.sql_ctx._sc._gateway._gateway_client)
2295+
jdf = self._jdf.agg(exprs[0]._jc, self.sql_ctx._sc._jvm.Dsl.toColumns(jcols))
22892296
return DataFrame(jdf, self.sql_ctx)
22902297

22912298
@dfapi
@@ -2347,7 +2354,7 @@ def _create_column_from_literal(literal):
23472354

23482355
def _create_column_from_name(name):
23492356
sc = SparkContext._active_spark_context
2350-
return sc._jvm.Column(name)
2357+
return sc._jvm.IncomputableColumn(name)
23512358

23522359

23532360
def _scalaMethod(name):
@@ -2371,7 +2378,7 @@ def _(self):
23712378
return _
23722379

23732380

2374-
def _bin_op(name, pass_literal_through=False):
2381+
def _bin_op(name, pass_literal_through=True):
23752382
""" Create a method for given binary operator
23762383
23772384
Keyword arguments:
@@ -2465,18 +2472,17 @@ def __init__(self, jc, jdf=None, sql_ctx=None):
24652472
# __getattr__ = _bin_op("getField")
24662473

24672474
# string methods
2468-
rlike = _bin_op("rlike", pass_literal_through=True)
2469-
like = _bin_op("like", pass_literal_through=True)
2470-
startswith = _bin_op("startsWith", pass_literal_through=True)
2471-
endswith = _bin_op("endsWith", pass_literal_through=True)
2475+
rlike = _bin_op("rlike")
2476+
like = _bin_op("like")
2477+
startswith = _bin_op("startsWith")
2478+
endswith = _bin_op("endsWith")
24722479
upper = _unary_op("upper")
24732480
lower = _unary_op("lower")
24742481

24752482
def substr(self, startPos, pos):
24762483
if type(startPos) != type(pos):
24772484
raise TypeError("Can not mix the type")
24782485
if isinstance(startPos, (int, long)):
2479-
24802486
jc = self._jc.substr(startPos, pos)
24812487
elif isinstance(startPos, Column):
24822488
jc = self._jc.substr(startPos._jc, pos._jc)
@@ -2507,30 +2513,53 @@ def cast(self, dataType):
25072513
return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx)
25082514

25092515

2516+
def _to_java_column(col):
2517+
if isinstance(col, Column):
2518+
jcol = col._jc
2519+
else:
2520+
jcol = _create_column_from_name(col)
2521+
return jcol
2522+
2523+
25102524
def _aggregate_func(name):
25112525
""" Create a function for aggregator by name"""
25122526
def _(col):
25132527
sc = SparkContext._active_spark_context
2514-
if isinstance(col, Column):
2515-
jcol = col._jc
2516-
else:
2517-
jcol = _create_column_from_name(col)
2518-
jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol)
2528+
jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
25192529
return Column(jc)
2530+
25202531
return staticmethod(_)
25212532

25222533

25232534
class Aggregator(object):
25242535
"""
25252536
A collections of builtin aggregators
25262537
"""
2527-
max = _aggregate_func("max")
2528-
min = _aggregate_func("min")
2529-
avg = mean = _aggregate_func("mean")
2530-
sum = _aggregate_func("sum")
2531-
first = _aggregate_func("first")
2532-
last = _aggregate_func("last")
2533-
count = _aggregate_func("count")
2538+
AGGS = [
2539+
'lit', 'col', 'column', 'upper', 'lower', 'sqrt', 'abs',
2540+
'min', 'max', 'first', 'last', 'count', 'avg', 'mean', 'sum', 'sumDistinct',
2541+
]
2542+
for _name in AGGS:
2543+
locals()[_name] = _aggregate_func(_name)
2544+
del _name
2545+
2546+
@staticmethod
2547+
def countDistinct(col, *cols):
2548+
sc = SparkContext._active_spark_context
2549+
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
2550+
sc._gateway._gateway_client)
2551+
jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
2552+
sc._jvm.Dsl.toColumns(jcols))
2553+
return Column(jc)
2554+
2555+
@staticmethod
2556+
def approxCountDistinct(col, rsd=None):
2557+
sc = SparkContext._active_spark_context
2558+
if rsd is None:
2559+
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
2560+
else:
2561+
jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
2562+
return Column(jc)
25342563

25352564

25362565
def _test():

python/pyspark/tests.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,9 +1029,11 @@ def test_aggregator(self):
10291029
g = df.groupBy()
10301030
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
10311031
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
1032-
# TODO(davies): fix aggregators
1032+
10331033
from pyspark.sql import Aggregator as Agg
1034-
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
1034+
self.assertEqual((0, u'99'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))
1035+
self.assertTrue(95 < g.agg(Agg.approxCountDistinct(df.key)).first()[0])
1036+
self.assertEqual(100, g.agg(Agg.countDistinct(df.value)).first()[0])
10351037

10361038
def test_help_command(self):
10371039
# Regression test for SPARK-5464

sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -500,10 +500,6 @@ trait DataFrame extends DataFrameSpecificApi with RDDApi[Row] {
500500
////////////////////////////////////////////////////////////////////////////
501501
// for Python API
502502
////////////////////////////////////////////////////////////////////////////
503-
/**
504-
* A helpful function for Py4j, convert a list of Column to an array
505-
*/
506-
protected[sql] def toColumnArray(cols: JList[Column]): Array[Column]
507503

508504
/**
509505
* Converts a JavaRDD to a PythonRDD.

sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,6 @@ private[sql] class DataFrameImpl protected[sql](
323323
////////////////////////////////////////////////////////////////////////////
324324
// for Python API
325325
////////////////////////////////////////////////////////////////////////////
326-
protected[sql] override def toColumnArray(cols: JList[Column]): Array[Column] = {
327-
cols.toList.toArray
328-
}
329-
330326
protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = {
331327
val fieldTypes = schema.fields.map(_.dataType)
332328
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()

sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.util.{List => JList}
21+
2022
import scala.language.implicitConversions
2123
import scala.reflect.runtime.universe.{TypeTag, typeTag}
24+
import scala.collection.JavaConversions._
2225

2326
import org.apache.spark.sql.catalyst.ScalaReflection
2427
import org.apache.spark.sql.catalyst.expressions._
@@ -105,8 +108,7 @@ object Dsl {
105108
def countDistinct(expr: Column, exprs: Column*): Column =
106109
CountDistinct((expr +: exprs).map(_.expr))
107110

108-
def approxCountDistinct(e: Column): Column =
109-
ApproxCountDistinct(e.expr)
111+
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
110112
def approxCountDistinct(e: Column, rsd: Double): Column =
111113
ApproxCountDistinct(e.expr, rsd)
112114

@@ -121,6 +123,13 @@ object Dsl {
121123
def sqrt(e: Column): Column = Sqrt(e.expr)
122124
def abs(e: Column): Column = Abs(e.expr)
123125

126+
/**
127+
* This is a private API for Python
128+
* TODO: move this to a private package
129+
*/
130+
def toColumns(cols: JList[Column]): Seq[Column] = {
131+
cols.toList.toSeq
132+
}
124133

125134
// scalastyle:off
126135

sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql
1919

20+
import java.util.{List => JList}
21+
2022
import scala.language.implicitConversions
2123
import scala.collection.JavaConversions._
2224

sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,5 @@ private[sql] class IncomputableColumn(protected[sql] val expr: Expression) exten
156156

157157
override def toJSON: RDD[String] = err()
158158

159-
protected[sql] override def toColumnArray(cols: java.util.List[Column]): Array[Column] = err()
160-
161159
protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err()
162160
}

0 commit comments

Comments
 (0)