Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 70 additions & 55 deletions src/rune/runtime/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from decimal import Decimal
from typing import Any, Never, get_args
import datetime
from typing_extensions import Self
from typing_extensions import Self, Tuple
from pydantic import (PlainSerializer, PlainValidator, WrapValidator,
WrapSerializer)
# from rune.runtime.object_registry import get_object
Expand Down Expand Up @@ -404,6 +404,17 @@ def validator(cls, allowed_meta: tuple[str] | tuple[Never, ...] = tuple()):

class BasicTypeMetaDataMixin(BaseMetaDataMixin):
'''holds the metadata associated with an instance'''
_INPUT_TYPES: Any | Tuple[Any, ...] = str # to be overridden by subclasses
_OUTPUT_TYPE: Any = str # to be overridden by subclasses
_JSON_OUTPUT = str | dict

@classmethod
def _check_type(cls, value):
if not isinstance(value, cls._INPUT_TYPES):
raise ValueError(f'{cls.__name__} can be instantiated only with '
f'one of the following type(s): {cls._INPUT_TYPES},'
f' however the value is of type {type(value)}')

@classmethod
def serialise(cls, obj, base_type) -> dict:
'''used as serialisation method with pydantic'''
Expand Down Expand Up @@ -431,7 +442,7 @@ def deserialize(cls, obj, handler, base_types, allowed_meta: set[str]):
@lru_cache
def serializer(cls):
'''should return the validator for the specific class'''
ser_fn = partial(cls.serialise, base_type=str)
ser_fn = partial(cls.serialise, base_type=cls._OUTPUT_TYPE)
return PlainSerializer(ser_fn, return_type=dict)

@classmethod
Expand All @@ -440,49 +451,62 @@ def validator(cls, allowed_meta: tuple[str]):
'''default validator for the specific class'''
allowed = set(allowed_meta)
return WrapValidator(partial(cls.deserialize,
base_types=str,
base_types=cls._INPUT_TYPES,
allowed_meta=allowed),
json_schema_input_type=str | dict)
json_schema_input_type=cls._JSON_OUTPUT)


class DateWithMeta(datetime.date, BasicTypeMetaDataMixin):
'''date with metadata'''
_INPUT_TYPES = (datetime.date, str)

def __new__(cls, value, **kwds): # pylint: disable=signature-differs
ymd = datetime.date.fromisoformat(value).timetuple()[:3]
cls._check_type(value)
if isinstance(value, str):
value = datetime.date.fromisoformat(value)
ymd = value.timetuple()[:3]
obj = datetime.date.__new__(cls, *ymd)
obj.set_meta(check_allowed=False, **kwds)
return obj


class TimeWithMeta(datetime.time, BasicTypeMetaDataMixin):
'''annotated time'''
_INPUT_TYPES = (datetime.time, str)

def __new__(cls, value, **kwds): # pylint: disable=signature-differs
aux = datetime.time.fromisoformat(value)
cls._check_type(value)
if isinstance(value, str):
value = datetime.time.fromisoformat(value)
obj = datetime.time.__new__(cls,
aux.hour,
aux.minute,
aux.second,
aux.microsecond,
aux.tzinfo,
fold=aux.fold)
value.hour,
value.minute,
value.second,
value.microsecond,
value.tzinfo,
fold=value.fold)
obj.set_meta(check_allowed=False, **kwds)
return obj


class DateTimeWithMeta(datetime.datetime, BasicTypeMetaDataMixin):
'''annotated datetime'''
_INPUT_TYPES = (datetime.datetime, str)

def __new__(cls, value, **kwds): # pylint: disable=signature-differs
aux = datetime.datetime.fromisoformat(value)
cls._check_type(value)
if isinstance(value, str):
value = datetime.datetime.fromisoformat(value)
obj = datetime.datetime.__new__(cls,
aux.year,
aux.month,
aux.day,
aux.hour,
aux.minute,
aux.second,
aux.microsecond,
aux.tzinfo,
fold=aux.fold)
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond,
value.tzinfo,
fold=value.fold)
obj.set_meta(check_allowed=False, **kwds)
return obj

Expand All @@ -500,54 +524,45 @@ def __new__(cls, value, **kwds):

class IntWithMeta(int, BasicTypeMetaDataMixin):
'''annotated integer'''
_INPUT_TYPES = int
_OUTPUT_TYPE = int
_JSON_OUTPUT = int | dict

def __new__(cls, value, **kwds):
obj = int.__new__(cls, value)
obj.set_meta(check_allowed=False, **kwds)
return obj

@classmethod
@lru_cache
def serializer(cls):
'''should return the validator for the specific class'''
ser_fn = partial(cls.serialise, base_type=int)
return PlainSerializer(ser_fn, return_type=dict)

@classmethod
@lru_cache
def validator(cls, allowed_meta: tuple[str]):
'''default validator for the specific class'''
allowed = set(allowed_meta)
return WrapValidator(partial(cls.deserialize,
base_types=int,
allowed_meta=allowed),
json_schema_input_type=int | dict)


class NumberWithMeta(Decimal, BasicTypeMetaDataMixin):
'''annotated number'''
_INPUT_TYPES = (Decimal, float, int, str)
_OUTPUT_TYPE = Decimal
_JSON_OUTPUT = float | int | str | dict

def __new__(cls, value, **kwds):
# NOTE: it could be necessary to convert the value to str if it is a
# float
obj = Decimal.__new__(cls, value)
obj.set_meta(check_allowed=False, **kwds)
return obj

@classmethod
@lru_cache
def serializer(cls):
'''should return the validator for the specific class'''
ser_fn = partial(cls.serialise, base_type=Decimal)
return PlainSerializer(ser_fn, return_type=dict)

@classmethod
@lru_cache
def validator(cls, allowed_meta: tuple[str]):
'''default validator for the specific class'''
allowed = set(allowed_meta)
return WrapValidator(partial(cls.deserialize,
base_types=(Decimal, float, int, str),
allowed_meta=allowed),
json_schema_input_type=float | int | str | dict)
# @classmethod
# @lru_cache
# def serializer(cls):
# '''should return the validator for the specific class'''
# ser_fn = partial(cls.serialise, base_type=Decimal)
# return PlainSerializer(ser_fn, return_type=dict)

# @classmethod
# @lru_cache
# def validator(cls, allowed_meta: tuple[str]):
# '''default validator for the specific class'''
# allowed = set(allowed_meta)
# return WrapValidator(partial(cls.deserialize,
# base_types=(Decimal, float, int, str),
# allowed_meta=allowed),
# json_schema_input_type=float | int | str | dict)


class _EnumWrapperDefaultVal(Enum):
Expand Down
23 changes: 23 additions & 0 deletions test/test_basic_types_with_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,29 @@ def test_dump_annotated_date_simple():
assert json_str == '{"date":{"@data":"2024-10-10"}}'


def test_dump_annotated_date_date():
'''test the annotated string'''
model = AnnotatedDateModel(date=date(2024, 10, 10))
json_str = model.model_dump_json(exclude_unset=True)
assert json_str == '{"date":{"@data":"2024-10-10"}}'

model = AnnotatedDateModel(date=DateWithMeta(date(2024, 10, 10)))
json_str = model.model_dump_json(exclude_unset=True)
assert json_str == '{"date":{"@data":"2024-10-10"}}'


def test_annotated_date_fail():
'''test instantiation failure with an incorrect type'''
with pytest.raises(AttributeError):
AnnotatedDateModel(date=10)


def test_date_with_meta_fail():
'''test instantiation failure with an incorrect type'''
with pytest.raises(ValueError):
DateWithMeta(10)


def test_load_annotated_date_scheme():
'''test the loading of annotated with a scheme strings'''
scheme_json = '{"date":{"@data":"2024-10-10","@scheme":"http://fpml.org"}}'
Expand Down