Skip to content

Commit

Permalink
[App] Fix local app run with relative import (#16835)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and carmocca committed Feb 27, 2023
1 parent c3a94ad commit 4ac3917
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/lightning_app/utilities/load_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def load_app_from_file(filepath: str, raise_exception: bool = False, mock_import
)

# TODO: Remove this, downstream code shouldn't depend on side-effects here but it does
_patch_sys_path(os.path.dirname(os.path.abspath(filepath))).__enter__()
sys.path.append(os.path.dirname(os.path.abspath(filepath)))
sys.modules["__main__"] = main_module

if len(apps) > 1:
Expand Down
5 changes: 5 additions & 0 deletions tests/tests_app/core/scripts/app_with_local_import.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from app_metadata import RootFlow

from lightning_app.core.app import LightningApp

app = LightningApp(RootFlow())
19 changes: 15 additions & 4 deletions tests/tests_app/utilities/test_load_app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import os
import sys
from unittest.mock import ANY

import pytest
import tests_app.core.scripts

from lightning_app.utilities.exceptions import MisconfigurationException
from lightning_app.utilities.load_app import extract_metadata_from_app, load_app_from_file


def test_load_app_from_file():
test_script_dir = os.path.join(os.path.dirname(tests_app.core.__file__), "scripts")
def test_load_app_from_file_errors():
test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
with pytest.raises(MisconfigurationException, match="There should not be multiple apps instantiated within a file"):
load_app_from_file(os.path.join(test_script_dir, "two_apps.py"))

Expand All @@ -20,8 +20,19 @@ def test_load_app_from_file():
load_app_from_file(os.path.join(test_script_dir, "script_with_error.py"))


@pytest.mark.parametrize("app_path", ["app_metadata.py", "app_with_local_import.py"])
def test_load_app_from_file(app_path):
"""Test that apps load without error and that sys.path and main module are set."""
original_main = sys.modules["__main__"]
test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
load_app_from_file(os.path.join(test_script_dir, app_path), raise_exception=True)

assert test_script_dir in sys.path
assert sys.modules["__main__"] != original_main


def test_extract_metadata_from_component():
test_script_dir = os.path.join(os.path.dirname(tests_app.core.__file__), "scripts")
test_script_dir = os.path.join(os.path.dirname(os.path.dirname(__file__)), "core", "scripts")
app = load_app_from_file(os.path.join(test_script_dir, "app_metadata.py"))
metadata = extract_metadata_from_app(app)
assert metadata == [
Expand Down

0 comments on commit 4ac3917

Please sign in to comment.