2
2
import shutil
3
3
import time
4
4
import pytest
5
+ import logging
5
6
6
7
from pyspark .sql .types import IntegerType , StringType , FloatType
8
+ import pyspark .sql .functions as F
7
9
8
10
import dbldatagen as dg
9
11
10
12
spark = dg .SparkSingleton .getLocalInstance ("streaming tests" )
11
13
12
14
15
+ @pytest .fixture (scope = "class" )
16
+ def setupLogging ():
17
+ FORMAT = '%(asctime)-15s %(message)s'
18
+ logging .basicConfig (format = FORMAT )
19
+
20
+
13
21
class TestStreaming ():
14
22
row_count = 100000
15
23
column_count = 10
16
24
time_to_run = 10
17
25
rows_per_second = 5000
18
26
27
+ def setup_log_capture (self , caplog_object ):
28
+ """ set up log capture fixture
29
+
30
+ Sets up log capture fixture to only capture messages after setup and only
31
+ capture warnings and errors
32
+
33
+ """
34
+ caplog_object .set_level (logging .WARNING )
35
+
36
+ # clear messages from setup
37
+ caplog_object .clear ()
38
+
39
+ def get_log_capture_warnings_and_errors (self , caplog_object , textFlag ):
40
+ """
41
+ gets count of errors containing specified text
42
+
43
+ :param caplog_object: log capture object from fixture
44
+ :param textFlag: text to search for to include error or warning in count
45
+ :return: count of errors containg text specified in `textFlag`
46
+ """
47
+ streaming_warnings_and_errors = 0
48
+ for r in caplog_object .records :
49
+ if (r .levelname == "WARNING" or r .levelname == "ERROR" ) and textFlag in r .message :
50
+ streaming_warnings_and_errors += 1
51
+
52
+ return streaming_warnings_and_errors
53
+
19
54
@pytest .fixture
20
55
def getStreamingDirs (self ):
21
56
time_now = int (round (time .time () * 1000 ))
@@ -32,6 +67,23 @@ def getStreamingDirs(self):
32
67
shutil .rmtree (base_dir , ignore_errors = True )
33
68
print (f"\n \n *** test dir [{ base_dir } ] deleted" )
34
69
70
+ @pytest .fixture
71
+ def getDataDir (self ):
72
+ time_now = int (round (time .time () * 1000 ))
73
+ base_dir = "/tmp/testdata_{}" .format (time_now )
74
+ data_dir = os .path .join (base_dir , "data" )
75
+ print (f"test data dir created '{ base_dir } '" )
76
+
77
+ # dont need to create the data dir
78
+ os .makedirs (base_dir )
79
+
80
+ try :
81
+ yield data_dir
82
+ finally :
83
+ shutil .rmtree (base_dir , ignore_errors = True )
84
+ print (f"\n \n *** test data dir [{ base_dir } ] deleted" )
85
+
86
+
35
87
@pytest .mark .parametrize ("seedColumnName" , ["id" ,
36
88
"_id" ,
37
89
None ])
@@ -342,4 +394,79 @@ def test_default_options(self, options, optionsExpected):
342
394
assert datagen_options == expected_datagen_options
343
395
assert passthrough_options == expected_passthrough_options
344
396
397
+ def test_text_streaming (self , getDataDir , caplog , getStreamingDirs ):
398
+ datadir = getDataDir
399
+ base_dir , test_dir , checkpoint_dir = getStreamingDirs
400
+
401
+ # caplog fixture captures log content
402
+ self .setup_log_capture (caplog )
403
+
404
+ df = spark .range (10000 ).select (F .expr ("cast(id as string)" ).alias ("id" ))
405
+ df .write .format ("text" ).save (datadir )
406
+
407
+ testDataSpec = (dg .DataGenerator (sparkSession = spark , name = "test_data_set1" , rows = self .row_count ,
408
+ partitions = 10 , seedMethod = 'hash_fieldname' )
409
+ .withColumn ("code1" , IntegerType (), minValue = 100 , maxValue = 200 )
410
+ .withColumn ("code2" , IntegerType (), minValue = 0 , maxValue = 10 )
411
+ .withColumn ("code3" , StringType (), values = ['a' , 'b' , 'c' ])
412
+ .withColumn ("code4" , StringType (), values = ['a' , 'b' , 'c' ], random = True )
413
+ .withColumn ("code5" , StringType (), values = ['a' , 'b' , 'c' ], random = True , weights = [9 , 1 , 1 ])
414
+ )
415
+
416
+ streamingOptions = {
417
+ 'dbldatagen.streaming.source' : 'text' ,
418
+ 'dbldatagen.streaming.sourcePath' : datadir ,
419
+
420
+ }
421
+ df_streaming = testDataSpec .build (withStreaming = True , options = streamingOptions )
422
+
423
+ # check that there warnings about `text` format
424
+ text_format_warnings_and_errors = self .get_log_capture_warnings_and_errors (caplog , "text" )
425
+ assert text_format_warnings_and_errors > 0 , "Should have error or warning messages about text format"
426
+
427
+ # loop until we get one seconds worth of data
428
+ start_time = time .time ()
429
+ elapsed_time = 0
430
+ rows_retrieved = 0
431
+ time_limit = 10.0
432
+
433
+ while elapsed_time < time_limit and rows_retrieved < self .rows_per_second :
434
+ sq = (df_streaming
435
+ .writeStream
436
+ .format ("parquet" )
437
+ .outputMode ("append" )
438
+ .option ("path" , test_dir )
439
+ .option ("checkpointLocation" , checkpoint_dir )
440
+ .trigger (once = True )
441
+ .start ())
442
+
443
+ # wait for trigger once to terminate
444
+ sq .awaitTermination (5 )
445
+
446
+ elapsed_time = time .time () - start_time
447
+
448
+ try :
449
+ df2 = spark .read .format ("parquet" ).load (test_dir )
450
+ rows_retrieved = df2 .count ()
451
+
452
+ # ignore file or metadata not found issues arising from read before stream has written first batch
453
+ except Exception as exc :
454
+ print ("Exception:" , exc )
455
+
456
+ if sq .isActive :
457
+ sq .stop ()
458
+
459
+ end_time = time .time ()
460
+
461
+ print ("*** Done ***" )
462
+ print ("read {} rows from newly written data" .format (rows_retrieved ))
463
+ print ("elapsed time (seconds)" , end_time - start_time )
464
+
465
+ # check that we have at least one second of data
466
+ assert rows_retrieved >= self .rows_per_second
467
+
468
+
469
+
470
+
471
+
345
472
0 commit comments