Skip to content

Commit

Permalink
Fix: refuse time and pagination parameters for message WS
Browse files Browse the repository at this point in the history
Problem: the message websocket could receive "startDate"/"endDate"
and pagination parameters.

Solution: use a slightly different Pydantic model for the websocket
to reject these parameters.
  • Loading branch information
odesenfans committed Oct 17, 2022
1 parent 2ace445 commit 4554c18
Showing 1 changed file with 48 additions and 40 deletions.
88 changes: 48 additions & 40 deletions src/aleph/web/controllers/messages.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import logging
from enum import IntEnum
from typing import Any, Dict, List, Optional, Mapping
from typing import Any, List, Optional, Mapping

from aiohttp import web
from aleph_message.models import MessageType, ItemHash, Chain
Expand All @@ -14,7 +14,6 @@
LIST_FIELD_SEPARATOR,
Pagination,
cond_output,
make_date_filters,
)

LOGGER = logging.getLogger(__name__)
Expand All @@ -30,7 +29,7 @@ class SortOrder(IntEnum):
DESCENDING = -1


class MessageQueryParams(BaseModel):
class BaseMessageQueryParams(BaseModel):
sort_order: SortOrder = Field(
default=SortOrder.DESCENDING,
description="Order in which messages should be listed: "
Expand Down Expand Up @@ -72,34 +71,6 @@ class MessageQueryParams(BaseModel):
hashes: Optional[List[ItemHash]] = Field(
default=None, description="Accepted values for the 'item_hash' field."
)
history: Optional[int] = Field(
DEFAULT_WS_HISTORY,
ge=10,
lt=200,
description="Accepted values for the 'item_hash' field.",
)
pagination: int = Field(
default=DEFAULT_MESSAGES_PER_PAGE,
ge=0,
description="Maximum number of messages to return. Specifying 0 removes this limit.",
)
page: int = Field(
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
)
start_date: float = Field(
default=0,
ge=0,
alias="startDate",
description="Start date timestamp. If specified, only messages with "
"a time field greater or equal to this value will be returned.",
)
end_date: float = Field(
default=0,
ge=0,
alias="endDate",
description="End date timestamp. If specified, only messages with "
"a time field lower than this value will be returned.",
)

@root_validator
def validate_field_dependencies(cls, values):
Expand All @@ -124,8 +95,8 @@ def split_str(cls, v):
return v.split(LIST_FIELD_SEPARATOR)
return v

def to_mongodb_filters(self) -> Mapping[str, Any]:
filters: List[Dict[str, Any]] = []
def to_filter_list(self) -> List[Mapping[str, Any]]:
filters: List[Mapping[str, Any]] = []

if self.message_type is not None:
filters.append({"type": self.message_type})
Expand Down Expand Up @@ -164,19 +135,56 @@ def to_mongodb_filters(self) -> Mapping[str, Any]:
}
)

date_filters = make_date_filters(
start=self.start_date, end=self.end_date, filter_key="time"
)
if date_filters:
filters.append(date_filters)
return filters

def to_mongodb_filters(self) -> Mapping[str, Any]:
filters = self.to_filter_list()
return self._make_and_filter(filters)

and_filter = {}
@staticmethod
def _make_and_filter(filters: List[Mapping[str, Any]]) -> Mapping[str, Any]:
and_filter: Mapping[str, Any] = {}
if filters:
and_filter = {"$and": filters} if len(filters) > 1 else filters[0]

return and_filter


class MessageQueryParams(BaseMessageQueryParams):
pagination: int = Field(
default=DEFAULT_MESSAGES_PER_PAGE,
ge=0,
description="Maximum number of messages to return. Specifying 0 removes this limit.",
)
page: int = Field(
default=DEFAULT_PAGE, ge=1, description="Offset in pages. Starts at 1."
)

start_date: float = Field(
default=0,
ge=0,
alias="startDate",
description="Start date timestamp. If specified, only messages with "
"a time field greater or equal to this value will be returned.",
)
end_date: float = Field(
default=0,
ge=0,
alias="endDate",
description="End date timestamp. If specified, only messages with "
"a time field lower than this value will be returned.",
)


class WsMessageQueryParams(BaseMessageQueryParams):
history: Optional[int] = Field(
DEFAULT_WS_HISTORY,
ge=10,
lt=200,
description="Accepted values for the 'item_hash' field.",
)


async def view_messages_list(request):
"""Messages list view with filters"""

Expand Down Expand Up @@ -244,7 +252,7 @@ async def messages_ws(request: web.Request):
collection = CappedMessage.collection
last_id = None

query_params = MessageQueryParams.parse_obj(request.query)
query_params = WsMessageQueryParams.parse_obj(request.query)
find_filters = query_params.to_mongodb_filters()

initial_count = query_params.history
Expand Down

0 comments on commit 4554c18

Please sign in to comment.