Skip to content

Commit

Permalink
Merge pull request #183 from fmi-faim/pydantic-questionary
Browse files Browse the repository at this point in the history
Improve IPAConfig
  • Loading branch information
imagejan authored Oct 7, 2024
2 parents 932f0a6 + 300ddd9 commit d3f4b15
Show file tree
Hide file tree
Showing 3 changed files with 303 additions and 40 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ dependencies = [
"pandas",
"pint",
"pydantic>=2",
"questionary",
"scikit-image",
"tqdm",
]
Expand Down
208 changes: 170 additions & 38 deletions src/faim_ipa/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,19 @@
import os
from datetime import datetime
from pathlib import Path
from typing import TypeVar

import pydantic
from pydantic import BaseModel
import questionary
import yaml
from pydantic import (
BaseModel,
TypeAdapter,
ValidationError,
field_serializer,
field_validator,
)
from questionary import ValidationError as QuestionaryValidationError
from questionary import Validator


def wavelength_to_rgb(wavelength, gamma=0.8):
Expand Down Expand Up @@ -115,6 +125,8 @@ def resolve_with_git_root(relative_path: Path) -> Path:
Path
Absolute path to the file.
"""
if relative_path.is_absolute():
return relative_path
git_root = get_git_root()
return (git_root / relative_path).resolve()

Expand Down Expand Up @@ -142,41 +154,161 @@ def make_relative_to_git_root(path: Path) -> Path:
return Path(os.path.relpath(path, git_root))


def prompt_with_questionary(
model: "IPAConfig",
defaults: dict | None = None,
):
schema = model.model_json_schema()
defaults = defaults or {}
responses = {}

for field_name, field_info in schema["properties"].items():
description = field_info.get("description", field_name)
field_type = field_info["type"]
default_value = defaults.get(field_name, "")

if field_type == "string":
if field_info.get("format") == "path":
responses[field_name] = Path(
questionary.path(
f"Enter {description} [{default_value}]",
validate=PathValidator(
model=model,
field_name=field_name,
),
default=default_value,
).ask()
).absolute()
elif field_info.get("format") == "directory-path":
responses[field_name] = Path(
questionary.path(
f"Enter {description} (directory) [{default_value}]",
validate=PathValidator(
model=model,
field_name=field_name,
),
default=default_value,
).ask()
).absolute()
else:
responses[field_name] = questionary.text(
f"Enter {description} [{default_value}]",
validate=QuestionaryPydanticValidator(
model=model, field_name=field_name
),
default=default_value,
).ask()
elif field_type == "integer":
min_val = field_info.get("minimum", None)
max_val = field_info.get("maximum", None)
prompt_message = f"{description} ({f'minimum: {min_val}' if min_val else ''}-{f'maximum: {max_val}' if max_val else ''}) [{default_value}]"
responses[field_name] = int(
questionary.text(
prompt_message,
validate=QuestionaryPydanticValidator(
model=model, field_name=field_name
),
default=str(default_value),
).ask()
)
elif field_type == "number":
min_val = field_info.get("minimum", None)
max_val = field_info.get("maximum", None)
prompt_message = f"{description} ({f'minimum: {min_val}' if min_val else ''}-{f'maximum: {max_val}' if max_val else ''}) [{default_value}]"
responses[field_name] = float(
questionary.text(
prompt_message,
validate=QuestionaryPydanticValidator(
model=model, field_name=field_name
),
default=str(default_value),
).ask()
)
elif field_type == "boolean":
responses[field_name] = questionary.confirm(
prompt_message,
default=default_value,
).ask()
else:
msg = f"Unknown field type: {field_type}"
raise ValueError(msg)

return model(**responses)


class QuestionaryPydanticValidator(Validator):
def __init__(self, model: BaseModel, field_name: str):
self.field_name = field_name
self.model = model
self.field_info = model.model_fields[field_name]
self.type_adapter = TypeAdapter(self.field_info.annotation)

def _preprocess(self, value):
return value

def validate(self, document):
try:
value = self._preprocess(self.type_adapter.validate_python(document.text))
self.model.__pydantic_validator__.validate_assignment(
self.model.model_construct(), self.field_name, value
)
except ValidationError as e:
raise QuestionaryValidationError(
message=f"Invalid value for field: {e.errors()[0]['msg']}"
) from e


class PathValidator(QuestionaryPydanticValidator):
def __init__(self, model: BaseModel, field_name: str):
super().__init__(model, field_name)

def _preprocess(self, value):
return Path(value).absolute()


T = TypeVar("T", bound="IPAConfig")


class IPAConfig(BaseModel):

def make_paths_absolute(self):
"""
Convert all `pathlib.Path` fields to absolute paths.
The paths are assumed to be relative to a git root directory somewhere
in the parent directories of the class implementing `IPAConfig`.
"""
fields = (
self.model_fields_set
if pydantic.__version__.startswith("2")
else self.__fields_set__
)

for f in fields:
attr = getattr(self, f)
if isinstance(attr, Path) and not attr.is_absolute():
setattr(self, f, resolve_with_git_root(attr))

def make_paths_relative(self):
"""
Convert all `pathlib.Path` fields to relative paths.
The resulting paths will be relative to the git-root directory
somewhere in the parent directories of the class implementing
`IPAConfig`.
"""
fields = (
self.model_fields_set
if pydantic.__version__.startswith("2")
else self.__fields_set__
)

for f in fields:
attr = getattr(self, f)
if isinstance(attr, Path) and attr.is_absolute():
setattr(self, f, make_relative_to_git_root(attr))
@field_serializer("*")
@classmethod
def path_relative_to_git(cls, value):
if isinstance(value, Path):
try:
return str(make_relative_to_git_root(value))
except ValueError:
return str(value)
return value

@field_validator("*", mode="before")
@classmethod
def git_relative_path_to_absolute(cls, value, info):
field_name = info.field_name
field_type = cls.__annotations__[field_name]
if isinstance(field_type, type) and issubclass(field_type, Path):
return resolve_with_git_root(value)
if hasattr(field_type, "__metadata__") and issubclass(
field_type.__origin__, Path
):
return resolve_with_git_root(Path(value))
return value

@staticmethod
def reference_dir():
return get_git_root()

@staticmethod
def config_name():
return "config.yml"

def save(self, config_file=None):
config_file = config_file or Path.cwd() / self.config_name()
with open(config_file, "w") as f:
yaml.safe_dump(self.model_dump(), f, sort_keys=False)

@classmethod
def load(cls: type[T], config_file=None) -> T:
config_file = config_file or Path.cwd() / cls.config_name()
with open(config_file) as f:
return cls(**yaml.safe_load(f))
134 changes: 132 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,41 @@
import os
from os.path import basename
from os.path import basename, relpath
from pathlib import Path

from faim_ipa.utils import create_logger, wavelength_to_rgb
import pytest
import questionary
from pydantic import DirectoryPath, Field, create_model
from questionary import ValidationError

from faim_ipa.utils import (
IPAConfig,
QuestionaryPydanticValidator,
create_logger,
prompt_with_questionary,
wavelength_to_rgb,
)


@pytest.fixture
def dummy_file(tmp_path: Path):
tmp_file = tmp_path / "dummy.txt"
with open(tmp_file, "w"):
pass
return tmp_file


@pytest.fixture
def model():
return create_model(
"TestModel",
__base__=IPAConfig,
path=(Path, Field(..., description="File path")),
directory=(DirectoryPath, Field(..., description="Folder path")),
string=(str, Field(..., description="Some text")),
ge0=(int, Field(..., ge=0)),
number=(float, Field(..., gt=0.0)),
boolean=(bool, Field(..., description="Checkbox")),
)


def test_wavelength_to_rgb():
Expand All @@ -25,3 +59,99 @@ def test_create_logger(tmp_path_factory):

with open(logger.handlers[0].baseFilename) as f:
assert f.read().strip()[-11:] == "INFO - Test"


def test_validator(dummy_file: Path, model):
class Document:
text: str

def __init__(self, text):
self.text = text

path_validator = QuestionaryPydanticValidator(model=model, field_name="path")
assert (
path_validator.validate(document=Document(str(dummy_file.absolute()))) is None
)

string_validator = QuestionaryPydanticValidator(model=model, field_name="string")
assert string_validator.validate(document=Document("some text")) is None

directory_validator = QuestionaryPydanticValidator(
model=model, field_name="directory"
)
with pytest.raises(ValidationError):
directory_validator.validate(document=Document(str(dummy_file.absolute())))

ge0_validator = QuestionaryPydanticValidator(model=model, field_name="ge0")
with pytest.raises(ValidationError):
ge0_validator.validate(document=Document("-1"))


def test_prompt_with_questionary(model, mocker, dummy_file):
class Question:
def __init__(self, answers):
self.answers = answers
self._answerer = self._answer()

def _answer(self):
yield from self.answers

def ask(self):
return next(self._answerer)

text_patch = mocker.patch(
"questionary.text", return_value=Question(["text", "10", "0.01"])
)
path_patch = mocker.patch(
"questionary.path",
return_value=Question(
[str(dummy_file.absolute()), str(dummy_file.parent.absolute())]
),
)
confirm_patch = mocker.patch("questionary.confirm", return_value=Question([True]))
mocker.patch("faim_ipa.utils.get_git_root", return_value=Path.cwd())
response = prompt_with_questionary(model=model)
questionary.text.assert_called()
questionary.path.assert_called()
assert text_patch.call_count == 3
assert path_patch.call_count == 2
assert confirm_patch.call_count == 1
assert Path.cwd().name == "logs0"
assert response.directory == dummy_file.parent
assert response.path == dummy_file
assert response.model_dump() == {
"boolean": True,
"directory": str(relpath(dummy_file.parent, Path.cwd())),
"ge0": 10,
"number": 0.01,
"path": str(relpath(dummy_file, Path.cwd())),
"string": "text",
}


def test_ipa_config_with_path(tmp_path):
class SomeConfig(IPAConfig):
string: str
integer: int
number: float

config = SomeConfig(string="dummy", integer=42, number=-0.01)
config_path = tmp_path / "config.yml"
config.save(config_file=config_path)

loaded_config = SomeConfig.load(config_file=config_path)

assert loaded_config.model_dump() == config.model_dump()


def test_ipa_config():
class SomeConfig(IPAConfig):
string: str
integer: int
number: float

config = SomeConfig(string="dummy", integer=42, number=-0.01)
config.save()
assert (Path.cwd() / config.config_name()).exists()
loaded_config = SomeConfig.load()
assert loaded_config.model_dump() == config.model_dump()

0 comments on commit d3f4b15

Please sign in to comment.