Skip to content

Commit

Permalink
Enable Spark Fast Register (#2765)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <pingsutw@apache.org>
Co-authored-by: Yee Hing Tong <wild-endeavor@users.noreply.github.com>
  • Loading branch information
pingsutw and wild-endeavor authored Sep 25, 2024
1 parent 534673c commit f394bc9
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 23 deletions.
25 changes: 15 additions & 10 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit.configuration import (
SERIALIZED_CONTEXT_ENV_VAR,
FastSerializationSettings,
ImageConfig,
SerializationSettings,
StatsConfig,
)
Expand Down Expand Up @@ -325,16 +326,20 @@ def setup_execution(
if compressed_serialization_settings:
ss = SerializationSettings.from_transport(compressed_serialization_settings)
ssb = ss.new_builder()
ssb.project = ssb.project or exe_project
ssb.domain = ssb.domain or exe_domain
ssb.version = tk_version
if dynamic_addl_distro:
ssb.fast_serialization_settings = FastSerializationSettings(
enabled=True,
destination_dir=dynamic_dest_dir,
distribution_location=dynamic_addl_distro,
)
cb = cb.with_serialization_settings(ssb.build())
else:
ss = SerializationSettings(ImageConfig.auto())
ssb = ss.new_builder()

ssb.project = ssb.project or exe_project
ssb.domain = ssb.domain or exe_domain
ssb.version = tk_version
if dynamic_addl_distro:
ssb.fast_serialization_settings = FastSerializationSettings(
enabled=True,
destination_dir=dynamic_dest_dir,
distribution_location=dynamic_addl_distro,
)
cb = cb.with_serialization_settings(ssb.build())

with FlyteContextManager.with_context(cb) as ctx:
yield ctx
Expand Down
5 changes: 5 additions & 0 deletions flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,11 @@ def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str]
# Let us remove any extensions like .py
basename = os.path.splitext(basename)[0]

# This is an escape hatch for the zipimporter (used by spark). As this function is called recursively,
# it'll eventually reach the zip file, which is not extracted, so we should return.
if not Path(dirname).is_dir():
return basename

if dirname == package_root:
return basename

Expand Down
5 changes: 3 additions & 2 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tarfile
import tempfile
import typing
from datetime import datetime
from pathlib import Path
from types import ModuleType
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -68,9 +69,9 @@ def compress_scripts(source_path: str, destination: str, modules: List[ModuleTyp
# intended to be passed as a filter to tarfile.add
# https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.add
def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo:
# set time to epoch timestamp 0, aka 00:00:00 UTC on 1 January 1970
# set time to epoch timestamp 0, aka 00:00:00 UTC on 1 January 1980
# note that when extracting this tarfile, this time will be shown as the modified date
tar_info.mtime = 0
tar_info.mtime = datetime(1980, 1, 1).timestamp()

# user/group info
tar_info.uid = 0
Expand Down
22 changes: 14 additions & 8 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union, cast

Expand All @@ -8,7 +9,6 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
Expand Down Expand Up @@ -158,13 +158,6 @@ def __init__(
**kwargs,
)

def get_image(self, settings: SerializationSettings) -> str:
if isinstance(self.container_image, ImageSpec):
# Ensure that the code is always copied into the image, even during fast-registration.
self.container_image.source_root = settings.source_root

return get_registerable_container_image(self.container_image, settings.image_config)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
Expand Down Expand Up @@ -201,6 +194,19 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
sess_builder = sess_builder.config(conf=spark_conf)

self.sess = sess_builder.getOrCreate()

if (
ctx.serialization_settings
and ctx.serialization_settings.fast_serialization_settings
and ctx.serialization_settings.fast_serialization_settings.enabled
and ctx.execution_state
and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
):
file_name = "flyte_wf"
file_format = "zip"
shutil.make_archive(file_name, file_format, os.getcwd())
self.sess.sparkContext.addPyFile(f"{file_name}.{file_format}")

return user_params.builder().add_attr("SPARK_SESSION", self.sess).build()

def execute(self, **kwargs) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>1.10.7", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]
plugin_requires = ["flytekit>1.13.5", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
40 changes: 38 additions & 2 deletions plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os.path

import pandas as pd
import pyspark
import pytest

from flytekit.core import context_manager
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -118,3 +122,35 @@ def test_to_html():
tf = StructuredDatasetTransformerEngine()
output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame)
assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output


def test_spark_addPyFile():
@task(
task_config=Spark(
spark_conf={"spark": "1"},
)
)
def my_spark(a: int) -> int:
return a

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir="/User/flyte/workflows",
distribution_location="s3://my-s3-bucket/fast/123",
),
)

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings)
) as new_ctx:
my_spark.pre_execute(new_ctx.user_space_params)
os.remove(os.path.join(os.getcwd(), "flyte_wf.zip"))
9 changes: 9 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,15 @@ def test_setup_disk_prefix():
}


def test_setup_for_fast_register():
dynamic_addl_distro = "distro"
dynamic_dest_dir = "/root"
with setup_execution(raw_output_data_prefix="qwerty", dynamic_addl_distro=dynamic_addl_distro, dynamic_dest_dir=dynamic_dest_dir) as ctx:
assert ctx.serialization_settings.fast_serialization_settings.enabled is True
assert ctx.serialization_settings.fast_serialization_settings.distribution_location == dynamic_addl_distro
assert ctx.serialization_settings.fast_serialization_settings.destination_dir == dynamic_dest_dir


@mock.patch("google.auth.compute_engine._metadata")
def test_setup_cloud_prefix(mock_gcs):
with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx:
Expand Down

0 comments on commit f394bc9

Please sign in to comment.