Skip to content

Commit

Permalink
Fix: Added typing on validation method.
Browse files Browse the repository at this point in the history
  • Loading branch information
Andres D. Molins committed Oct 24, 2024
1 parent abe0f95 commit 700f957
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/aleph/schemas/base_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ def check_item_type(self):
)
return self

@classmethod
@field_validator("item_hash", mode="after")
@classmethod
def check_item_hash(cls, v: Any, info: ValidationInfo):
"""
For inline item types, check that the item hash is equal to
Expand Down Expand Up @@ -88,8 +88,8 @@ def check_item_hash(cls, v: Any, info: ValidationInfo):
raise ValueError(f"Unknown item type: '{item_type}'")
return v

@classmethod
@field_validator("time", mode="before")
@classmethod
def check_time(cls, v: Any, info: ValidationInfo):
"""
Parses the time field as a UTC datetime. Contrary to the default datetime
Expand Down
4 changes: 2 additions & 2 deletions src/aleph/schemas/pending_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ class BasePendingMessage(AlephBaseMessage, Generic[MType, ContentType]):
A raw Aleph message, as sent by users to the Aleph network.
"""

@classmethod
@model_validator(mode="before")
def load_content(cls, values):
@classmethod
def load_content(cls, values: Any):
"""
Preload inline content. We let the CCN populate this field later
on for ipfs and storage item types.
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/toolkit/timestamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def timestamp_to_datetime(timestamp: float) -> dt.datetime:
object.
"""

return pytz.utc.localize(dt.datetime.fromtimestamp(timestamp, dt.UTC))
return pytz.utc.localize(dt.datetime.utcfromtimestamp(timestamp))


def coerce_to_datetime(
Expand Down
2 changes: 1 addition & 1 deletion src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ def validate_field_dependencies(self):

return self

@classmethod
@field_validator(
"hashes",
"addresses",
Expand All @@ -144,6 +143,7 @@ def validate_field_dependencies(self):
"tags",
mode="before",
)
@classmethod
def split_str(cls, v):
if isinstance(v, str):
return v.split(LIST_FIELD_SEPARATOR)
Expand Down
4 changes: 2 additions & 2 deletions src/aleph/web/controllers/posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,11 @@ def validate_field_dependencies(self):

return self

@classmethod
@field_validator(
"addresses", "hashes", "refs", "post_types", "channels", "tags", mode="before"
)
def split_str(cls, v):
@classmethod
def split_str(cls, v) -> List[str]:
if isinstance(v, str):
return v.split(LIST_FIELD_SEPARATOR)
return v
Expand Down
46 changes: 33 additions & 13 deletions tests/api/test_list_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ async def test_get_messages(fixture_messages: Sequence[Dict[str, Any]], ccn_api_
@pytest.mark.asyncio
async def test_get_messages_filter_by_channel(fixture_messages, ccn_api_client):
async def fetch_messages_by_channel(channel: str) -> Dict:
response = await ccn_api_client.get(MESSAGES_URI, params={"channels": [channel]})
response = await ccn_api_client.get(
MESSAGES_URI, params={"channels": channel}
)
assert response.status == 200, await response.text()
return await response.json()

Expand Down Expand Up @@ -104,7 +106,7 @@ async def fetch_messages_by_channel(channel: str) -> Dict:


async def fetch_messages_by_chain(api_client, chain: str) -> aiohttp.ClientResponse:
response = await api_client.get(MESSAGES_URI, params={"chains": [chain]})
response = await api_client.get(MESSAGES_URI, params={"chains": chain})
return response


Expand All @@ -130,7 +132,7 @@ async def test_get_messages_filter_invalid_chain(fixture_messages, ccn_api_clien
async def fetch_messages_by_content_hash(
api_client, item_hash: str
) -> aiohttp.ClientResponse:
response = await api_client.get(MESSAGES_URI, params={"contentHashes": [item_hash]})
response = await api_client.get(MESSAGES_URI, params={"contentHashes": item_hash})
return response


Expand Down Expand Up @@ -186,57 +188,75 @@ async def test_get_messages_filter_by_tags(
session.commit()

# Matching tag for both messages
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": ["mainnet"]})
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 2

# Matching tags for both messages
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": ["original,amend"]})
response = await ccn_api_client.get(
MESSAGES_URI, params={"tags": "original,amend"}
)
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 2

# Matching the original tag
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": ["original"]})
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "original"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 1
assert messages[0]["item_hash"] == message_db.item_hash

# Matching the amend tag
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": ["amend"]})
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "amend"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 1
assert messages[0]["item_hash"] == amend_message_db.item_hash

# No match
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": ["not-a-tag"]})
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "not-a-tag"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 0

# Matching the amend tag with other tags
response = await ccn_api_client.get(
MESSAGES_URI, params={"tags": ["amend,not-a-tag,not-a-tag-either"]}
MESSAGES_URI, params={"tags": "amend,not-a-tag,not-a-tag-either"}
)
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 1
assert messages[0]["item_hash"] == amend_message_db.item_hash


@pytest.mark.parametrize("type_field", ("msgType", "msgTypes"))
@pytest.mark.asyncio
async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str):
async def test_get_by_deprecated_message_type(fixture_messages, ccn_api_client):
messages_by_type = defaultdict(list)
for message in fixture_messages:
messages_by_type[message["type"]].append(message)

for message_type, expected_messages in messages_by_type.items():
response = await ccn_api_client.get(
MESSAGES_URI, params={"msgType": message_type}
)
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert set(msg["item_hash"] for msg in messages) == set(
msg["item_hash"] for msg in expected_messages
)


@pytest.mark.asyncio
async def test_get_by_message_type(fixture_messages, ccn_api_client):
messages_by_type = defaultdict(list)
for message in fixture_messages:
messages_by_type[message["type"]].append(message)

for message_type, expected_messages in messages_by_type.items():
response = await ccn_api_client.get(
MESSAGES_URI, params={type_field: message_type}
MESSAGES_URI, params={"msgTypes": [message_type]}
)
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
Expand All @@ -253,7 +273,7 @@ async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_cl
"""

# Matching tag
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": ["mainnet"]})
response = await ccn_api_client.get(MESSAGES_URI, params={"tags": "mainnet"})
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert len(messages) == 0
Expand Down
30 changes: 16 additions & 14 deletions tests/api/test_posts.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ async def test_get_posts_refs(

# Match the ref
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"refs": [f"{post_db.ref}"]}
"/api/v0/posts.json", params={"refs": f"{post_db.ref}"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -85,7 +85,7 @@ async def test_get_posts_refs(

# Unknown ref
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"refs": ["not-a-ref"]}
"/api/v0/posts.json", params={"refs": "not-a-ref"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -94,7 +94,7 @@ async def test_get_posts_refs(

# Search for several refs
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"refs": [f"{post_db.ref},not-a-ref"]}
"/api/v0/posts.json", params={"refs": f"{post_db.ref},not-a-ref"}
)
assert response.status == 200
response_json = await response.json()
Expand Down Expand Up @@ -131,7 +131,7 @@ async def test_get_amended_posts_refs(

# Match the ref
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"refs": [f"{original_post_db.ref}"]}
"/api/v0/posts.json", params={"refs": f"{original_post_db.ref}"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -146,7 +146,7 @@ async def test_get_amended_posts_refs(

# Unknown ref
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"refs": ["not-a-ref"]}
"/api/v0/posts.json", params={"refs": "not-a-ref"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -155,7 +155,7 @@ async def test_get_amended_posts_refs(

# Search for several refs
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"refs": [f"{original_post_db.ref},not-a-ref"]}
"/api/v0/posts.json", params={"refs": f"{original_post_db.ref},not-a-ref"}
)
assert response.status == 200
response_json = await response.json()
Expand Down Expand Up @@ -186,7 +186,7 @@ async def test_get_posts_tags(

# Match one tag
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["mainnet"]}
"/api/v0/posts.json", params={"tags": "mainnet"}
)
assert response.status == 200, await response.text()
response_json = await response.json()
Expand All @@ -200,7 +200,7 @@ async def test_get_posts_tags(

# Unknown tag
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["not-a-tag"]}
"/api/v0/posts.json", params={"tags": "not-a-tag"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -209,7 +209,7 @@ async def test_get_posts_tags(

# Search for several tags
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["mainnet,not-a-ref"]}
"/api/v0/posts.json", params={"tags": "mainnet,not-a-ref"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -225,7 +225,7 @@ async def test_get_posts_tags(
# Check for several matching tags
# Search for several tags
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["original,mainnet"]}
"/api/v0/posts.json", params={"tags": "original,mainnet"}
)
assert response.status == 200
response_json = await response.json()
Expand Down Expand Up @@ -261,7 +261,9 @@ async def test_get_amended_posts_tags(
session.commit()

# Match one tag
response = await ccn_api_client.get("/api/v0/posts.json", params={"tags": ["amend"]})
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": "amend"}
)
assert response.status == 200
response_json = await response.json()
assert len(response_json["posts"]) == 1
Expand All @@ -275,7 +277,7 @@ async def test_get_amended_posts_tags(

# Unknown tag
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["not-a-tag"]}
"/api/v0/posts.json", params={"tags": "not-a-tag"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -284,7 +286,7 @@ async def test_get_amended_posts_tags(

# Tag of the original
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["original"]}
"/api/v0/posts.json", params={"tags": "original"}
)
assert response.status == 200
response_json = await response.json()
Expand All @@ -293,7 +295,7 @@ async def test_get_amended_posts_tags(

# Search for several tags
response = await ccn_api_client.get(
"/api/v0/posts.json", params={"tags": ["mainnet,not-a-tag"]}
"/api/v0/posts.json", params={"tags": "mainnet,not-a-tag"}
)
assert response.status == 200
response_json = await response.json()
Expand Down
16 changes: 12 additions & 4 deletions tests/message_processing/test_process_instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,9 @@ async def test_process_instance_missing_volumes(
assert rejected_message.error_code == ErrorCode.VM_VOLUME_NOT_FOUND

if fixture_instance_message.item_content:
content = InstanceContent.model_validate_json(fixture_instance_message.item_content)
content = InstanceContent.model_validate_json(
fixture_instance_message.item_content
)
volume_refs = set(volume.ref for volume in get_volume_refs(content))
assert isinstance(rejected_message.details, dict)
assert set(rejected_message.details["errors"]) == volume_refs
Expand Down Expand Up @@ -453,7 +455,9 @@ async def test_get_volume_size(
session.commit()

if fixture_instance_message.item_content:
content = InstanceContent.model_validate_json(fixture_instance_message.item_content)
content = InstanceContent.model_validate_json(
fixture_instance_message.item_content
)
with session_factory() as session:
volume_size = get_volume_size(session=session, content=content)
assert volume_size == 21512585216
Expand All @@ -469,7 +473,9 @@ async def test_get_additional_storage_price(
session.commit()

if fixture_instance_message.item_content:
content = InstanceContent.model_validate_json(fixture_instance_message.item_content)
content = InstanceContent.model_validate_json(
fixture_instance_message.item_content
)
with session_factory() as session:
additional_price = get_additional_storage_price(
content=content, session=session
Expand All @@ -487,7 +493,9 @@ async def test_get_compute_cost(
session.commit()

if fixture_instance_message.item_content:
content = InstanceContent.model_validate_json(fixture_instance_message.item_content)
content = InstanceContent.model_validate_json(
fixture_instance_message.item_content
)
with session_factory() as session:
price: Decimal = compute_cost(content=content, session=session)
assert price == Decimal("2001.8")
Expand Down

0 comments on commit 700f957

Please sign in to comment.