|
10 | 10 | import re
|
11 | 11 |
|
12 | 12 | from pyspark.sql.types import LongType, IntegerType, StringType, StructType, StructField, DataType
|
13 |
| - |
| 13 | +from .spark_singleton import SparkSingleton |
14 | 14 | from .column_generation_spec import ColumnGenerationSpec
|
15 | 15 | from .datagen_constants import DEFAULT_RANDOM_SEED, RANDOM_SEED_FIXED, RANDOM_SEED_HASH_FIELD_NAME
|
16 |
| -from .spark_singleton import SparkSingleton |
17 | 16 | from .utils import ensure, topologicalSort, DataGenError, deprecated
|
18 | 17 |
|
19 | 18 | START_TIMESTAMP_OPTION = "startTimestamp"
|
@@ -41,7 +40,7 @@ class DataGenerator:
|
41 | 40 | :param rows: = amount of rows to generate
|
42 | 41 | :param startingId: = starting value for generated seed column
|
43 | 42 | :param randomSeed: = seed for random number generator
|
44 |
| - :param partitions: = number of partitions to generate |
| 43 | + :param partitions: = number of partitions to generate, if not provided, uses `spark.sparkContext.defaultParallelism` |
45 | 44 | :param verbose: = if `True`, generate verbose output
|
46 | 45 | :param batchSize: = UDF batch number of rows to pass via Apache Arrow to Pandas UDFs
|
47 | 46 | :param debug: = if set to True, output debug level of information
|
@@ -75,7 +74,18 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
|
75 | 74 | self._rowCount = rows
|
76 | 75 | self.starting_id = startingId
|
77 | 76 | self.__schema__ = None
|
78 |
| - self.partitions = partitions if partitions is not None else 10 |
| 77 | + |
| 78 | + if sparkSession is None: |
| 79 | + sparkSession = SparkSingleton.getLocalInstance() |
| 80 | + |
| 81 | + self.sparkSession = sparkSession |
| 82 | + |
| 83 | + # if the active Spark session is stopped, you may end up with a valid SparkSession object but the underlying |
| 84 | + # SparkContext will be invalid |
| 85 | + assert sparkSession is not None, "Spark session not initialized" |
| 86 | + assert sparkSession.sparkContext is not None, "Expecting spark session to have valid sparkContext" |
| 87 | + |
| 88 | + self.partitions = partitions if partitions is not None else sparkSession.sparkContext.defaultParallelism |
79 | 89 |
|
80 | 90 | # check for old versions of args
|
81 | 91 | if "starting_id" in kwargs:
|
@@ -131,20 +141,6 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
|
131 | 141 | self.withColumn(ColumnGenerationSpec.SEED_COLUMN, LongType(), nullable=False, implicit=True, omit=True)
|
132 | 142 | self._batchSize = batchSize
|
133 | 143 |
|
134 |
| - if sparkSession is None: |
135 |
| - sparkSession = SparkSingleton.getInstance() |
136 |
| - |
137 |
| - assert sparkSession is not None, "The spark session attribute must be initialized" |
138 |
| - |
139 |
| - self.sparkSession = sparkSession |
140 |
| - if sparkSession is None: |
141 |
| - raise DataGenError("""Spark session not initialized |
142 |
| -
|
143 |
| - The spark session attribute must be initialized in the DataGenerator initialization |
144 |
| -
|
145 |
| - i.e DataGenerator(sparkSession=spark, name="test", ...) |
146 |
| - """) |
147 |
| - |
148 | 144 | # set up use of pandas udfs
|
149 | 145 | self._setupPandas(batchSize)
|
150 | 146 |
|
|
0 commit comments