@@ -24,6 +24,7 @@ import cern.jet.random.Poisson
24
24
import cern .jet .random .engine .DRand
25
25
26
26
import org .apache .spark .{Partition , TaskContext }
27
+ import org .apache .spark .util .random .BernoulliSampler
27
28
28
29
private [spark]
29
30
class FoldedRDDPartition (val prev : Partition , val seed : Int ) extends Partition with Serializable {
@@ -32,24 +33,10 @@ class FoldedRDDPartition(val prev: Partition, val seed: Int) extends Partition w
32
33
33
34
class FoldedRDD [T : ClassTag ](
34
35
prev : RDD [T ],
35
- fold : Int ,
36
- folds : Int ,
36
+ fold : Float ,
37
+ folds : Float ,
37
38
seed : Int )
38
- extends RDD [T ](prev) {
39
-
40
- override def getPartitions : Array [Partition ] = {
41
- val rg = new Random (seed)
42
- firstParent[T ].partitions.map(x => new FoldedRDDPartition (x, rg.nextInt))
43
- }
44
-
45
- override def getPreferredLocations (split : Partition ): Seq [String ] =
46
- firstParent[T ].preferredLocations(split.asInstanceOf [FoldedRDDPartition ].prev)
47
-
48
- override def compute (splitIn : Partition , context : TaskContext ): Iterator [T ] = {
49
- val split = splitIn.asInstanceOf [FoldedRDDPartition ]
50
- val rand = new Random (split.seed)
51
- firstParent[T ].iterator(split.prev, context).filter(x => (rand.nextInt(folds) == fold- 1 ))
52
- }
39
+ extends PartitionwiseSampledRDD [T , T ](prev, new BernoulliSampler ((fold- 1 )/ folds,fold/ folds, false ), seed) {
53
40
}
54
41
55
42
/**
@@ -58,14 +45,8 @@ class FoldedRDD[T: ClassTag](
58
45
*/
59
46
class CompositeFoldedRDD [T : ClassTag ](
60
47
prev : RDD [T ],
61
- fold : Int ,
62
- folds : Int ,
48
+ fold : Float ,
49
+ folds : Float ,
63
50
seed : Int )
64
- extends FoldedRDD [T ](prev, fold, folds, seed) {
65
-
66
- override def compute (splitIn : Partition , context : TaskContext ): Iterator [T ] = {
67
- val split = splitIn.asInstanceOf [FoldedRDDPartition ]
68
- val rand = new Random (split.seed)
69
- firstParent[T ].iterator(split.prev, context).filter(x => (rand.nextInt(folds) != fold- 1 ))
70
- }
51
+ extends PartitionwiseSampledRDD [T , T ](prev, new BernoulliSampler ((fold- 1 )/ folds, fold/ folds, true ), seed) {
71
52
}
0 commit comments