Skip to content

Commit

Permalink
Added parameter to choose the deployments to use in pytests (#232)
Browse files Browse the repository at this point in the history
* added parameter to choose the deployments to use during pytests

* changed some variable and method names
  • Loading branch information
LanderOtto authored Sep 14, 2023
1 parent dae6826 commit 64406ce
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 38 deletions.
109 changes: 82 additions & 27 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import asyncio
import os
import platform
Expand Down Expand Up @@ -26,42 +27,86 @@
from streamflow.persistence.loading_context import DefaultDatabaseLoadingContext


def deployment_types():
def csvtype(choices):
"""Return a function that splits and checks comma-separated values."""

def splitarg(arg):
values = arg.split(",")
for value in values:
if value not in choices:
raise argparse.ArgumentTypeError(
"invalid choice: {!r} (choose from {})".format(
value, ", ".join(map(repr, choices))
)
)
return values

return splitarg


def pytest_addoption(parser):
parser.addoption(
"--deploys",
type=csvtype(all_deployment_types()),
default=all_deployment_types(),
help=f"List of deployments to deploy. Use the comma as delimiter e.g. --deploys local,docker. (default: {all_deployment_types()})",
)


@pytest.fixture(scope="module")
def chosen_deployment_types(request):
return request.config.getoption("--deploys")


def pytest_generate_tests(metafunc):
if "deployment_src" in metafunc.fixturenames:
metafunc.parametrize(
"deployment_src",
metafunc.config.getoption("deploys"),
scope="module",
)
if "deployment_dst" in metafunc.fixturenames:
metafunc.parametrize(
"deployment_dst",
metafunc.config.getoption("deploys"),
scope="module",
)


def all_deployment_types():
deployments_ = ["local", "docker", "ssh"]
if platform.system() == "Linux":
deployments_.extend(["kubernetes", "singularity"])
return deployments_


async def get_location(
_context: StreamFlowContext, request: pytest.FixtureRequest
) -> Location:
if request.param == "local":
async def get_location(_context: StreamFlowContext, deployment_t: str) -> Location:
if deployment_t == "local":
return Location(deployment=LOCAL_LOCATION, name=LOCAL_LOCATION)
elif request.param == "docker":
elif deployment_t == "docker":
connector = _context.deployment_manager.get_connector("alpine-docker")
locations = await connector.get_available_locations()
return Location(deployment="alpine-docker", name=next(iter(locations.keys())))
elif request.param == "kubernetes":
elif deployment_t == "kubernetes":
connector = _context.deployment_manager.get_connector("alpine-kubernetes")
locations = await connector.get_available_locations(service="sf-test")
return Location(
deployment="alpine-kubernetes",
service="sf-test",
name=next(iter(locations.keys())),
)
elif request.param == "singularity":
elif deployment_t == "singularity":
connector = _context.deployment_manager.get_connector("alpine-singularity")
locations = await connector.get_available_locations()
return Location(
deployment="alpine-singularity", name=next(iter(locations.keys()))
)
elif request.param == "ssh":
elif deployment_t == "ssh":
connector = _context.deployment_manager.get_connector("linuxserver-ssh")
locations = await connector.get_available_locations()
return Location(deployment="linuxserver-ssh", name=next(iter(locations.keys())))
else:
raise Exception(f"{request.param} location type not supported")
raise Exception(f"{deployment_t} location type not supported")


def get_docker_deployment_config():
Expand Down Expand Up @@ -140,26 +185,36 @@ async def get_ssh_deployment_config(_context: StreamFlowContext):
)


@pytest_asyncio.fixture(scope="session")
async def context() -> StreamFlowContext:
def get_local_deployment_config():
return DeploymentConfig(
name=LOCAL_LOCATION,
type="local",
config={},
external=True,
lazy=False,
workdir=os.path.realpath(tempfile.gettempdir()),
)


@pytest_asyncio.fixture(scope="module")
async def context(chosen_deployment_types) -> StreamFlowContext:
_context = build_context(
{"database": {"type": "default", "config": {"connection": ":memory:"}}},
)
await _context.deployment_manager.deploy(
DeploymentConfig(
name=LOCAL_LOCATION,
type="local",
config={},
external=True,
lazy=False,
workdir=os.path.realpath(tempfile.gettempdir()),
)
)
await _context.deployment_manager.deploy(get_docker_deployment_config())
await _context.deployment_manager.deploy(await get_ssh_deployment_config(_context))
if platform.system() == "Linux":
await _context.deployment_manager.deploy(get_kubernetes_deployment_config())
await _context.deployment_manager.deploy(get_singularity_deployment_config())
for deployment_t in chosen_deployment_types:
if deployment_t == "local":
config = get_local_deployment_config()
elif deployment_t == "docker":
config = get_docker_deployment_config()
elif deployment_t == "kubernetes":
config = get_kubernetes_deployment_config()
elif deployment_t == "singularity":
config = get_singularity_deployment_config()
elif deployment_t == "ssh":
config = await get_ssh_deployment_config(_context)
else:
raise Exception(f"{deployment_t} deployment type not supported")
await _context.deployment_manager.deploy(config)
yield _context
await _context.deployment_manager.undeploy_all()
# Close the database connection
Expand Down
8 changes: 4 additions & 4 deletions tests/test_remotepath.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from streamflow.core.deployment import Connector, Location
from streamflow.data import remotepath
from streamflow.deployment.utils import get_path_processor
from tests.conftest import deployment_types, get_location
from tests.conftest import get_location


@pytest_asyncio.fixture(scope="module", params=deployment_types())
async def location(context, request) -> Location:
return await get_location(context, request)
@pytest_asyncio.fixture(scope="module")
async def location(context, deployment_src) -> Location:
return await get_location(context, deployment_src)


@pytest.fixture(scope="module")
Expand Down
14 changes: 7 additions & 7 deletions tests/test_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
from streamflow.data import remotepath
from streamflow.deployment.connector import LocalConnector
from streamflow.deployment.utils import get_path_processor
from tests.conftest import deployment_types, get_location
from tests.conftest import get_location


@pytest_asyncio.fixture(scope="module", params=deployment_types())
async def src_location(context, request) -> Location:
return await get_location(context, request)
@pytest_asyncio.fixture(scope="module")
async def src_location(context, deployment_src) -> Location:
return await get_location(context, deployment_src)


@pytest.fixture(scope="module")
def src_connector(context, src_location) -> Connector:
return context.deployment_manager.get_connector(src_location.deployment)


@pytest_asyncio.fixture(scope="module", params=deployment_types())
async def dst_location(context, request) -> Location:
return await get_location(context, request)
@pytest_asyncio.fixture(scope="module")
async def dst_location(context, deployment_dst) -> Location:
return await get_location(context, deployment_dst)


@pytest.fixture(scope="module")
Expand Down

0 comments on commit 64406ce

Please sign in to comment.