Skip to content

Commit

Permalink
GH-5732: Make TypeEngine.lazy_import_transformers() thread safe (#2735
Browse files Browse the repository at this point in the history
)

* Progress on test

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Working test

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Fix

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Implement with a lock instead

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Autoformat

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* tests

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Mark test as serial

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Autoformat

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

* Avoid asserting on mock_call signature

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>

---------

Signed-off-by: Thomas Newton <thomas.w.newton@gmail.com>
  • Loading branch information
Tom-Newton committed Sep 10, 2024
1 parent 26559fa commit ae9c6f8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 46 deletions.
96 changes: 51 additions & 45 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import mimetypes
import sys
import textwrap
import threading
import typing
from abc import ABC, abstractmethod
from collections import OrderedDict
Expand Down Expand Up @@ -842,6 +843,7 @@ class TypeEngine(typing.Generic[T]):
_DATACLASS_TRANSFORMER: TypeTransformer = DataclassTransformer() # type: ignore
_ENUM_TRANSFORMER: TypeTransformer = EnumTransformer() # type: ignore
has_lazy_import = False
lazy_import_lock = threading.Lock()

@classmethod
def register(
Expand Down Expand Up @@ -995,51 +997,55 @@ def lazy_import_transformers(cls):
"""
Only load the transformers if needed.
"""
if cls.has_lazy_import:
return
cls.has_lazy_import = True
from flytekit.types.structured import (
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
register_snowflake_handlers,
)
from flytekit.types.structured.structured_dataset import DuplicateHandlerError

if is_imported("tensorflow"):
from flytekit.extras import tensorflow # noqa: F401
if is_imported("torch"):
from flytekit.extras import pytorch # noqa: F401
if is_imported("sklearn"):
from flytekit.extras import sklearn # noqa: F401
if is_imported("pandas"):
try:
from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401
except ValueError:
logger.debug("Transformer for pandas is already registered.")
try:
register_pandas_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for pandas is already registered.")
if is_imported("pyarrow"):
try:
register_arrow_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for arrow is already registered.")
if is_imported("google.cloud.bigquery"):
try:
register_bigquery_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for bigquery is already registered.")
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401
if is_imported("PIL"):
from flytekit.types.file import image # noqa: F401
if is_imported("snowflake.connector"):
try:
register_snowflake_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for snowflake is already registered.")
with cls.lazy_import_lock:
# Avoid a race condition where concurrent threads may exit lazy_import_transformers before the transformers
# have been imported. This could be implemented without a lock if you assume python assignments are atomic
# and re-registering transformers is acceptable, but I decided to play it safe.
if cls.has_lazy_import:
return
cls.has_lazy_import = True
from flytekit.types.structured import (
register_arrow_handlers,
register_bigquery_handlers,
register_pandas_handlers,
register_snowflake_handlers,
)
from flytekit.types.structured.structured_dataset import DuplicateHandlerError

if is_imported("tensorflow"):
from flytekit.extras import tensorflow # noqa: F401
if is_imported("torch"):
from flytekit.extras import pytorch # noqa: F401
if is_imported("sklearn"):
from flytekit.extras import sklearn # noqa: F401
if is_imported("pandas"):
try:
from flytekit.types.schema.types_pandas import PandasSchemaReader, PandasSchemaWriter # noqa: F401
except ValueError:
logger.debug("Transformer for pandas is already registered.")
try:
register_pandas_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for pandas is already registered.")
if is_imported("pyarrow"):
try:
register_arrow_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for arrow is already registered.")
if is_imported("google.cloud.bigquery"):
try:
register_bigquery_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for bigquery is already registered.")
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401
if is_imported("PIL"):
from flytekit.types.file import image # noqa: F401
if is_imported("snowflake.connector"):
try:
register_snowflake_handlers()
except DuplicateHandlerError:
logger.debug("Transformer for snowflake is already registered.")

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
Expand Down
31 changes: 30 additions & 1 deletion tests/flytekit/unit/core/test_type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import mock
import pytest
import typing_extensions
from concurrent.futures import ThreadPoolExecutor
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import errors_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -73,7 +74,7 @@
from flytekit.types.pickle import FlytePickle
from flytekit.types.pickle.pickle import BatchSize, FlytePickleTransformer
from flytekit.types.schema import FlyteSchema
from flytekit.types.structured.structured_dataset import StructuredDataset
from flytekit.types.structured.structured_dataset import StructuredDataset, StructuredDatasetTransformerEngine

T = typing.TypeVar("T")

Expand Down Expand Up @@ -3246,3 +3247,31 @@ def outer_workflow(input: OuterWorkflowInput) -> OuterWorkflowOutput:
assert float_value_output == 1.0, f"Float value was {float_value_output}, not 1.0 as expected"
none_value_output = outer_workflow(OuterWorkflowInput(input=0)).nullable_output
assert none_value_output is None, f"None value was {none_value_output}, not None as expected"


@pytest.mark.serial
def test_lazy_import_transformers_concurrently():
# Ensure that next call to TypeEngine.lazy_import_transformers doesn't skip the import. Mark as serial to ensure
# this achieves what we expect.
TypeEngine.has_lazy_import = False

# Configure the mocks similar to https://stackoverflow.com/questions/29749193/python-unit-testing-with-two-mock-objects-how-to-verify-call-order
after_import_mock, mock_register = mock.Mock(), mock.Mock()
mock_wrapper = mock.Mock()
mock_wrapper.mock_register = mock_register
mock_wrapper.after_import_mock = after_import_mock

with mock.patch.object(StructuredDatasetTransformerEngine, "register", new=mock_register):
def run():
TypeEngine.lazy_import_transformers()
after_import_mock()

N = 5
with ThreadPoolExecutor(max_workers=N) as executor:
futures = [executor.submit(run) for _ in range(N)]
[f.result() for f in futures]

# Assert that all the register calls come before anything else.
assert mock_wrapper.mock_calls[-N:] == [mock.call.after_import_mock()]*N
expected_number_of_register_calls = len(mock_wrapper.mock_calls) - N
assert all([mock_call[0] == "mock_register" for mock_call in mock_wrapper.mock_calls[:expected_number_of_register_calls]])

0 comments on commit ae9c6f8

Please sign in to comment.