Skip to content

Commit 5f51db1

Browse files
wip
1 parent b363acc commit 5f51db1

File tree

2 files changed

+180
-11
lines changed

2 files changed

+180
-11
lines changed

dbldatagen/data_generator.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@
4444
_STREAM_SOURCE_START_TIMESTAMP = "startTimestamp"
4545

4646
_STREAMING_SOURCE_TEXT = "text"
47-
_STREAMING_SOURCE_TEXT = "parquet"
48-
_STREAMING_SOURCE_TEXT = "csv"
49-
_STREAMING_SOURCE_TEXT = "json"
50-
_STREAMING_SOURCE_TEXT = "ord"
47+
_STREAMING_SOURCE_PARQUET = "parquet"
48+
_STREAMING_SOURCE_CSV = "csv"
49+
_STREAMING_SOURCE_JSON = "json"
50+
_STREAMING_SOURCE_ORC = "ord"
51+
_STREAMING_SOURCE_DELTA = "delta"
5152

5253
class DataGenerator:
5354
""" Main Class for test data set generation
@@ -901,19 +902,60 @@ def _getBaseDataFrame(self, startId=0, streaming=False, options=None):
901902

902903
else:
903904
self._applyStreamingDefaults(build_options, passthrough_options)
904-
status = (
905-
f"Generating streaming data frame with {id_partitions} partitions")
905+
906+
assert _STREAMING_SOURCE_OPTION in build_options.keys(), "There must be a source type specified"
907+
streaming_source_format = build_options[_STREAMING_SOURCE_OPTION]
908+
909+
if streaming_source_format in [ _STREAMING_SOURCE_RATE, _STREAMING_SOURCE_RATE_MICRO_BATCH]:
910+
streaming_partitions = passthrough_options[_STREAMING_SOURCE_NUM_PARTITIONS]
911+
status = (
912+
f"Generating streaming data frame with {streaming_partitions} partitions")
913+
else:
914+
status = (
915+
f"Generating streaming data frame with '{streaming_source_format}' streaming source")
916+
906917
self.logger.info(status)
907918
self.executionHistory.append(status)
908919

909920
df1 = (self.sparkSession.readStream
910-
.format("rate"))
921+
.format(streaming_source_format))
911922

912923
for k, v in passthrough_options.items():
913924
df1 = df1.option(k, v)
914-
df1 = (df1.load()
915-
.withColumnRenamed("value", self._seedColumnName)
916-
)
925+
926+
file_formats = [_STREAMING_SOURCE_TEXT, _STREAMING_SOURCE_JSON, _STREAMING_SOURCE_CSV,
927+
_STREAMING_SOURCE_PARQUET, _STREAMING_SOURCE_DELTA, _STREAMING_SOURCE_ORC]
928+
929+
data_path = None
930+
source_table = None
931+
id_column = "value"
932+
933+
if _STREAMING_ID_FIELD_OPTION in build_options:
934+
id_column = build_options[_STREAMING_ID_FIELD_OPTION]
935+
936+
if _STREAMING_TABLE_OPTION in build_options:
937+
source_table = build_options[_STREAMING_TABLE_OPTION]
938+
939+
if _STREAMING_SCHEMA_OPTION in build_options:
940+
source_schema = build_options[_STREAMING_SCHEMA_OPTION]
941+
df1 = df1.schema(source_schema)
942+
943+
# get path for file based reads
944+
if _STREAMING_PATH_OPTION in build_options:
945+
data_path = build_options[_STREAMING_PATH_OPTION]
946+
elif streaming_source_format in file_formats:
947+
if "path" in passthrough_options:
948+
data_path = passthrough_options["path"]
949+
950+
if data_path is not None:
951+
df1 = df1.load(data_path)
952+
elif source_table is not None:
953+
df1 = df1.table(source_table)
954+
else:
955+
df1 = df1.load()
956+
957+
if id_column != self._seedColumnName:
958+
df1 = df1.withColumnRenamed(id_column, self._seedColumnName)
917959

918960
return df1
919961

@@ -1174,7 +1216,7 @@ def _applyStreamingDefaults(self, build_options, passthrough_options):
11741216
build_options[_STREAMING_SOURCE_OPTION] = _STREAMING_SOURCE_RATE
11751217

11761218
# setup `numPartitions` if not specified
1177-
if build_options[_STREAMING_SOURCE_OPTION] in [_STREAMING_SOURCE_RATE,_STREAMING_SOURCE_RATE_MICRO_BATCH]:
1219+
if build_options[_STREAMING_SOURCE_OPTION] in [_STREAMING_SOURCE_RATE, _STREAMING_SOURCE_RATE_MICRO_BATCH]:
11781220
if _STREAMING_SOURCE_NUM_PARTITIONS not in passthrough_options:
11791221
passthrough_options[_STREAMING_SOURCE_NUM_PARTITIONS] = self.partitions
11801222

tests/test_streaming.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,55 @@
22
import shutil
33
import time
44
import pytest
5+
import logging
56

67
from pyspark.sql.types import IntegerType, StringType, FloatType
8+
import pyspark.sql.functions as F
79

810
import dbldatagen as dg
911

1012
spark = dg.SparkSingleton.getLocalInstance("streaming tests")
1113

1214

15+
@pytest.fixture(scope="class")
16+
def setupLogging():
17+
FORMAT = '%(asctime)-15s %(message)s'
18+
logging.basicConfig(format=FORMAT)
19+
20+
1321
class TestStreaming():
1422
row_count = 100000
1523
column_count = 10
1624
time_to_run = 10
1725
rows_per_second = 5000
1826

27+
def setup_log_capture(self, caplog_object):
28+
""" set up log capture fixture
29+
30+
Sets up log capture fixture to only capture messages after setup and only
31+
capture warnings and errors
32+
33+
"""
34+
caplog_object.set_level(logging.WARNING)
35+
36+
# clear messages from setup
37+
caplog_object.clear()
38+
39+
def get_log_capture_warnings_and_errors(self, caplog_object, textFlag):
40+
"""
41+
gets count of errors containing specified text
42+
43+
:param caplog_object: log capture object from fixture
44+
:param textFlag: text to search for to include error or warning in count
45+
:return: count of errors containg text specified in `textFlag`
46+
"""
47+
streaming_warnings_and_errors = 0
48+
for r in caplog_object.records:
49+
if (r.levelname == "WARNING" or r.levelname == "ERROR") and textFlag in r.message:
50+
streaming_warnings_and_errors += 1
51+
52+
return streaming_warnings_and_errors
53+
1954
@pytest.fixture
2055
def getStreamingDirs(self):
2156
time_now = int(round(time.time() * 1000))
@@ -32,6 +67,23 @@ def getStreamingDirs(self):
3267
shutil.rmtree(base_dir, ignore_errors=True)
3368
print(f"\n\n*** test dir [{base_dir}] deleted")
3469

70+
@pytest.fixture
71+
def getDataDir(self):
72+
time_now = int(round(time.time() * 1000))
73+
base_dir = "/tmp/testdata_{}".format(time_now)
74+
data_dir = os.path.join(base_dir, "data")
75+
print(f"test data dir created '{base_dir}'")
76+
77+
# dont need to create the data dir
78+
os.makedirs(base_dir)
79+
80+
try:
81+
yield data_dir
82+
finally:
83+
shutil.rmtree(base_dir, ignore_errors=True)
84+
print(f"\n\n*** test data dir [{base_dir}] deleted")
85+
86+
3587
@pytest.mark.parametrize("seedColumnName", ["id",
3688
"_id",
3789
None])
@@ -342,4 +394,79 @@ def test_default_options(self, options, optionsExpected):
342394
assert datagen_options == expected_datagen_options
343395
assert passthrough_options == expected_passthrough_options
344396

397+
def test_text_streaming(self, getDataDir, caplog, getStreamingDirs):
398+
datadir = getDataDir
399+
base_dir, test_dir, checkpoint_dir = getStreamingDirs
400+
401+
# caplog fixture captures log content
402+
self.setup_log_capture(caplog)
403+
404+
df = spark.range(10000).select(F.expr("cast(id as string)").alias("id"))
405+
df.write.format("text").save(datadir)
406+
407+
testDataSpec = (dg.DataGenerator(sparkSession=spark, name="test_data_set1", rows=self.row_count,
408+
partitions=10, seedMethod='hash_fieldname')
409+
.withColumn("code1", IntegerType(), minValue=100, maxValue=200)
410+
.withColumn("code2", IntegerType(), minValue=0, maxValue=10)
411+
.withColumn("code3", StringType(), values=['a', 'b', 'c'])
412+
.withColumn("code4", StringType(), values=['a', 'b', 'c'], random=True)
413+
.withColumn("code5", StringType(), values=['a', 'b', 'c'], random=True, weights=[9, 1, 1])
414+
)
415+
416+
streamingOptions = {
417+
'dbldatagen.streaming.source': 'text',
418+
'dbldatagen.streaming.sourcePath': datadir,
419+
420+
}
421+
df_streaming = testDataSpec.build(withStreaming=True, options=streamingOptions)
422+
423+
# check that there warnings about `text` format
424+
text_format_warnings_and_errors = self.get_log_capture_warnings_and_errors(caplog, "text")
425+
assert text_format_warnings_and_errors > 0, "Should have error or warning messages about text format"
426+
427+
# loop until we get one seconds worth of data
428+
start_time = time.time()
429+
elapsed_time = 0
430+
rows_retrieved = 0
431+
time_limit = 10.0
432+
433+
while elapsed_time < time_limit and rows_retrieved < self.rows_per_second:
434+
sq = (df_streaming
435+
.writeStream
436+
.format("parquet")
437+
.outputMode("append")
438+
.option("path", test_dir)
439+
.option("checkpointLocation", checkpoint_dir)
440+
.trigger(once=True)
441+
.start())
442+
443+
# wait for trigger once to terminate
444+
sq.awaitTermination(5)
445+
446+
elapsed_time = time.time() - start_time
447+
448+
try:
449+
df2 = spark.read.format("parquet").load(test_dir)
450+
rows_retrieved = df2.count()
451+
452+
# ignore file or metadata not found issues arising from read before stream has written first batch
453+
except Exception as exc:
454+
print("Exception:", exc)
455+
456+
if sq.isActive:
457+
sq.stop()
458+
459+
end_time = time.time()
460+
461+
print("*** Done ***")
462+
print("read {} rows from newly written data".format(rows_retrieved))
463+
print("elapsed time (seconds)", end_time - start_time)
464+
465+
# check that we have at least one second of data
466+
assert rows_retrieved >= self.rows_per_second
467+
468+
469+
470+
471+
345472

0 commit comments

Comments
 (0)