Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 76 additions & 8 deletions src/dodal/common/data_util.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,22 @@
import os
from typing import TypeVar
from os.path import isabs, isfile, join, split
from typing import Protocol, Self, TypeVar

from pydantic import BaseModel

TBaseModel = TypeVar("TBaseModel", bound=BaseModel)
TBaseModel = TypeVar("TBaseModel", bound=BaseModel, covariant=True)


def load_json_file_to_class(
t: type[TBaseModel],
file: str,
) -> TBaseModel:
if not os.path.isfile(file):
def load_json_file_to_class(t: type[TBaseModel], file: str) -> TBaseModel:
"""Load json file into a pydantic model class.

Args:
t (type[TBaseModel]): type of model to load a file into to.
file (str): The file to read.

Returns:
An instance of pydantic BaseModel.
"""
if not isfile(file):
raise FileNotFoundError(f"Cannot find file {file}")

with open(file) as f:
Expand All @@ -20,5 +26,67 @@ def load_json_file_to_class(


def save_class_to_json_file(model: BaseModel, file: str) -> None:
"""Save a pydantic model as a json file.

Args:
model (BaseModel): The pydantic model to save to json file.
file (str): The file path to save the model to.
"""
with open(file, "w") as f:
f.write(model.model_dump_json())


class JsonModelLoader(Protocol[TBaseModel]):
def __call__(self, file: str | None = None) -> TBaseModel: ...


class JsonLoaderConfig(BaseModel):
default_path: str
default_file: str | None

@classmethod
def from_default_file(cls, default_file: str) -> Self:
"""Create instance by splitting path from file to set defaults."""
default_path, default_file = split(default_file)
return cls(default_path=default_path, default_file=default_file)

@classmethod
def from_default_path(cls, default_path: str) -> Self:
"""Create instance by only setting a default path."""
return cls(default_path=default_path, default_file=None)

def update_config_from_file(self, new_file: str) -> None:
"""Update exisiting config by splitting path from file to set new defaults."""
self.default_path, self.default_file = split(new_file)


def json_model_loader(
model: type[TBaseModel], config: JsonLoaderConfig | None = None
) -> JsonModelLoader[TBaseModel]:
"""Factory to create a function that loads a json file into a configured pydantic
model and with optional configuration for default path and file to use.
"""

def load_json(file: str | None = None) -> TBaseModel:
"""Load a json file and return it is as the configured pydantic model.

Args:
file (str, optional): The file to load into a pydantic class. If None
provided, use the default_file from the config.

Returns:
An instance of the configurated pydantic base_model type.
"""
if file is None:
if config is None or config.default_file is None:
raise RuntimeError(
f"{model.__name__} loader has no default file configured "
"and no file was provided."
)
file = config.default_file

if not isabs(file) and config is not None:
file = join(config.default_path, file)
return load_json_file_to_class(model, file)

return load_json
201 changes: 201 additions & 0 deletions tests/common/test_data_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from os.path import split

import pytest
from pydantic import BaseModel

from dodal.common.data_util import (
JsonLoaderConfig,
JsonModelLoader,
json_model_loader,
save_class_to_json_file,
)


class MyModel(BaseModel):
value: str
number: float


def assert_model(result: MyModel, expected: MyModel) -> None:
assert result.value == expected.value
assert result.number == expected.number


@pytest.fixture
def default_model() -> MyModel:
return MyModel(value="test1", number=0)


@pytest.fixture
def other_model() -> MyModel:
return MyModel(value="test2", number=3)


@pytest.fixture
def default_tmp_file(tmp_path, default_model: MyModel) -> str:
path = tmp_path / "json_loader_default_file.json"
# Setup tmp file for this test by saving the data.
save_class_to_json_file(default_model, path)
return str(path)


@pytest.fixture
def tmp_file(tmp_path, other_model: MyModel) -> str:
path = tmp_path / "json_loader_file.json"
# Setup tmp file for this test by saving the data.
save_class_to_json_file(other_model, path)
return str(path)


@pytest.fixture
def load_json_model_with_default_file_only(
default_tmp_file: str,
) -> JsonModelLoader[MyModel]:
return json_model_loader(
MyModel, JsonLoaderConfig.from_default_file(default_tmp_file)
)


def test_json_model_loader_with_configured_default_file_only(
load_json_model_with_default_file_only: JsonModelLoader[MyModel],
tmp_file: str,
other_model: MyModel,
default_model: MyModel,
) -> None:
model_result = load_json_model_with_default_file_only()
assert_model(model_result, default_model)

# Test we can use relative file path
path, file = split(tmp_file)
model_result = load_json_model_with_default_file_only(file)
assert_model(model_result, other_model)

# Test we can override with absolute path.
model_result = load_json_model_with_default_file_only(tmp_file)
assert_model(model_result, other_model)


@pytest.fixture
def load_json_model_with_default_path_only(
default_tmp_file: str,
) -> JsonModelLoader[MyModel]:
path, file = split(default_tmp_file)
return json_model_loader(MyModel, JsonLoaderConfig.from_default_path(path))


def test_load_json_model_with_configued_path_only(
load_json_model_with_default_path_only: JsonModelLoader[MyModel],
tmp_file: str,
other_model: MyModel,
) -> None:
# Test we can use relative file path
path, file = split(tmp_file)
model_result = load_json_model_with_default_path_only(file)
assert_model(model_result, other_model)

# Test we can still use absolute file path
model_result = load_json_model_with_default_path_only(tmp_file)
assert_model(model_result, other_model)

with pytest.raises(
RuntimeError,
match="MyModel loader has no default file configured and no file was provided.",
):
load_json_model_with_default_path_only()


@pytest.fixture
def load_json_model_with_default_path_and_file(
default_tmp_file: str,
) -> JsonModelLoader[MyModel]:
path, file = split(default_tmp_file)
return json_model_loader(
MyModel, JsonLoaderConfig(default_path=path, default_file=file)
)


def test_load_json_model_with_configued_path_and_file(
load_json_model_with_default_path_and_file: JsonModelLoader[MyModel],
tmp_file: str,
other_model: MyModel,
default_model: MyModel,
) -> None:
# Test uses default file
model_result = load_json_model_with_default_path_and_file()
assert_model(model_result, default_model)

# Test we can use relative file path
path, file = split(tmp_file)
model_result = load_json_model_with_default_path_and_file(file)
assert_model(model_result, other_model)

# Test we can still use absolute file path
model_result = load_json_model_with_default_path_and_file(tmp_file)
assert_model(model_result, other_model)


@pytest.fixture
def load_json_model_no_config() -> JsonModelLoader[MyModel]:
return json_model_loader(MyModel)


def test_json_model_loader_with_no_config(
load_json_model_no_config: JsonModelLoader[MyModel],
tmp_file: str,
other_model: MyModel,
) -> None:
with pytest.raises(
RuntimeError,
match="MyModel loader has no default file configured and no file was provided.",
):
load_json_model_no_config()

with pytest.raises(FileNotFoundError):
# Test using a relative path fails
path, file = split(tmp_file)
load_json_model_no_config(file)

# Test we can still use absolute file path
model_result = load_json_model_no_config(tmp_file)
assert_model(model_result, other_model)


def test_updating_config_updates_factory_function(
default_tmp_file: str, tmp_file: str, default_model: MyModel, other_model: MyModel
) -> None:
config = JsonLoaderConfig.from_default_file(default_tmp_file)
model_loader = json_model_loader(MyModel, config)

# Test uses default file
model_result = model_loader()
assert_model(model_result, default_model)

# Test uses new default file
config.update_config_from_file(tmp_file)
model_result = model_loader()
assert_model(model_result, other_model)


@pytest.fixture
def all_json_model_loaders(
load_json_model_with_default_file_only: JsonModelLoader[MyModel],
load_json_model_with_default_path_only: JsonModelLoader[MyModel],
load_json_model_with_default_path_and_file: JsonModelLoader[MyModel],
load_json_model_no_config: JsonModelLoader[MyModel],
) -> list[JsonModelLoader[MyModel]]:
return [
load_json_model_with_default_file_only,
load_json_model_with_default_path_only,
load_json_model_with_default_path_and_file,
load_json_model_no_config,
]


@pytest.mark.parametrize("loader_position", range(4))
def test_all_json_model_loader_raise_error_if_invalid_file(
all_json_model_loaders: list[JsonModelLoader[MyModel]],
loader_position: int,
) -> None:
json_loader = all_json_model_loaders[loader_position]
with pytest.raises(FileNotFoundError):
json_loader("sdkgsk")
35 changes: 15 additions & 20 deletions tests/devices/electron_analyser/helper_util/sequence.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dodal.common.data_util import load_json_file_to_class
from dodal.common.data_util import JsonLoaderConfig, json_model_loader
from dodal.devices.beamlines import b07, b07_shared, i09
from dodal.devices.electron_analyser.specs import (
SpecsAnalyserDriverIO,
Expand All @@ -18,29 +18,24 @@
TEST_SEQUENCE_REGION_NAMES = ["New_Region", "New_Region1", "New_Region2"]


def b07_specs_test_sequence_loader() -> SpecsSequence[b07.LensMode, b07_shared.PsuMode]:
return load_json_file_to_class(
SpecsSequence[b07.LensMode, b07_shared.PsuMode], TEST_SPECS_SEQUENCE
)


def i09_vgscienta_test_sequence_loader() -> VGScientaSequence[
i09.LensMode, i09.PsuMode, i09.PassEnergy
]:
return load_json_file_to_class(
VGScientaSequence[i09.LensMode, i09.PsuMode, i09.PassEnergy],
TEST_VGSCIENTA_SEQUENCE,
)
load_b07_specs_test_sequence = json_model_loader(
SpecsSequence[b07.LensMode, b07_shared.PsuMode],
JsonLoaderConfig.from_default_file(TEST_SPECS_SEQUENCE),
)
load_i09_vgscienta_test_sequence = json_model_loader(
VGScientaSequence[i09.LensMode, i09.PsuMode, i09.PassEnergy],
JsonLoaderConfig.from_default_file(TEST_VGSCIENTA_SEQUENCE),
)


# Map to know what function to load in sequence an analyser driver should use.
TEST_SEQUENCES = {
SpecsDetector: b07_specs_test_sequence_loader,
SpecsAnalyserDriverIO: b07_specs_test_sequence_loader,
SpecsSequence: b07_specs_test_sequence_loader,
VGScientaDetector: i09_vgscienta_test_sequence_loader,
VGScientaAnalyserDriverIO: i09_vgscienta_test_sequence_loader,
VGScientaSequence: i09_vgscienta_test_sequence_loader,
SpecsDetector: load_b07_specs_test_sequence,
SpecsAnalyserDriverIO: load_b07_specs_test_sequence,
SpecsSequence: load_b07_specs_test_sequence,
VGScientaDetector: load_i09_vgscienta_test_sequence,
VGScientaAnalyserDriverIO: load_i09_vgscienta_test_sequence,
VGScientaSequence: load_i09_vgscienta_test_sequence,
}


Expand Down