Skip to content

Commit

Permalink
Allow multiple message types (#444)
Browse files Browse the repository at this point in the history
Problem: `msgType` is the only parameter when fetching messages, which does not take a list of items as a parameter.

Solution: add `msgTypes` as a parameter to the API and deprecate `msgType`.
  • Loading branch information
MHHukiewitz authored Sep 5, 2023
1 parent 71cf823 commit dbf2048
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/aleph/db/accessors/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def make_matching_messages_query(
refs: Optional[Sequence[str]] = None,
chains: Optional[Sequence[Chain]] = None,
message_type: Optional[MessageType] = None,
message_types: Optional[Sequence[MessageType]] = None,
start_date: Optional[Union[float, dt.datetime]] = None,
end_date: Optional[Union[float, dt.datetime]] = None,
content_hashes: Optional[Sequence[ItemHash]] = None,
Expand Down Expand Up @@ -87,6 +88,8 @@ def make_matching_messages_query(
select_stmt = select_stmt.where(MessageDb.sender.in_(addresses))
if chains:
select_stmt = select_stmt.where(MessageDb.chain.in_(chains))
if message_types:
select_stmt = select_stmt.where(MessageDb.type.in_(message_types))
if message_type:
select_stmt = select_stmt.where(MessageDb.type == message_type)
if start_datetime:
Expand Down
7 changes: 5 additions & 2 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ class BaseMessageQueryParams(BaseModel):
"-1 means most recent messages first, 1 means older messages first.",
)
message_type: Optional[MessageType] = Field(
default=None, alias="msgType", description="Message type."
default=None, alias="msgType", description="Message type. Deprecated: use msgTypes instead"
)
message_types: Optional[List[MessageType]] = Field(
default=None, alias="msgTypes", description="Accepted message types."
)
addresses: Optional[List[str]] = Field(
default=None, description="Accepted values for the 'sender' field."
Expand Down Expand Up @@ -120,6 +123,7 @@ def validate_field_dependencies(cls, values):
"content_types",
"chains",
"channels",
"message_types",
"tags",
pre=True,
)
Expand Down Expand Up @@ -356,7 +360,6 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse:
except ValidationError as e:
raise web.HTTPUnprocessableEntity(body=e.json(indent=4))

message_filters = query_params.dict(exclude_none=True)
history = query_params.history

if history:
Expand Down
19 changes: 19 additions & 0 deletions tests/api/test_list_messages.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime as dt
import itertools
from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union, Tuple

import aiohttp
Expand Down Expand Up @@ -221,6 +222,24 @@ async def test_get_messages_filter_by_tags(
assert messages[0]["item_hash"] == amend_message_db.item_hash


@pytest.mark.parametrize("type_field", ("msgType", "msgTypes"))
@pytest.mark.asyncio
async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str):
messages_by_type = defaultdict(list)
for message in fixture_messages:
messages_by_type[message["type"]].append(message)

for message_type, expected_messages in messages_by_type.items():
response = await ccn_api_client.get(
MESSAGES_URI, params={type_field: message_type}
)
assert response.status == 200, await response.text()
messages = (await response.json())["messages"]
assert set(msg["item_hash"] for msg in messages) == set(
msg["item_hash"] for msg in expected_messages
)


@pytest.mark.asyncio
async def test_get_messages_filter_by_tags_no_match(fixture_messages, ccn_api_client):
"""
Expand Down

0 comments on commit dbf2048

Please sign in to comment.