Skip to content

Commit d11d5b9

Browse files
author
云峤
committed
[SPARK-7294] ADD BETWEEN
1 parent a9fc505 commit d11d5b9

File tree

5 files changed

+49
-0
lines changed

5 files changed

+49
-0
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,18 @@ def cast(self, dataType):
12891289
raise TypeError("unexpected type: %s" % type(dataType))
12901290
return Column(jc)
12911291

1292+
@ignore_unicode_prefix
1293+
def between(self, col1, col2):
1294+
""" A boolean expression that is evaluated to true if the value of this
1295+
expression is between the given columns.
1296+
1297+
>>> df[df.col1.between(col2, col3)].collect()
1298+
[Row(col1=5, col2=6, col3=8)]
1299+
"""
1300+
#sc = SparkContext._active_spark_context
1301+
jc = self > col1 & self < col2
1302+
return Column(jc)
1303+
12921304
def __repr__(self):
12931305
return 'Column<%s>' % self._jc.toString().encode('utf8')
12941306

python/pyspark/sql/tests.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,12 @@ def test_rand_functions(self):
426426
for row in rndn:
427427
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
428428

429+
def test_between_function(self):
430+
df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF()
431+
self.assertEqual([False, True, False],
432+
df.select(df.a.between(df.b, df.c)).collect())
433+
434+
429435
def test_save_and_load(self):
430436
df = self.df
431437
tmpPath = tempfile.mkdtemp()

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,20 @@ class Column(protected[sql] val expr: Expression) extends Logging {
295295
*/
296296
def eqNullSafe(other: Any): Column = this <=> other
297297

298+
/**
299+
* Between col1 and col2.
300+
*
301+
* @group java_expr_ops
302+
*/
303+
def between(col1: String, col2: String): Column = between(Column(col1), Column(col2))
304+
305+
/**
306+
* Between col1 and col2.
307+
*
308+
* @group java_expr_ops
309+
*/
310+
def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr))
311+
298312
/**
299313
* True if the current expression is null.
300314
*

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,12 @@ class ColumnExpressionSuite extends QueryTest {
208208
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
209209
}
210210

211+
test("between") {
212+
checkAnswer(
213+
testData4.filter($"a".between($"b", $"c")),
214+
testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2)))
215+
}
216+
211217
val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
212218
Row(false, false) ::
213219
Row(false, true) ::

sql/core/src/test/scala/org/apache/spark/sql/TestData.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,17 @@ object TestData {
5757
TestData2(3, 2) :: Nil, 2).toDF()
5858
testData2.registerTempTable("testData2")
5959

60+
case class TestData4(a: Int, b: Int, c: Int)
61+
val testData4 =
62+
TestSQLContext.sparkContext.parallelize(
63+
TestData4(0, 1, 2) ::
64+
TestData4(1, 2, 3) ::
65+
TestData4(2, 1, 0) ::
66+
TestData4(2, 2, 4) ::
67+
TestData4(3, 1, 6) ::
68+
TestData4(3, 2, 0) :: Nil, 2).toDF()
69+
testData4.registerTempTable("TestData4")
70+
6071
case class DecimalData(a: BigDecimal, b: BigDecimal)
6172

6273
val decimalData =

0 commit comments

Comments
 (0)