Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migration to Pydantic V2 #115

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 57 additions & 47 deletions aleph_message/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import datetime
import json
import logging
from copy import copy
from hashlib import sha256
from json import JSONDecodeError
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, TypeVar, Union, cast

from pydantic import BaseModel, Extra, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import TypeAlias

from .abstract import BaseContent, HashableModel
Expand All @@ -16,6 +17,10 @@
from .execution.program import ProgramContent
from .item_hash import ItemHash, ItemType

logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


__all__ = [
"AggregateContent",
"AggregateMessage",
Expand Down Expand Up @@ -54,8 +59,7 @@ class MongodbId(BaseModel):

oid: str = Field(alias="$oid")

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class ChainRef(BaseModel):
Expand All @@ -76,8 +80,7 @@ class MessageConfirmationHash(BaseModel):
binary: str = Field(alias="$binary")
type: str = Field(alias="$type")

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class MessageConfirmation(BaseModel):
Expand All @@ -93,15 +96,13 @@ class MessageConfirmation(BaseModel):
default=None, description="The address that published the transaction."
)

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class AggregateContentKey(BaseModel):
name: str

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class PostContent(BaseContent):
Expand All @@ -116,16 +117,15 @@ class PostContent(BaseContent):
)
type: str = Field(description="User-generated 'content-type' of a POST message")

@validator("type")
@field_validator("type")
def check_type(cls, v, values):
if v == "amend":
ref = values.get("ref")
ref = values.data.get("ref")
if not ref:
raise ValueError("A 'ref' is required for POST type 'amend'")
return v

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class AggregateContent(BaseContent):
Expand All @@ -136,8 +136,7 @@ class AggregateContent(BaseContent):
)
content: Dict = Field(description="The content of an aggregate must be a dict")

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")


class StoreContent(BaseContent):
Expand All @@ -148,10 +147,11 @@ class StoreContent(BaseContent):
size: Optional[int] = None # Generated by the node on storage
content_type: Optional[str] = None # Generated by the node on storage
ref: Optional[str] = None
metadata: Optional[Dict[str, Any]] = Field(description="Metadata of the VM")
metadata: Optional[Dict[str, Any]] = Field(
default=None, description="Metadata of the VM"
)

class Config:
extra = Extra.allow
model_config = ConfigDict(extra="allow")


class ForgetContent(BaseContent):
Expand Down Expand Up @@ -214,9 +214,9 @@ class BaseMessage(BaseModel):

forgotten_by: Optional[List[str]]

@validator("item_content")
@field_validator("item_content")
def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
item_type = values["item_type"]
item_type = values.data.get("item_type")
if v is None:
return None
elif item_type == ItemType.inline:
Expand All @@ -232,14 +232,14 @@ def check_item_content(cls, v: Optional[str], values) -> Optional[str]:
)
return v

@validator("item_hash")
@field_validator("item_hash")
def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
item_type = values["item_type"]
item_type = values.data.get("item_type")
if item_type == ItemType.inline:
item_content: str = values["item_content"]
item_content: str = values.data.get("item_content")

# Double check that the hash function is supported
hash_type = values["hash_type"] or HashType.sha256
hash_type = values.data.get("hash_type") or HashType.sha256
assert hash_type.value == HashType.sha256

computed_hash: str = sha256(item_content.encode()).hexdigest()
Expand All @@ -255,49 +255,56 @@ def check_item_hash(cls, v: ItemHash, values) -> ItemHash:
assert item_type == ItemType.storage
return v

@validator("confirmed")
@field_validator("confirmed")
def check_confirmed(cls, v, values):
confirmations = values["confirmations"]
confirmations = values.data.get("confirmations")
if v is True and not bool(confirmations):
raise ValueError("Message cannot be 'confirmed' without 'confirmations'")
return v

@validator("time")
@field_validator("time")
def convert_float_to_datetime(cls, v, values):
if isinstance(v, float):
v = datetime.datetime.fromtimestamp(v)
assert isinstance(v, datetime.datetime)
return v

class Config:
extra = Extra.forbid
exclude = {"id_", "_id"}
model_config = ConfigDict(extra="forbid")

def custom_dump(self):
"""Exclude MongoDB identifiers from dumps for historical reasons."""
return self.model_dump(exclude={"id_", "_id"})


class PostMessage(BaseMessage):
"""Unique data posts (unique data points, events, ...)"""

type: Literal[MessageType.post]
content: PostContent
forgotten_by: Optional[List[str]] = None


class AggregateMessage(BaseMessage):
"""A key-value storage specific to an address"""

type: Literal[MessageType.aggregate]
content: AggregateContent
forgotten_by: Optional[list] = None


class StoreMessage(BaseMessage):
type: Literal[MessageType.store]
content: StoreContent
forgotten_by: Optional[list] = None
metadata: Optional[Dict[str, Any]] = None


class ForgetMessage(BaseMessage):
type: Literal[MessageType.forget]
content: ForgetContent
forgotten_by: Optional[list] = None

@validator("forgotten_by")
@field_validator("forgotten_by")
def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[str]]:
assert values
if v:
Expand All @@ -308,25 +315,29 @@ def cannot_be_forgotten(cls, v: Optional[List[str]], values) -> Optional[List[st
class ProgramMessage(BaseMessage):
type: Literal[MessageType.program]
content: ProgramContent
forgotten_by: Optional[List[str]] = None

@validator("content")
@field_validator("content")
def check_content(cls, v, values):
item_type = values["item_type"]
"""Ensure that the content of the message is correctly formatted."""
item_type = values.data.get("item_type")
if item_type == ItemType.inline:
item_content = json.loads(values["item_content"])
if v.dict(exclude_none=True) != item_content:
# Print differences
vdict = v.dict(exclude_none=True)
for key, value in item_content.items():
if vdict[key] != value:
print(f"{key}: {vdict[key]} != {value}")
# Ensure that the content correct JSON
item_content = json.loads(values.data.get("item_content"))
# Ensure that the content matches the expected structure
if v.model_dump(exclude_none=True) != item_content:
logger.warning(
"Content and item_content differ for message %s",
values.data["item_hash"],
)
raise ValueError("Content and item_content differ")
return v


class InstanceMessage(BaseMessage):
type: Literal[MessageType.instance]
content: InstanceContent
forgotten_by: Optional[List[str]] = None


AlephMessage: TypeAlias = Union[
Expand Down Expand Up @@ -363,12 +374,12 @@ def parse_message(message_dict: Dict) -> AlephMessage:
message_class.__annotations__["type"].__args__[0]
)
if message_dict["type"] == message_type:
return message_class.parse_obj(message_dict)
return message_class.model_validate(message_dict)
else:
raise ValueError(f"Unknown message type {message_dict['type']}")


def add_item_content_and_hash(message_dict: Dict, inplace: bool = False):
def add_item_content_and_hash(message_dict: Dict, inplace: bool = False) -> Dict:
if not inplace:
message_dict = copy(message_dict)

Expand All @@ -390,7 +401,7 @@ def create_new_message(
"""
message_content = add_item_content_and_hash(message_dict)
if factory:
return cast(T, factory.parse_obj(message_content))
return cast(T, factory.model_validate(message_content))
else:
return cast(T, parse_message(message_content))

Expand All @@ -405,7 +416,7 @@ def create_message_from_json(
message_dict = json.loads(json_data)
message_content = add_item_content_and_hash(message_dict, inplace=True)
if factory:
return factory.parse_obj(message_content)
return factory.model_validate(message_content)
else:
return parse_message(message_content)

Expand All @@ -422,7 +433,7 @@ def create_message_from_file(
message_dict = decoder.load(fd)
message_content = add_item_content_and_hash(message_dict, inplace=True)
if factory:
return factory.parse_obj(message_content)
return factory.model_validate(message_content)
else:
return parse_message(message_content)

Expand All @@ -436,5 +447,4 @@ class MessagesResponse(BaseModel):
pagination_per_page: int
pagination_item: str

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")
5 changes: 2 additions & 3 deletions aleph_message/models/abstract.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import BaseModel, Extra
from pydantic import BaseModel, ConfigDict


def hashable(obj):
Expand All @@ -24,5 +24,4 @@ class BaseContent(BaseModel):
address: str
time: float

class Config:
extra = Extra.forbid
model_config = ConfigDict(extra="forbid")
2 changes: 1 addition & 1 deletion aleph_message/models/execution/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .abstract import BaseExecutableContent
from .base import Encoding, Interface, MachineType, Payment, PaymentType
from .instance import InstanceContent
from .program import ProgramContent
from .base import Encoding, MachineType, PaymentType, Payment, Interface

__all__ = [
"BaseExecutableContent",
Expand Down
2 changes: 1 addition & 1 deletion aleph_message/models/execution/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class Payment(HashableModel):

chain: Chain
"""Which chain to check for funds"""
receiver: Optional[str]
receiver: Optional[str] = None
"""Optional alternative address to send tokens to"""
type: PaymentType
"""Whether to pay by holding $ALEPH or by streaming tokens"""
Expand Down
Loading
Loading