Skip to content

Commit

Permalink
test(compute_engine): use the orchestrator mock in command test (#1098)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexandrePicosson authored Jun 10, 2022
1 parent 2e18f7c commit 780ab88
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 122 deletions.
230 changes: 108 additions & 122 deletions backend/substrapp/tests/compute_tasks/test_command.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,33 @@
import json
from typing import Dict

from pytest_mock import MockerFixture

import orchestrator.computetask_pb2 as computetask_pb2
import substrapp.tests.assets as assets
import orchestrator.model_pb2 as model_pb2
from substrapp.compute_tasks.command import _get_args
from substrapp.compute_tasks.context import Context
from substrapp.tests.common import get_compute_plan
from substrapp.tests.common import get_data_manager
from substrapp.tests.common import get_task
from substrapp.tests.common import get_task_metrics
from substrapp.tests.common import get_test_task_input_models
from substrapp.tests.orchestrator_factory import Orchestrator

_CHANNEL = "mychannel"
_TASK_CATEGORY_NAME_TRAIN = computetask_pb2.ComputeTaskCategory.Name(computetask_pb2.TASK_TRAIN)
_TASK_CATEGORY_NAME_COMPOSITE = computetask_pb2.ComputeTaskCategory.Name(computetask_pb2.TASK_COMPOSITE)


def test_get_args_train_task():
task = assets.get_train_task()
cp = get_compute_plan(task["compute_plan_key"])
dm = get_data_manager(task["train"]["data_manager_key"])
in_models = get_test_task_input_models(task)

ctx = Context(
channel_name=_CHANNEL,
task=task,
task_category=computetask_pb2.TASK_TRAIN,
task_key=task["key"],
compute_plan=cp,
compute_plan_key=cp["key"],
compute_plan_tag=None,
in_models=in_models,
algo=task["algo"],
metrics={},
data_manager=dm,
directories={},
has_chainkeys=False,
)


def test_get_args_train_task(mocker: MockerFixture, orchestrator: Orchestrator):
mocker.patch("substrapp.compute_tasks.context.get_orchestrator_client", return_value=orchestrator.client)

cp = orchestrator.create_compute_plan()
parent_task = orchestrator.create_train_task(compute_plan_key=cp.key)
orchestrator.create_model(compute_task_key=parent_task.key)
train_task = orchestrator.create_train_task(compute_plan_key=cp.key, parent_task_keys=[parent_task.key])

ctx = Context.from_task(_CHANNEL, orchestrator.client.query_task(train_task.key))

inputs = []
for m in get_test_task_input_models(task):
inputs.append({"id": "models", "value": f"/substra_internal/in_models/{m['key']}"})
for model in orchestrator.client.get_computetask_input_models(train_task.key):
inputs.append({"id": "models", "value": f"/substra_internal/in_models/{model['key']}"})
inputs.append(
{"id": "opener", "value": f"/substra_internal/openers/{task['train']['data_manager_key']}/__init__.py"}
{"id": "opener", "value": f"/substra_internal/openers/{train_task.train.data_manager_key}/__init__.py"}
)
for ds_key in task["train"]["data_sample_keys"]:
for ds_key in train_task.train.data_sample_keys:
inputs.append({"id": "datasamples", "value": f"/substra_internal/data_samples/{ds_key}"})

outputs = [
Expand All @@ -64,39 +47,46 @@ def test_get_args_train_task():
]


def test_get_args_composite_task():
task = assets.get_composite_task()
cp = get_compute_plan(task["compute_plan_key"])
dm = get_data_manager(task["composite"]["data_manager_key"])
in_models = get_test_task_input_models(task)

ctx = Context(
channel_name=_CHANNEL,
task=task,
task_category=computetask_pb2.TASK_COMPOSITE,
task_key=task["key"],
compute_plan=cp,
compute_plan_key=cp["key"],
compute_plan_tag=None,
in_models=in_models,
algo=task["algo"],
metrics={},
data_manager=dm,
directories={},
has_chainkeys=False,
def test_get_args_composite_task(mocker: MockerFixture, orchestrator: Orchestrator):
mocker.patch("substrapp.compute_tasks.context.get_orchestrator_client", return_value=orchestrator.client)

cp = orchestrator.create_compute_plan()
parent_task = orchestrator.create_composite_train_task(compute_plan_key=cp.key)
orchestrator.create_model(compute_task_key=parent_task.key, category=model_pb2.MODEL_SIMPLE)
orchestrator.create_model(compute_task_key=parent_task.key, category=model_pb2.MODEL_HEAD)

inputs = [
computetask_pb2.ComputeTaskInput(
identifier="shared",
parent_task_output=computetask_pb2.ParentTaskOutputRef(
output_identifier="shared", parent_task_key=parent_task.key
),
),
computetask_pb2.ComputeTaskInput(
identifier="local",
parent_task_output=computetask_pb2.ParentTaskOutputRef(
output_identifier="shared", parent_task_key=parent_task.key
),
),
]

composite_task = orchestrator.create_composite_train_task(
compute_plan_key=cp.key, parent_task_keys=[parent_task.key], inputs=inputs
)

ctx = Context.from_task(_CHANNEL, orchestrator.client.query_task(composite_task.key))

inputs = []
in_models = get_test_task_input_models(task)
in_models = orchestrator.client.get_computetask_input_models(composite_task.key)
if in_models:
inputs.append({"id": "local", "value": f"/substra_internal/in_models/{in_models[0]['key']}"})
inputs.append({"id": "shared", "value": f"/substra_internal/in_models/{in_models[1]['key']}"})
inputs.append({"id": "shared", "value": f"/substra_internal/in_models/{in_models[0]['key']}"})
inputs.append({"id": "local", "value": f"/substra_internal/in_models/{in_models[1]['key']}"})

inputs.append(
{"id": "opener", "value": f"/substra_internal/openers/{task['composite']['data_manager_key']}/__init__.py"}
{"id": "opener", "value": f"/substra_internal/openers/{composite_task.composite.data_manager_key}/__init__.py"}
)

for ds_key in task["composite"]["data_sample_keys"]:
for ds_key in composite_task.composite.data_sample_keys:
inputs.append({"id": "datasamples", "value": f"/substra_internal/data_samples/{ds_key}"})

outputs = [
Expand All @@ -117,20 +107,26 @@ def test_get_args_composite_task():
]


def test_get_args_predict_train():
task = _get_test_task_with_parent_of_type(_TASK_CATEGORY_NAME_TRAIN)
ctx = _get_test_ctx(task)
def test_get_args_predict_train(mocker: MockerFixture, orchestrator: Orchestrator):
mocker.patch("substrapp.compute_tasks.context.get_orchestrator_client", return_value=orchestrator.client)

cp = orchestrator.create_compute_plan()
parent_task = orchestrator.create_train_task(compute_plan_key=cp.key)
orchestrator.create_model(compute_task_key=parent_task.key)

test_task = orchestrator.create_test_task(parent_task_keys=[parent_task.key])

ctx = Context.from_task(_CHANNEL, orchestrator.client.query_task(test_task.key))

inputs = []
inputs += [
{"id": "models", "value": f"/substra_internal/in_models/{m['key']}"} for m in get_test_task_input_models(task)
{"id": "models", "value": f"/substra_internal/in_models/{m['key']}"}
for m in orchestrator.client.get_computetask_input_models(test_task.key)
]

inputs.append(
{"id": "opener", "value": f"/substra_internal/openers/{task['test']['data_manager_key']}/__init__.py"}
)
inputs.append({"id": "opener", "value": f"/substra_internal/openers/{test_task.test.data_manager_key}/__init__.py"})

for ds_key in task["test"]["data_sample_keys"]:
for ds_key in test_task.test.data_sample_keys:
inputs.append({"id": "datasamples", "value": f"/substra_internal/data_samples/{ds_key}"})

outputs = [
Expand All @@ -142,43 +138,55 @@ def test_get_args_predict_train():
assert actual == ["predict", "--inputs", f"'{json.dumps(inputs)}'", "--outputs", f"'{json.dumps(outputs)}'"]


def test_get_args_eval_train():
task = _get_test_task_with_parent_of_type(_TASK_CATEGORY_NAME_TRAIN)
metric_key = task["test"]["metric_keys"][0]
ctx = _get_test_ctx(task)
def test_get_args_eval_train(mocker: MockerFixture, orchestrator: Orchestrator):
mocker.patch("substrapp.compute_tasks.context.get_orchestrator_client", return_value=orchestrator.client)

cp = orchestrator.create_compute_plan()
parent_task = orchestrator.create_train_task(compute_plan_key=cp.key)
orchestrator.create_model(compute_task_key=parent_task.key)

test_task = orchestrator.create_test_task(parent_task_keys=[parent_task.key])

ctx = Context.from_task(_CHANNEL, orchestrator.client.query_task(test_task.key))

cmd = [
"--input-predictions-path",
"/substra_internal/pred/pred.json",
"--opener-path",
f"/substra_internal/openers/{task['test']['data_manager_key']}/__init__.py",
f"/substra_internal/openers/{test_task.test.data_manager_key}/__init__.py",
]

cmd.append("--data-sample-paths")
for ds_key in task["test"]["data_sample_keys"]:
for ds_key in test_task.test.data_sample_keys:
cmd.append(f"/substra_internal/data_samples/{ds_key}")

cmd += ["--output-perf-path", f"/substra_internal/perf/{task['test']['metric_keys'][0]}-perf.json"]
cmd += ["--output-perf-path", f"/substra_internal/perf/{test_task.test.metric_keys[0]}-perf.json"]

actual = _get_args(ctx, metric_key, True)
actual = _get_args(ctx, test_task.test.metric_keys[0], True)
assert actual == cmd


def test_get_args_predict_composite():
task = _get_test_task_with_parent_of_type(_TASK_CATEGORY_NAME_COMPOSITE)
ctx = _get_test_ctx(task)
def test_get_args_predict_composite(mocker: MockerFixture, orchestrator: Orchestrator):
mocker.patch("substrapp.compute_tasks.context.get_orchestrator_client", return_value=orchestrator.client)

cp = orchestrator.create_compute_plan()
parent_task = orchestrator.create_composite_train_task(compute_plan_key=cp.key)
orchestrator.create_model(compute_task_key=parent_task.key, category=model_pb2.MODEL_HEAD)
orchestrator.create_model(compute_task_key=parent_task.key, category=model_pb2.MODEL_SIMPLE)

test_task = orchestrator.create_test_task(parent_task_keys=[parent_task.key])

ctx = Context.from_task(_CHANNEL, orchestrator.client.query_task(test_task.key))

inputs = []

in_models = get_test_task_input_models(task)
in_models = orchestrator.client.get_computetask_input_models(test_task.key)
inputs.append({"id": "local", "value": f"/substra_internal/in_models/{in_models[0]['key']}"})
inputs.append({"id": "shared", "value": f"/substra_internal/in_models/{in_models[1]['key']}"})

inputs.append(
{"id": "opener", "value": f"/substra_internal/openers/{task['test']['data_manager_key']}/__init__.py"}
)
inputs.append({"id": "opener", "value": f"/substra_internal/openers/{test_task.test.data_manager_key}/__init__.py"})

for ds_key in task["test"]["data_sample_keys"]:
for ds_key in test_task.test.data_sample_keys:
inputs.append({"id": "datasamples", "value": f"/substra_internal/data_samples/{ds_key}"})

outputs = [
Expand All @@ -190,52 +198,30 @@ def test_get_args_predict_composite():
assert actual == ["predict", "--inputs", f"'{json.dumps(inputs)}'", "--outputs", f"'{json.dumps(outputs)}'"]


def test_get_args_eval_composite():
task = _get_test_task_with_parent_of_type(_TASK_CATEGORY_NAME_COMPOSITE)
metric_key = task["test"]["metric_keys"][0]
ctx = _get_test_ctx(task)
def test_get_args_eval_composite(mocker: MockerFixture, orchestrator: Orchestrator):
mocker.patch("substrapp.compute_tasks.context.get_orchestrator_client", return_value=orchestrator.client)

cp = orchestrator.create_compute_plan()
parent_task = orchestrator.create_composite_train_task(compute_plan_key=cp.key)
orchestrator.create_model(compute_task_key=parent_task.key, category=model_pb2.MODEL_HEAD)
orchestrator.create_model(compute_task_key=parent_task.key, category=model_pb2.MODEL_SIMPLE)

test_task = orchestrator.create_test_task(parent_task_keys=[parent_task.key])

ctx = Context.from_task(_CHANNEL, orchestrator.client.query_task(test_task.key))

cmd = [
"--input-predictions-path",
"/substra_internal/pred/pred.json",
"--opener-path",
f"/substra_internal/openers/{task['test']['data_manager_key']}/__init__.py",
f"/substra_internal/openers/{test_task.test.data_manager_key}/__init__.py",
]

cmd.append("--data-sample-paths")
for ds_key in task["test"]["data_sample_keys"]:
for ds_key in test_task.test.data_sample_keys:
cmd.append(f"/substra_internal/data_samples/{ds_key}")

cmd += ["--output-perf-path", f"/substra_internal/perf/{task['test']['metric_keys'][0]}-perf.json"]
cmd += ["--output-perf-path", f"/substra_internal/perf/{test_task.test.metric_keys[0]}-perf.json"]

actual = _get_args(ctx, metric_key, True)
actual = _get_args(ctx, test_task.test.metric_keys[0], True)
assert actual == cmd


def _get_test_ctx(task: Dict) -> Context:
cp = get_compute_plan(task["compute_plan_key"])
metrics = get_task_metrics(task)
in_models = get_test_task_input_models(task)

return Context(
channel_name=_CHANNEL,
task=task,
task_category=computetask_pb2.TASK_TEST,
task_key=task["key"],
compute_plan=cp,
compute_plan_key=cp["key"],
compute_plan_tag=None,
in_models=in_models,
algo=task["algo"],
metrics=metrics,
data_manager=None,
directories={},
has_chainkeys=False,
)


def _get_test_task_with_parent_of_type(category_name: str) -> Dict:
for t in assets.get_test_tasks():
if get_task(t["parent_task_keys"][0])["category"] == category_name:
return t
raise Exception("assets.py doesn't contain any test task with a parent of type " + category_name)
8 changes: 8 additions & 0 deletions backend/substrapp/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest

import substrapp.tests.orchestrator_factory


@pytest.fixture
def orchestrator() -> substrapp.tests.orchestrator_factory.Orchestrator:
return substrapp.tests.orchestrator_factory.Orchestrator()

0 comments on commit 780ab88

Please sign in to comment.