Skip to content

Commit 020cbdf

Browse files
author
jbencook
committed
[SPARK-4860][pyspark][sql] using Scala implementations of sample() and takeSample()
1 parent 6ee6aa7 commit 020cbdf

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

python/pyspark/sql.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,6 +2085,35 @@ def subtract(self, other, numPartitions=None):
20852085
else:
20862086
raise ValueError("Can only subtract another SchemaRDD")
20872087

2088+
def sample(self, withReplacement, fraction, seed=None):
2089+
"""
2090+
Return a sampled subset of this SchemaRDD.
2091+
2092+
>>> srdd = sqlCtx.inferSchema(rdd)
2093+
>>> srdd.sample(False, 0.5, 97).count()
2094+
2L
2095+
"""
2096+
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
2097+
seed = seed if seed is not None else random.randint(0, sys.maxint)
2098+
rdd = self._jschema_rdd.baseSchemaRDD().sample(
2099+
withReplacement, fraction, long(seed))
2100+
return SchemaRDD(rdd.toJavaSchemaRDD(), self.sql_ctx)
2101+
2102+
def takeSample(self, withReplacement, num, seed=None):
2103+
"""Return a fixed-size sampled subset of this SchemaRDD.
2104+
2105+
>>> srdd = sqlCtx.inferSchema(rdd)
2106+
>>> srdd.takeSample(False, 2, 97)
2107+
[Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
2108+
"""
2109+
seed = seed if seed is not None else random.randint(0, sys.maxint)
2110+
with SCCallSiteSync(self.context) as css:
2111+
bytesInJava = self._jschema_rdd.baseSchemaRDD() \
2112+
.takeSampleToPython(withReplacement, num, long(seed)) \
2113+
.iterator()
2114+
cls = _create_cls(self.schema())
2115+
return map(cls, self._collect_iterator_through_file(bytesInJava))
2116+
20882117

20892118
def _test():
20902119
import doctest

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,20 @@ class SchemaRDD(
437437
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
438438
}
439439

440+
/**
441+
* Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same
442+
* format as javaToPython and collectToPython. It is used by pyspark.
443+
*/
444+
private[sql] def takeSampleToPython(withReplacement: Boolean,
445+
num: Int,
446+
seed: Long): JList[Array[Byte]] = {
447+
val fieldTypes = schema.fields.map(_.dataType)
448+
val pickle = new Pickler
449+
new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row =>
450+
EvaluatePython.rowToArray(row, fieldTypes)
451+
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
452+
}
453+
440454
/**
441455
* Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value
442456
* of base RDD functions that do not change schema.

0 commit comments

Comments
 (0)