Skip to content

Commit

Permalink
fix: Tweak more issues appeared from the migration.
Browse files Browse the repository at this point in the history
  • Loading branch information
Andres D. Molins committed Oct 24, 2024
1 parent 44510b8 commit eb31438
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 52 deletions.
6 changes: 0 additions & 6 deletions src/aleph/schemas/api/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
class MessageConfirmation(BaseModel):
"""Format of the result when a message has been confirmed on a blockchain"""

# TODO[pydantic]: The following keys were removed: `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
from_attributes=True, json_encoders={dt.datetime: lambda d: d.timestamp()}
)
Expand All @@ -49,8 +47,6 @@ class MessageConfirmation(BaseModel):


class BaseMessage(BaseModel, Generic[MType, ContentType]):
# TODO[pydantic]: The following keys were removed: `json_loads`, `json_encoders`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
from_attributes=True,
json_encoders={dt.datetime: lambda d: d.timestamp()},
Expand Down Expand Up @@ -211,8 +207,6 @@ class RejectedMessageStatus(BaseMessageStatus):


class MessageListResponse(BaseModel):
# TODO[pydantic]: The following keys were removed: `json_encoders`, `json_loads`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
json_encoders={dt.datetime: lambda d: d.timestamp()},
)
Expand Down
49 changes: 21 additions & 28 deletions src/aleph/schemas/base_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import datetime as dt
from hashlib import sha256
from typing import Any, Generic, Mapping, Optional, TypeVar, cast
from typing import Any, Generic, Optional, TypeVar, cast

from aleph_message.models import BaseContent, Chain, ItemType, MessageType
from pydantic import BaseModel, model_validator, validator
from pydantic import BaseModel, model_validator, field_validator, ValidationInfo

from aleph.toolkit.timestamp import timestamp_to_datetime
from aleph.utils import item_type_from_hash
Expand Down Expand Up @@ -36,45 +36,40 @@ class AlephBaseMessage(BaseModel, Generic[MType, ContentType]):

@model_validator(mode="after")
@classmethod
def check_item_type(cls, values):
def check_item_type(cls):
"""
Checks that the item hash of the message matches the one inferred from the hash.
Only applicable to storage/ipfs item types.
"""
item_type_value = values.get("item_type")
item_type_value = cls.item_type
if item_type_value is None:
raise ValueError("Could not determine item type")

item_type = ItemType(item_type_value)
if item_type == ItemType.inline:
return values

item_hash = values.get("item_hash")
if item_hash is None:
raise ValueError("Could not determine item hash")

expected_item_type = item_type_from_hash(item_hash)
if item_type != expected_item_type:
raise ValueError(
f"Expected {expected_item_type} based on hash but item type is {item_type}."
)
return values

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("item_hash")
def check_item_hash(cls, v: Any, values: Mapping[str, Any]):
if item_type != ItemType.inline:
item_hash = cls.item_hash
if item_hash is None:
raise ValueError("Could not determine item hash")

expected_item_type = item_type_from_hash(item_hash)
if item_type != expected_item_type:
raise ValueError(
f"Expected {expected_item_type} based on hash but item type is {item_type}."
)

@field_validator("item_hash", mode='after')
def check_item_hash(cls, v: Any, info: ValidationInfo):
"""
For inline item types, check that the item hash is equal to
the hash of the item content.
"""

item_type = values.get("item_type")
item_type = info.config.get("item_type")
if item_type is None:
raise ValueError("Could not determine item type")

if item_type == ItemType.inline:
item_content = cast(Optional[str], values.get("item_content"))
item_content = cast(Optional[str], info.config.get("item_content"))
if item_content is None:
raise ValueError("Could not find inline item content")

Expand All @@ -92,10 +87,8 @@ def check_item_hash(cls, v: Any, values: Mapping[str, Any]):
raise ValueError(f"Unknown item type: '{item_type}'")
return v

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("time", pre=True)
def check_time(cls, v, values):
@field_validator("time", mode='before')
def check_time(cls, v, info):
"""
Parses the time field as a UTC datetime. Contrary to the default datetime
validator, this implementation raises an exception if the time field is
Expand Down
6 changes: 2 additions & 4 deletions src/aleph/schemas/chains/indexer_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from enum import Enum
from typing import List, Protocol, Tuple

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator


class GenericMessageEvent(Protocol):
Expand Down Expand Up @@ -43,9 +43,7 @@ class AccountEntityState(BaseModel):
pending: List[Tuple[dt.datetime, dt.datetime]]
processed: List[Tuple[dt.datetime, dt.datetime]]

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("pending", "processed", pre=True, each_item=True)
@field_validator("pending", "processed", mode='before')
def split_datetime_ranges(cls, v):
if isinstance(v, str):
return v.split("/")
Expand Down
6 changes: 2 additions & 4 deletions src/aleph/schemas/chains/sync_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Annotated, List, Literal, Optional, Union

from aleph_message.models import Chain, ItemHash, ItemType, MessageType
from pydantic import BaseModel, ConfigDict, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from aleph.types.chain_sync import ChainSyncProtocol
from aleph.types.channel import Channel
Expand All @@ -21,9 +21,7 @@ class OnChainMessage(BaseModel):
time: float
channel: Optional[Channel] = None

# TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
@validator("time", pre=True)
@field_validator("time", mode='before')
def check_time(cls, v, values):
if isinstance(v, dt.datetime):
return v.timestamp()
Expand Down
11 changes: 5 additions & 6 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,15 @@ class BaseMessageQueryParams(BaseModel):
)

@model_validator(mode="after")
def validate_field_dependencies(cls, values):
start_date = values.get("start_date")
end_date = values.get("end_date")
def validate_field_dependencies(cls):
start_date = cls.start_date
end_date = cls.end_date
if start_date and end_date and (end_date < start_date):
raise ValueError("end date cannot be lower than start date.")
start_block = values.get("start_block")
end_block = values.get("end_block")
start_block = cls.start_block
end_block = cls.end_block
if start_block and end_block and (end_block < start_block):
raise ValueError("end block cannot be lower than start block.")
return values

@field_validator(
"hashes",
Expand Down
7 changes: 3 additions & 4 deletions src/aleph/web/controllers/posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,11 @@ class PostQueryParams(BaseModel):
)

@model_validator(mode="after")
def validate_field_dependencies(cls, values):
start_date = values.get("start_date")
end_date = values.get("end_date")
def validate_field_dependencies(cls):
start_date = cls.start_date
end_date = cls.end_date
if start_date and end_date and (end_date < start_date):
raise ValueError("end date cannot be lower than start date.")
return values

@field_validator(
"addresses", "hashes", "refs", "post_types", "channels", "tags", mode="before"
Expand Down

0 comments on commit eb31438

Please sign in to comment.