Skip to content

Commit 2b3060e

Browse files
Merge branch 'master' into feature_streaming_enhancments
2 parents e0f7887 + 109707e commit 2b3060e

8 files changed

+225
-257
lines changed

dbldatagen/data_generator.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,9 @@
1010
import re
1111

1212
from pyspark.sql.types import LongType, IntegerType, StringType, StructType, StructField, DataType
13-
13+
from .spark_singleton import SparkSingleton
1414
from .column_generation_spec import ColumnGenerationSpec
1515
from .datagen_constants import DEFAULT_RANDOM_SEED, RANDOM_SEED_FIXED, RANDOM_SEED_HASH_FIELD_NAME
16-
from .spark_singleton import SparkSingleton
1716
from .utils import ensure, topologicalSort, DataGenError, deprecated
1817

1918
START_TIMESTAMP_OPTION = "startTimestamp"
@@ -41,7 +40,7 @@ class DataGenerator:
4140
:param rows: = amount of rows to generate
4241
:param startingId: = starting value for generated seed column
4342
:param randomSeed: = seed for random number generator
44-
:param partitions: = number of partitions to generate
43+
:param partitions: = number of partitions to generate, if not provided, uses `spark.sparkContext.defaultParallelism`
4544
:param verbose: = if `True`, generate verbose output
4645
:param batchSize: = UDF batch number of rows to pass via Apache Arrow to Pandas UDFs
4746
:param debug: = if set to True, output debug level of information
@@ -75,7 +74,18 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
7574
self._rowCount = rows
7675
self.starting_id = startingId
7776
self.__schema__ = None
78-
self.partitions = partitions if partitions is not None else 10
77+
78+
if sparkSession is None:
79+
sparkSession = SparkSingleton.getLocalInstance()
80+
81+
self.sparkSession = sparkSession
82+
83+
# if the active Spark session is stopped, you may end up with a valid SparkSession object but the underlying
84+
# SparkContext will be invalid
85+
assert sparkSession is not None, "Spark session not initialized"
86+
assert sparkSession.sparkContext is not None, "Expecting spark session to have valid sparkContext"
87+
88+
self.partitions = partitions if partitions is not None else sparkSession.sparkContext.defaultParallelism
7989

8090
# check for old versions of args
8191
if "starting_id" in kwargs:
@@ -131,20 +141,6 @@ def __init__(self, sparkSession=None, name=None, randomSeedMethod=None,
131141
self.withColumn(ColumnGenerationSpec.SEED_COLUMN, LongType(), nullable=False, implicit=True, omit=True)
132142
self._batchSize = batchSize
133143

134-
if sparkSession is None:
135-
sparkSession = SparkSingleton.getInstance()
136-
137-
assert sparkSession is not None, "The spark session attribute must be initialized"
138-
139-
self.sparkSession = sparkSession
140-
if sparkSession is None:
141-
raise DataGenError("""Spark session not initialized
142-
143-
The spark session attribute must be initialized in the DataGenerator initialization
144-
145-
i.e DataGenerator(sparkSession=spark, name="test", ...)
146-
""")
147-
148144
# set up use of pandas udfs
149145
self._setupPandas(batchSize)
150146

dbldatagen/spark_singleton.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
"""
1111

1212
import os
13-
import math
1413
import logging
1514
from pyspark.sql import SparkSession
1615

@@ -28,17 +27,27 @@ def getInstance(cls):
2827
return SparkSession.builder.getOrCreate()
2928

3029
@classmethod
31-
def getLocalInstance(cls, appName="new Spark session"):
30+
def getLocalInstance(cls, appName="new Spark session", useAllCores=True):
3231
"""Create a machine local Spark instance for Datalib.
33-
It uses 3/4 of the available cores for the spark session.
32+
By default, it uses `n-1` cores of the available cores for the spark session,
33+
where `n` is total cores available.
3434
35+
:param useAllCores: If `useAllCores` is True, then use all cores rather than `n-1` cores
3536
:returns: A Spark instance
3637
"""
37-
cpu_count = int(math.floor(os.cpu_count() * 0.75))
38-
logging.info("cpu count: %d", cpu_count)
38+
cpu_count = os.cpu_count()
3939

40-
return SparkSession.builder \
41-
.master(f"local[{cpu_count}]") \
40+
if useAllCores:
41+
spark_core_count = cpu_count
42+
else:
43+
spark_core_count = cpu_count - 1
44+
45+
logging.info("Spark core count: %d", spark_core_count)
46+
47+
sparkSession = SparkSession.builder \
48+
.master(f"local[{spark_core_count}]") \
4249
.appName(appName) \
4350
.config("spark.sql.warehouse.dir", "/tmp/spark-warehouse") \
4451
.getOrCreate()
52+
53+
return sparkSession

0 commit comments

Comments
 (0)