Skip to content

Commit

Permalink
[xy] Improve local spark pipeline (mage-ai#2206)
Browse files Browse the repository at this point in the history
* [xy] Improve local spark.

* [xy] Fix reading spark variable.
  • Loading branch information
wangxiaoyou1993 authored Mar 16, 2023
1 parent e8b6618 commit 419320e
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 9 deletions.
60 changes: 52 additions & 8 deletions mage_ai/data_preparation/models/block/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from logging import Logger
from mage_ai.data_cleaner.shared.utils import (
is_geo_dataframe,
is_spark_dataframe,
)
from mage_ai.data_preparation.models.block.extension.utils import handle_run_tests
from mage_ai.data_preparation.models.block.utils import (
Expand Down Expand Up @@ -252,6 +253,10 @@ def __init__(
self.dynamic_block_uuid = None
self.dynamic_upstream_block_uuids = None

# Spark session
self.spark = None
self.spark_init = False

@property
def uuid(self):
return self.dynamic_block_uuid or self._uuid
Expand Down Expand Up @@ -1053,7 +1058,6 @@ def get_outputs(
block_uuid,
partition=execution_partition,
)

if not include_print_outputs:
all_variables = self.output_variables(execution_partition=execution_partition)

Expand All @@ -1063,6 +1067,7 @@ def get_outputs(
block_uuid,
v,
partition=execution_partition,
spark=self.__get_spark_session(),
)

if variable_type is not None and variable_object.variable_type != variable_type:
Expand All @@ -1071,6 +1076,7 @@ def get_outputs(
data = variable_object.read_data(
sample=True,
sample_count=sample_count,
spark=self.__get_spark_session(),
)
if type(data) is pd.DataFrame:
try:
Expand Down Expand Up @@ -1130,6 +1136,19 @@ def get_outputs(
type=DataType.TEXT,
variable_uuid=v,
)
elif is_spark_dataframe(data):
df = data.toPandas()
columns_to_display = df.columns.tolist()[:DATAFRAME_ANALYSIS_MAX_COLUMNS]
data = dict(
sample_data=dict(
columns=columns_to_display,
rows=json.loads(df[columns_to_display].to_json(orient='split'))['data']
),
type=DataType.TABLE,
variable_uuid=v,
)
data_products.append(data)
continue
outputs.append(data)
return outputs + data_products

Expand Down Expand Up @@ -1166,6 +1185,7 @@ async def get_outputs_async(
block_uuid,
v,
partition=execution_partition,
spark=self.__get_spark_session(),
)

if variable_type is not None and variable_object.variable_type != variable_type:
Expand All @@ -1174,6 +1194,7 @@ async def get_outputs_async(
data = await variable_object.read_data_async(
sample=True,
sample_count=sample_count,
spark=self.__get_spark_session(),
)
if type(data) is pd.DataFrame:
try:
Expand Down Expand Up @@ -1233,6 +1254,19 @@ async def get_outputs_async(
type=DataType.TEXT,
variable_uuid=v,
)
elif is_spark_dataframe(data):
df = data.toPandas()
columns_to_display = df.columns.tolist()[:DATAFRAME_ANALYSIS_MAX_COLUMNS]
data = dict(
sample_data=dict(
columns=columns_to_display,
rows=json.loads(df[columns_to_display].to_json(orient='split'))['data']
),
type=DataType.TABLE,
variable_uuid=v,
)
data_products.append(data)
continue
outputs.append(data)
return outputs + data_products

Expand Down Expand Up @@ -1663,14 +1697,23 @@ def __enrich_global_vars(self, global_vars: Dict = None):
is_spark_env()):
global_vars = global_vars or dict()
if not global_vars.get('spark'):
try:
from pyspark.sql import SparkSession
global_vars['spark'] = SparkSession.builder.master(
os.getenv('SPARK_MASTER_HOST', 'local')).getOrCreate()
except Exception:
pass
spark = self.__get_spark_session()
if spark is not None:
global_vars['spark'] = spark
return global_vars

def __get_spark_session(self):
if self.spark_init:
return self.spark
try:
from pyspark.sql import SparkSession
self.spark = SparkSession.builder.master(
os.getenv('SPARK_MASTER_HOST', 'local')).getOrCreate()
except Exception:
self.spark = None
self.spark_init = True
return self.spark

def __store_variables_prepare(
self,
variable_mapping: Dict,
Expand Down Expand Up @@ -1722,7 +1765,8 @@ def store_variables(
dynamic_block_uuid,
)
for uuid, data in variables_data['variable_mapping'].items():
if spark is not None and type(data) is pd.DataFrame:
if spark is not None and self.pipeline.type == PipelineType.PYSPARK \
and type(data) is pd.DataFrame:
data = spark.createDataFrame(data)
self.pipeline.variable_manager.add_variable(
self.pipeline.uuid,
Expand Down
7 changes: 6 additions & 1 deletion mage_ai/data_preparation/models/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ async def read_data_async(
"""
if self.variable_type == VariableType.DATAFRAME:
return self.__read_parquet(sample=sample, sample_count=sample_count)
elif self.variable_type == VariableType.SPARK_DATAFRAME:
return self.__read_spark_parquet(sample=sample, sample_count=sample_count, spark=spark)
elif self.variable_type == VariableType.DATAFRAME_ANALYSIS:
return await self.__read_dataframe_analysis_async(
dataframe_analysis_keys=dataframe_analysis_keys,
Expand Down Expand Up @@ -367,14 +369,17 @@ def __read_parquet(self, sample: bool = False, sample_count: int = None) -> pd.D
def __read_spark_parquet(self, sample: bool = False, sample_count: int = None, spark=None):
if spark is None:
return None
return (
df = (
spark.read
.format('csv')
.option('header', 'true')
.option('inferSchema', 'true')
.option('delimiter', ',')
.load(self.variable_path)
)
if sample and sample_count:
df = df.limit(sample_count)
return df

def __write_geo_dataframe(self, data) -> None:
os.makedirs(self.variable_path, exist_ok=True)
Expand Down

0 comments on commit 419320e

Please sign in to comment.