@@ -368,26 +368,25 @@ def test_histogram(self):
368368 # empty
369369 rdd = self .sc .parallelize ([])
370370 self .assertEquals ([0 ], rdd .histogram ([0 , 10 ])[1 ])
371- self .assertEquals ([0 ], rdd .histogram ([0 , 10 ], True )[1 ])
371+ self .assertEquals ([0 , 0 ], rdd .histogram ([0 , 4 , 10 ])[1 ])
372+ self .assertRaises (ValueError , lambda : rdd .histogram (1 ))
372373
373374 # out of range
374375 rdd = self .sc .parallelize ([10.01 , - 0.01 ])
375376 self .assertEquals ([0 ], rdd .histogram ([0 , 10 ])[1 ])
376- self .assertEquals ([0 ], rdd .histogram ([ 0 , 10 ], True )[1 ])
377+ self .assertEquals ([0 , 0 ], rdd .histogram (( 0 , 4 , 10 ) )[1 ])
377378
378379 # in range with one bucket
379380 rdd = self .sc .parallelize (range (1 , 5 ))
380381 self .assertEquals ([4 ], rdd .histogram ([0 , 10 ])[1 ])
381- self .assertEquals ([4 ], rdd .histogram ([0 , 10 ], True )[1 ])
382+ self .assertEquals ([3 , 1 ], rdd .histogram ([0 , 4 , 10 ] )[1 ])
382383
383384 # in range with one bucket exact match
384385 self .assertEquals ([4 ], rdd .histogram ([1 , 4 ])[1 ])
385- self .assertEquals ([4 ], rdd .histogram ([1 , 4 ], True )[1 ])
386386
387387 # out of range with two buckets
388388 rdd = self .sc .parallelize ([10.01 , - 0.01 ])
389389 self .assertEquals ([0 , 0 ], rdd .histogram ([0 , 5 , 10 ])[1 ])
390- self .assertEquals ([0 , 0 ], rdd .histogram ([0 , 5 , 10 ], True )[1 ])
391390
392391 # out of range with two uneven buckets
393392 rdd = self .sc .parallelize ([10.01 , - 0.01 ])
@@ -396,12 +395,10 @@ def test_histogram(self):
396395 # in range with two buckets
397396 rdd = self .sc .parallelize ([1 , 2 , 3 , 5 , 6 ])
398397 self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ])[1 ])
399- self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ], True )[1 ])
400398
401399 # in range with two bucket and None
402400 rdd = self .sc .parallelize ([1 , 2 , 3 , 5 , 6 , None , float ('nan' )])
403401 self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ])[1 ])
404- self .assertEquals ([3 , 2 ], rdd .histogram ([0 , 5 , 10 ], True )[1 ])
405402
406403 # in range with two uneven buckets
407404 rdd = self .sc .parallelize ([1 , 2 , 3 , 5 , 6 ])
@@ -421,12 +418,14 @@ def test_histogram(self):
421418 self .assertEquals ([4 , 2 , 1 , 3 ], rdd .histogram ([0.0 , 5.0 , 11.0 , 12.0 , 200.0 ])[1 ])
422419
423420 # out of range with infinite buckets
424- rdd = self .sc .parallelize ([10.01 , - 0.01 , float ('nan' )])
425- self .assertEquals ([1 , 1 ], rdd .histogram ([float ('-inf' ), 0 , float ('inf' )])[1 ])
421+ rdd = self .sc .parallelize ([10.01 , - 0.01 , float ('nan' ), float ( "inf" ) ])
422+ self .assertEquals ([1 , 2 ], rdd .histogram ([float ('-inf' ), 0 , float ('inf' )])[1 ])
426423
427424 # invalid buckets
428425 self .assertRaises (ValueError , lambda : rdd .histogram ([]))
429426 self .assertRaises (ValueError , lambda : rdd .histogram ([1 ]))
427+ self .assertRaises (ValueError , lambda : rdd .histogram (0 ))
428+ self .assertRaises (TypeError , lambda : rdd .histogram ({}))
430429
431430 # without buckets
432431 rdd = self .sc .parallelize (range (1 , 5 ))
@@ -456,6 +455,19 @@ def test_histogram(self):
456455 rdd = self .sc .parallelize ([float ('nan' )])
457456 self .assertRaises (ValueError , lambda : rdd .histogram (2 ))
458457
458+ # string
459+ rdd = self .sc .parallelize (["ab" , "ac" , "b" , "bd" , "ef" ], 2 )
460+ self .assertEquals ([2 , 2 ], rdd .histogram (["a" , "b" , "c" ])[1 ])
461+ self .assertEquals ((["ab" , "ef" ], [5 ]), rdd .histogram (1 ))
462+ self .assertRaises (TypeError , lambda : rdd .histogram (2 ))
463+
464+ # mixed RDD
465+ rdd = self .sc .parallelize ([1 , 4 , "ab" , "ac" , "b" ], 2 )
466+ self .assertEquals ([1 , 1 ], rdd .histogram ([0 , 4 , 10 ])[1 ])
467+ self .assertEquals ([2 , 1 ], rdd .histogram (["a" , "b" , "c" ])[1 ])
468+ self .assertEquals (([1 , "b" ], [5 ]), rdd .histogram (1 ))
469+ self .assertRaises (TypeError , lambda : rdd .histogram (2 ))
470+
459471
460472class TestIO (PySparkTestCase ):
461473
0 commit comments