Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ jobs:
runs-on: ${{ matrix.runs-on }}
python-version: ${{ matrix.python-version }}

system-test:
uses: ./.github/workflows/_system_test.yml
# https://github.com/DiamondLightSource/blueapi/issues/1297
# Temporarily disabled until Tiled release with authz
# system-test:
# uses: ./.github/workflows/_system_test.yml

container:
needs: test
Expand Down
10 changes: 10 additions & 0 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
DodalSource,
EnvironmentConfig,
PlanSource,
TiledConfig,
)
from blueapi.core.protocols import DeviceManager
from blueapi.utils import (
Expand Down Expand Up @@ -121,6 +122,7 @@ class BlueskyContext:
run_engine: RunEngine = field(
default_factory=lambda: RunEngine(context_managers=[])
)
tiled_conf: TiledConfig | None = field(default=None, init=False, repr=False)
numtracker: NumtrackerClient | None = field(default=None, init=False, repr=False)
path_provider: PathProvider | None = None
plans: dict[str, Plan] = field(default_factory=dict)
Expand Down Expand Up @@ -173,6 +175,14 @@ def _update_scan_num(md: dict[str, Any]) -> int:
"the devices. Remove this path provider to use numtracker."
)

if (tiled_conf := configuration.tiled) is not None and tiled_conf.enabled:
if configuration.env.metadata is None:
raise InvalidConfigError(
"Tiled has been configured but `instrument` metadata is not set - "
"this field is required to make authorization decisions."
)
self.tiled_conf = tiled_conf

def find_device(self, addr: str | list[str]) -> Device | None:
"""
Find a device in this context, allows for recursive search.
Expand Down
54 changes: 38 additions & 16 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tiled.client import from_uri

from blueapi.cli.scratch import get_python_environment
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig, TiledConfig
from blueapi.config import ApplicationConfig, OIDCConfig, StompConfig
from blueapi.core.context import BlueskyContext
from blueapi.core.event import EventStream
from blueapi.log import set_up_logging
Expand All @@ -20,7 +20,8 @@
TaskRequest,
WorkerTask,
)
from blueapi.worker.event import TaskStatusEnum, WorkerState
from blueapi.utils.serialization import access_blob
from blueapi.worker.event import TaskStatusEnum, WorkerEvent, WorkerState
from blueapi.worker.task import Task
from blueapi.worker.task_worker import TaskWorker, TrackableTask

Expand Down Expand Up @@ -87,16 +88,6 @@ def stomp_client() -> StompClient | None:
return None


@cache
def tiled_writer() -> TiledWriter | None:
tiled_config: TiledConfig = config().tiled
if tiled_config.enabled:
client = from_uri(str(tiled_config.url), api_key=tiled_config.api_key)
return TiledWriter(client, batch_size=1)
else:
return None


def setup(config: ApplicationConfig) -> None:
"""Creates and starts a worker with supplied config"""
set_config(config)
Expand All @@ -105,8 +96,6 @@ def setup(config: ApplicationConfig) -> None:
# Eagerly initialize worker and messaging connection
worker()
stomp_client()
if writer := tiled_writer():
context().run_engine.subscribe(writer)


def teardown() -> None:
Expand All @@ -116,7 +105,6 @@ def teardown() -> None:
context.cache_clear()
worker.cache_clear()
stomp_client.cache_clear()
tiled_writer.cache_clear()


def _publish_event_streams(
Expand Down Expand Up @@ -158,10 +146,20 @@ def get_device(name: str) -> DeviceModel:

def submit_task(task_request: TaskRequest) -> str:
"""Submit a task to be run on begin_task"""
metadata: dict[str, Any] = {
"instrument_session": task_request.instrument_session,
}
if context().tiled_conf is not None:
md = config().env.metadata
# We raise an InvalidConfigError on setting tiled_conf if this isn't set
assert md
metadata["tiled_access_tags"] = [
access_blob(task_request.instrument_session, md.instrument)
]
task = Task(
name=task_request.name,
params=task_request.params,
metadata={"instrument_session": task_request.instrument_session},
metadata=metadata,
)
return worker().submit_task(task)

Expand All @@ -177,6 +175,30 @@ def begin_task(
"""Trigger a task. Will fail if the worker is busy"""
if nt := context().numtracker:
nt.set_headers(pass_through_headers or {})

if tiled_config := context().tiled_conf:
# Tiled queries the root node, so must create an authorized client
tiled_client = from_uri(
str(tiled_config.url),
api_key=tiled_config.api_key,
headers=pass_through_headers,
)
tiled_writer_token = context().run_engine.subscribe(
TiledWriter(tiled_client, batch_size=1)
)

def remove_callback_when_task_finished(
event: WorkerEvent, correlation_id: str | None
) -> None:
if (
event.task_status
and event.task_status.task_id == task.task_id
and event.task_status.task_complete
):
context().run_engine.unsubscribe(tiled_writer_token)

worker().worker_events.subscribe(remove_callback_when_task_finished)

if task.task_id is not None:
worker().begin_task(task.task_id)
return task
Expand Down
22 changes: 22 additions & 0 deletions src/blueapi/utils/serialization.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json
import re
from typing import Any

from pydantic import BaseModel
Expand All @@ -24,3 +26,23 @@ def serialize(obj: Any) -> Any:
return serialize(obj.__pydantic_model__)
else:
return obj


_INSTRUMENT_SESSION_AUTHZ_REGEX: re.Pattern = re.compile(
r"^[a-zA-Z]{2}(?P<proposal>\d+)-(?P<visit>\d+)$"
)


def access_blob(instrument_session: str, beamline: str) -> str:
m = _INSTRUMENT_SESSION_AUTHZ_REGEX.match(instrument_session)
if m is None:
raise ValueError(
"Unable to extract proposal and visit from "
f"instrument session {instrument_session}"
)
blob = {
"proposal": int(m["proposal"]),
"visit": int(m["visit"]),
"beamline": beamline,
}
return json.dumps(blob)
5 changes: 4 additions & 1 deletion tests/system_tests/test_blueapi_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def test_instrument_session_propagated(client: BlueapiClient):
response = client.create_task(_SIMPLE_TASK)
trackable_task = client.get_task(response.task_id)
assert trackable_task.task.metadata == {
"instrument_session": FAKE_INSTRUMENT_SESSION
"instrument_session": FAKE_INSTRUMENT_SESSION,
"tiled_access_tags": [
'{"proposal": 12345, "visit": 1, "beamline": "adsim"}',
],
}


Expand Down
40 changes: 40 additions & 0 deletions tests/unit_tests/core/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,20 @@
from pytest import LogCaptureFixture

from blueapi.config import (
ApplicationConfig,
DeviceManagerSource,
DeviceSource,
DodalSource,
EnvironmentConfig,
MetadataConfig,
PlanSource,
TiledConfig,
)
from blueapi.core import BlueskyContext, is_bluesky_compatible_device
from blueapi.core.context import DefaultFactory, generic_bounds, qualified_name
from blueapi.core.protocols import DeviceConnectResult, DeviceManager
from blueapi.utils.connect_devices import _establish_device_connections
from blueapi.utils.invalid_config_error import InvalidConfigError

SIM_MOTOR_NAME = "sim"
ALT_MOTOR_NAME = "alt"
Expand Down Expand Up @@ -837,3 +840,40 @@ def test_non_device_manager_errors(empty_context: BlueskyContext):
imp_mod.side_effect = lambda mod: dev_mod if mod == "foo.bar" else None
with pytest.raises(ValueError, match="not a device manager"):
empty_context.with_config(env)


def test_setup_without_tiled_not_makes_tiled_inserter():
config = TiledConfig(enabled=False)
context = BlueskyContext(
ApplicationConfig(
tiled=config,
env=EnvironmentConfig(metadata=MetadataConfig(instrument="ixx")),
)
)
assert context.tiled_conf is None


def test_setup_default_not_makes_tiled_inserter():
context = BlueskyContext(ApplicationConfig())
assert context.tiled_conf is None


@pytest.mark.parametrize("api_key", [None, "foo"])
def test_setup_with_tiled_makes_tiled_inserter(api_key: str | None):
config = TiledConfig(enabled=True, api_key=api_key)
context = BlueskyContext(
ApplicationConfig(
tiled=config,
env=EnvironmentConfig(metadata=MetadataConfig(instrument="ixx")),
)
)
assert context.tiled_conf == config


@pytest.mark.parametrize("api_key", [None, "foo"])
def test_must_have_instrument_set_for_tiled(api_key: str | None):
config = TiledConfig(enabled=True, api_key=api_key)
with pytest.raises(InvalidConfigError):
BlueskyContext(
ApplicationConfig(tiled=config, env=EnvironmentConfig(metadata=None))
)
58 changes: 25 additions & 33 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json
import uuid
from dataclasses import dataclass
from typing import Any
from unittest.mock import ANY, MagicMock, Mock, patch

import pytest
Expand All @@ -25,7 +27,6 @@
PlanSource,
ScratchConfig,
StompConfig,
TiledConfig,
)
from blueapi.core.context import BlueskyContext
from blueapi.service import interface
Expand Down Expand Up @@ -316,10 +317,19 @@ def test_get_tasks(get_tasks_mock: MagicMock):
assert interface.get_tasks() == tasks


@pytest.mark.parametrize("tiled_enabled", [True, False])
@patch("blueapi.service.interface.context")
def test_get_task_by_id(context_mock: MagicMock):
@patch("blueapi.service.interface.config")
def test_get_task_by_id(
config_mock: MagicMock, context_mock: MagicMock, tiled_enabled: bool
):
context = BlueskyContext()
context.register_plan(my_plan)
if tiled_enabled:
context.tiled_conf = MagicMock()
config_mock.return_value = ApplicationConfig(
env=EnvironmentConfig(metadata=MetadataConfig(instrument="ixx"))
)
context_mock.return_value = context

task_id = interface.submit_task(
Expand All @@ -329,15 +339,25 @@ def test_get_task_by_id(context_mock: MagicMock):
)
)

expected_metadata: dict[str, Any] = {
"instrument_session": FAKE_INSTRUMENT_SESSION,
}

if tiled_enabled:
expected_access_tag = {
"proposal": 12345,
"visit": 1,
"beamline": "ixx",
}
expected_metadata["tiled_access_tags"] = [json.dumps(expected_access_tag)]

assert interface.get_task_by_id(task_id) == TrackableTask.model_construct(
task_id=task_id,
request_id=ANY,
task=Task(
name="my_plan",
params={},
metadata={
"instrument_session": FAKE_INSTRUMENT_SESSION,
},
metadata=expected_metadata,
),
is_complete=False,
is_pending=True,
Expand Down Expand Up @@ -509,34 +529,6 @@ def test_setup_with_numtracker_makes_start_document_provider():
clear_path_provider()


def test_setup_without_tiled_not_makes_tiled_inserter():
with patch("blueapi.service.interface.from_uri") as from_uri:
conf = ApplicationConfig()
interface.setup(conf)

assert from_uri.call_count == 0


def test_setup_with_tiled_makes_tiled_inserter():
with patch("blueapi.service.interface.from_uri") as from_uri:
conf = ApplicationConfig(tiled=TiledConfig(enabled=True))
interface.setup(conf)

assert from_uri.call_count == 1
assert from_uri.call_args.args == ("http://localhost:8407/",)
assert from_uri.call_args.kwargs == {"api_key": None}


def test_setup_with_tiled_api_key_makes_tiled_inserter():
with patch("blueapi.service.interface.from_uri") as from_uri:
conf = ApplicationConfig(tiled=TiledConfig(enabled=True, api_key="foobarbaz"))
interface.setup(conf)

assert from_uri.call_count == 1
assert from_uri.call_args.args == ("http://localhost:8407/",)
assert from_uri.call_args.kwargs == {"api_key": "foobarbaz"}


def test_setup_with_numtracker_raises_if_provider_is_defined_in_device_module():
conf = ApplicationConfig(
env=EnvironmentConfig(
Expand Down
Loading