@@ -364,6 +364,98 @@ def test_zip_with_different_number_of_items(self):
364364 self .assertEquals (a .count (), b .count ())
365365 self .assertRaises (Exception , lambda : a .zip (b ).count ())
366366
367+ def test_histogram (self ):
368+ # empty
369+ rdd = self .sc .parallelize ([])
370+ self .assertEquals ([0 ], rdd .histogram ([0 , 10 ])[1 ])
371+ self .assertEquals ([0 ], rdd .histogram ([0 , 10 ], True )[1 ])
372+
373+ # out of range
374+ rdd = self .sc .parallelize ([10.01 , - 0.01 ])
375+ self .assertEquals ([0 ], rdd .histogram ([0 , 10 ])[1 ])
376+ self .assertEquals ([0 ], rdd .histogram ([0 , 10 ], True )[1 ])
377+
378+ # in range with one bucket
379+ rdd = self .sc .parallelize (range (1 , 5 ))
380+ self .assertEquals ([4 ], rdd .histogram ([0 , 10 ])[1 ])
381+ self .assertEquals ([4 ], rdd .histogram ([0 , 10 ], True )[1 ])
382+
383+ # in range with one bucket exact match
384+ self .assertEquals ([4 ], rdd .histogram ([1 , 4 ])[1 ])
385+ self .assertEquals ([4 ], rdd .histogram ([1 , 4 ], True )[1 ])
386+
387+ # out of range with two buckets
388+ rdd = self .sc .parallelize ([10.01 , - 0.01 ])
389+ self .assertEquals ([0 , 0 ], rdd .histogram ([0 , 5 , 10 ])[1 ])
390+ self .assertEquals ([0 , 0 ], rdd .histogram ([0 , 5 , 10 ], True )[1 ])
391+
392+ # out of range with two uneven buckets
393+ rdd = self .sc .parallelize ([10.01 , - 0.01 ])
394+ self .assertEquals ([0 , 0 ], rdd .histogram ([0 , 4 , 10 ])[1 ])
395+
396+ # in range with two buckets
397+ rdd = self .sc .parallelize ([1 , 2 , 3 , 5 , 6 ])
398+ self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ])[1 ])
399+ self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ], True )[1 ])
400+
401+ # in range with two bucket and None
402+ rdd = self .sc .parallelize ([1 , 2 , 3 , 5 , 6 , None , float ('nan' )])
403+ self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ])[1 ])
404+ self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ], True )[1 ])
405+
406+ # in range with two uneven buckets
407+ rdd = self .sc .parallelize ([1 , 2 , 3 , 5 , 6 ])
408+ self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 11 ])[1 ])
409+
410+ # mixed range with two uneven buckets
411+ rdd = self .sc .parallelize ([- 0.01 , 0.0 , 1 , 2 , 3 , 5 , 6 , 11.0 , 11.01 ])
412+ self .assertEquals ([4 , 3 ], rdd .histogram ([0 , 5 , 11 ])[1 ])
413+
414+ # mixed range with four uneven buckets
415+ rdd = self .sc .parallelize ([- 0.01 , 0.0 , 1 , 2 , 3 , 5 , 6 , 11.01 , 12.0 , 199.0 , 200.0 , 200.1 ])
416+ self .assertEquals ([4 , 2 , 1 , 3 ], rdd .histogram ([0.0 , 5.0 , 11.0 , 12.0 , 200.0 ])[1 ])
417+
418+ # mixed range with uneven buckets and NaN
419+ rdd = self .sc .parallelize ([- 0.01 , 0.0 , 1 , 2 , 3 , 5 , 6 , 11.01 , 12.0 ,
420+ 199.0 , 200.0 , 200.1 , None , float ('nan' )])
421+ self .assertEquals ([4 , 2 , 1 , 3 ], rdd .histogram ([0.0 , 5.0 , 11.0 , 12.0 , 200.0 ])[1 ])
422+
423+ # out of range with infinite buckets
424+ rdd = self .sc .parallelize ([10.01 , - 0.01 , float ('nan' )])
425+ self .assertEquals ([1 , 1 ], rdd .histogram ([float ('-inf' ), 0 , float ('inf' )])[1 ])
426+
427+ # invalid buckets
428+ self .assertRaises (ValueError , lambda : rdd .histogram ([]))
429+ self .assertRaises (ValueError , lambda : rdd .histogram ([1 ]))
430+
431+ # without buckets
432+ rdd = self .sc .parallelize (range (1 , 5 ))
433+ self .assertEquals (([1 , 4 ], [4 ]), rdd .histogram (1 ))
434+
435+ # without buckets single element
436+ rdd = self .sc .parallelize ([1 ])
437+ self .assertEquals (([1 , 1 ], [1 ]), rdd .histogram (1 ))
438+
439+ # without bucket no range
440+ rdd = self .sc .parallelize ([1 ] * 4 )
441+ self .assertEquals (([1 , 1 ], [4 ]), rdd .histogram (1 ))
442+
443+ # without buckets basic two
444+ rdd = self .sc .parallelize (range (1 , 5 ))
445+ self .assertEquals (([1 , 2.5 , 4 ], [2 , 2 ]), rdd .histogram (2 ))
446+
447+ # without buckets with more requested than elements
448+ rdd = self .sc .parallelize ([1 , 2 ])
449+ buckets = [1 + 0.2 * i for i in range (6 )]
450+ hist = [1 , 0 , 0 , 0 , 1 ]
451+ self .assertEquals ((buckets , hist ), rdd .histogram (5 ))
452+
453+ # invalid RDDs
454+ rdd = self .sc .parallelize ([1 , float ('inf' )])
455+ self .assertRaises (ValueError , lambda : rdd .histogram (2 ))
456+ rdd = self .sc .parallelize ([float ('nan' )])
457+ self .assertRaises (ValueError , lambda : rdd .histogram (2 ))
458+
367459
368460class TestIO (PySparkTestCase ):
369461
0 commit comments