Skip to content

Commit baf839b

Browse files
author
云峤
committed
[SPARK-7294] ADD BETWEEN
1 parent d11d5b9 commit baf839b

File tree

4 files changed

+15
-11
lines changed

4 files changed

+15
-11
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,15 +1290,14 @@ def cast(self, dataType):
12901290
return Column(jc)
12911291

12921292
@ignore_unicode_prefix
1293-
def between(self, col1, col2):
1293+
def between(self, lowerBound, upperBound):
12941294
""" A boolean expression that is evaluated to true if the value of this
12951295
expression is between the given columns.
12961296
1297-
>>> df[df.col1.between(col2, col3)].collect()
1297+
>>> df[df.col1.between(lowerBound, upperBound)].collect()
12981298
[Row(col1=5, col2=6, col3=8)]
12991299
"""
1300-
#sc = SparkContext._active_spark_context
1301-
jc = self > col1 & self < col2
1300+
jc = (self >= lowerBound) & (self <= upperBound)
13021301
return Column(jc)
13031302

13041303
def __repr__(self):

python/pyspark/sql/tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,8 @@ def test_rand_functions(self):
427427
assert row[1] >= -4.0 and row[1] <= 4.0, "got: %s" % row[1]
428428

429429
def test_between_function(self):
430-
df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=3)]).toDF()
431-
self.assertEqual([False, True, False],
430+
df = self.sqlCtx.parallelize([Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]).toDF()
431+
self.assertEqual([False, True, True],
432432
df.select(df.a.between(df.b, df.c)).collect())
433433

434434

sql/core/src/main/scala/org/apache/spark/sql/Column.scala

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,18 +296,23 @@ class Column(protected[sql] val expr: Expression) extends Logging {
296296
def eqNullSafe(other: Any): Column = this <=> other
297297

298298
/**
299-
* Between col1 and col2.
299+
* True if the current column is between the lower bound and upper bound, inclusive.
300300
*
301301
* @group java_expr_ops
302302
*/
303-
def between(col1: String, col2: String): Column = between(Column(col1), Column(col2))
303+
def between(lowerBound: String, upperBound: String): Column = {
304+
between(Column(lowerBound), Column(upperBound))
305+
}
304306

305307
/**
306-
* Between col1 and col2.
308+
* True if the current column is between the lower bound and upper bound, inclusive.
307309
*
308310
* @group java_expr_ops
309311
*/
310-
def between(col1: Column, col2: Column): Column = And(GreaterThan(this.expr, col1.expr), LessThan(this.expr, col2.expr))
312+
def between(lowerBound: Column, upperBound: Column): Column = {
313+
And(GreaterThanOrEqual(this.expr, lowerBound.expr),
314+
LessThanOrEqual(this.expr, upperBound.expr))
315+
}
311316

312317
/**
313318
* True if the current expression is null.

sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ class ColumnExpressionSuite extends QueryTest {
211211
test("between") {
212212
checkAnswer(
213213
testData4.filter($"a".between($"b", $"c")),
214-
testData4.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1) && r.getInt(0) < r.getInt(2)))
214+
testData4.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1) && r.getInt(0) <= r.getInt(2)))
215215
}
216216

217217
val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(

0 commit comments

Comments
 (0)