Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Spark Fast Register #2765

Merged
merged 36 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions flytekit/bin/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit.configuration import (
SERIALIZED_CONTEXT_ENV_VAR,
FastSerializationSettings,
ImageConfig,
SerializationSettings,
StatsConfig,
)
Expand Down Expand Up @@ -325,16 +326,20 @@ def setup_execution(
if compressed_serialization_settings:
ss = SerializationSettings.from_transport(compressed_serialization_settings)
ssb = ss.new_builder()
ssb.project = ssb.project or exe_project
ssb.domain = ssb.domain or exe_domain
ssb.version = tk_version
if dynamic_addl_distro:
ssb.fast_serialization_settings = FastSerializationSettings(
enabled=True,
destination_dir=dynamic_dest_dir,
distribution_location=dynamic_addl_distro,
)
cb = cb.with_serialization_settings(ssb.build())
else:
ss = SerializationSettings(ImageConfig.auto())
ssb = ss.new_builder()

ssb.project = ssb.project or exe_project
ssb.domain = ssb.domain or exe_domain
ssb.version = tk_version
if dynamic_addl_distro:
ssb.fast_serialization_settings = FastSerializationSettings(
enabled=True,
destination_dir=dynamic_dest_dir,
distribution_location=dynamic_addl_distro,
)
cb = cb.with_serialization_settings(ssb.build())

with FlyteContextManager.with_context(cb) as ctx:
yield ctx
Expand Down
3 changes: 3 additions & 0 deletions flytekit/core/tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,9 @@ def _resolve_abs_module_name(self, path: str, package_root: typing.Optional[str]
# Let us remove any extensions like .py
basename = os.path.splitext(basename)[0]

if not Path(dirname).is_dir():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since the executor will load the workflow from the zip file.
image

pingsutw marked this conversation as resolved.
Show resolved Hide resolved
return basename

if dirname == package_root:
return basename

Expand Down
5 changes: 3 additions & 2 deletions flytekit/tools/script_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import tarfile
import tempfile
import typing
from datetime import datetime
from pathlib import Path
from types import ModuleType
from typing import List, Optional, Tuple, Union
Expand Down Expand Up @@ -68,9 +69,9 @@ def compress_scripts(source_path: str, destination: str, modules: List[ModuleTyp
# intended to be passed as a filter to tarfile.add
# https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.add
def tar_strip_file_attributes(tar_info: tarfile.TarInfo) -> tarfile.TarInfo:
# set time to epoch timestamp 0, aka 00:00:00 UTC on 1 January 1970
# set time to epoch timestamp 0, aka 00:00:00 UTC on 1 January 1980
# note that when extracting this tarfile, this time will be shown as the modified date
tar_info.mtime = 0
tar_info.mtime = datetime(1980, 1, 1).timestamp()

# user/group info
tar_info.uid = 0
Expand Down
20 changes: 12 additions & 8 deletions plugins/flytekit-spark/flytekitplugins/spark/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shutil
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Union, cast

Expand All @@ -8,7 +9,6 @@
from flytekit import FlyteContextManager, PythonFunctionTask, lazy_module, logger
from flytekit.configuration import DefaultImages, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters
from flytekit.core.python_auto_container import get_registerable_container_image
from flytekit.extend import ExecutionState, TaskPlugins
from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin
from flytekit.image_spec import ImageSpec
Expand Down Expand Up @@ -158,13 +158,6 @@ def __init__(
**kwargs,
)

def get_image(self, settings: SerializationSettings) -> str:
if isinstance(self.container_image, ImageSpec):
# Ensure that the code is always copied into the image, even during fast-registration.
self.container_image.source_root = settings.source_root

return get_registerable_container_image(self.container_image, settings.image_config)

def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
job = SparkJob(
spark_conf=self.task_config.spark_conf,
Expand Down Expand Up @@ -201,6 +194,17 @@ def pre_execute(self, user_params: ExecutionParameters) -> ExecutionParameters:
sess_builder = sess_builder.config(conf=spark_conf)

self.sess = sess_builder.getOrCreate()

if (
ctx.serialization_settings.fast_serialization_settings.enabled
and ctx.execution_state
and ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
):
file_name = "flyte_wf"
file_format = "zip"
shutil.make_archive(file_name, file_format, os.getcwd())
self.sess.sparkContext.addPyFile(f"{file_name}.{file_format}")

return user_params.builder().add_attr("SPARK_SESSION", self.sess).build()

def execute(self, **kwargs) -> Any:
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-spark/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

microlib_name = f"flytekitplugins-{PLUGIN_NAME}"

plugin_requires = ["flytekit>1.10.7", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]
plugin_requires = ["flytekit>1.13.5", "pyspark>=3.0.0", "aiohttp", "flyteidl>=1.11.0b1", "pandas"]

__version__ = "0.0.0+develop"

Expand Down
40 changes: 38 additions & 2 deletions plugins/flytekit-spark/tests/test_spark_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import os.path

import pandas as pd
import pyspark
import pytest

from flytekit.core import context_manager
from flytekitplugins.spark import Spark
from flytekitplugins.spark.task import Databricks, new_spark_session
from pyspark.sql import SparkSession

import flytekit
from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task
from flytekit.configuration import Image, ImageConfig, SerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager
from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings
from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -118,3 +122,35 @@ def test_to_html():
tf = StructuredDatasetTransformerEngine()
output = tf.to_html(FlyteContextManager.current_context(), sd, pyspark.sql.DataFrame)
assert pd.DataFrame(df.schema, columns=["StructField"]).to_html() == output


def test_spark_addPyFile():
@task(
task_config=Spark(
spark_conf={"spark": "1"},
)
)
def my_spark(a: int) -> int:
return a

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env={"FOO": "baz"},
image_config=ImageConfig(default_image=default_img, images=[default_img]),
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir="/User/flyte/workflows",
distribution_location="s3://my-s3-bucket/fast/123",
),
)

ctx = context_manager.FlyteContextManager.current_context()
with context_manager.FlyteContextManager.with_context(
ctx.with_execution_state(
ctx.new_execution_state().with_params(mode=ExecutionState.Mode.TASK_EXECUTION)).with_serialization_settings(serialization_settings)
) as new_ctx:
my_spark.pre_execute(new_ctx.user_space_params)
os.remove(os.path.join(os.getcwd(), "flyte_wf.zip"))
9 changes: 9 additions & 0 deletions tests/flytekit/unit/bin/test_python_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,15 @@ def test_setup_disk_prefix():
}


def test_setup_for_fast_register():
dynamic_addl_distro = "distro"
dynamic_dest_dir = "/root"
with setup_execution(raw_output_data_prefix="qwerty", dynamic_addl_distro=dynamic_addl_distro, dynamic_dest_dir=dynamic_dest_dir) as ctx:
assert ctx.serialization_settings.fast_serialization_settings.enabled is True
assert ctx.serialization_settings.fast_serialization_settings.distribution_location == dynamic_addl_distro
assert ctx.serialization_settings.fast_serialization_settings.destination_dir == dynamic_dest_dir


@mock.patch("google.auth.compute_engine._metadata")
def test_setup_cloud_prefix(mock_gcs):
with setup_execution("s3://", checkpoint_path=None, prev_checkpoint=None) as ctx:
Expand Down
Loading